Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
ResponseReasoningSummaryTextDeltaEvent,
ResponseFunctionCallArgumentsDeltaEvent,
)
from openai.types.responses.response_prompt_param import ResponsePromptParam

# AgentEx SDK imports
from agentex.lib import adk
Expand Down Expand Up @@ -481,12 +482,23 @@ async def get_response(
output_schema: Optional[AgentOutputSchemaBase],
handoffs: list[Handoff],
tracing: ModelTracing, # noqa: ARG002
**kwargs, # noqa: ARG002
*,
previous_response_id: Optional[str] = None,
conversation_id: Optional[str] = None,
prompt: Optional[ResponsePromptParam] = None,
) -> ModelResponse:
"""Get a non-streaming response from the model with streaming to Redis.

This method is used by Temporal activities and needs to return a complete
response, but we stream the response to Redis while generating it.

``previous_response_id``, ``conversation_id``, and ``prompt`` are all
Responses API server-state parameters threaded through by the OpenAI
Agents SDK. Each is forwarded to ``responses.create`` only when
explicitly set — defaults resolve to ``NOT_GIVEN`` and are omitted from
the request body. Not all OpenAI-compatible backends recognize these
fields, so callers on alternative providers see no wire-level change
unless they opt in.
"""

task_id = streaming_task_id.get()
Expand Down Expand Up @@ -575,6 +587,11 @@ async def get_response(
if model_settings.top_logprobs is not None:
extra_args["top_logprobs"] = model_settings.top_logprobs

# Opt-in prompt_cache_key: forwarded only when the caller supplies it via
# model_settings.extra_args["prompt_cache_key"]. Not all OpenAI-compatible
# endpoints recognize this parameter, so we don't auto-inject a default.
prompt_cache_key = extra_args.pop("prompt_cache_key", NOT_GIVEN)

# Create the response stream using Responses API
logger.debug(f"[TemporalStreamingModel] Creating response stream with Responses API")
stream = await self.client.responses.create( # type: ignore[call-overload]
Expand Down Expand Up @@ -605,12 +622,20 @@ async def get_response(
extra_headers=model_settings.extra_headers,
extra_query=model_settings.extra_query,
extra_body=model_settings.extra_body,
prompt_cache_key=prompt_cache_key,
previous_response_id=self._non_null_or_not_given(previous_response_id),
# SDK abstract names this conversation_id; the Responses API
# endpoint kwarg is `conversation` (accepts a str id directly).
conversation=self._non_null_or_not_given(conversation_id),
prompt=self._non_null_or_not_given(prompt),
# Any additional parameters from extra_args
**extra_args,
Comment thread
greptile-apps[bot] marked this conversation as resolved.
)

# Process the stream of events from Responses API
output_items = []
captured_usage = None
captured_response_id = None
current_text = ""
streaming_context = None
reasoning_context = None
Expand Down Expand Up @@ -821,10 +846,13 @@ async def get_response(
# Response completed
logger.debug(f"[TemporalStreamingModel] Response completed")
response = getattr(event, 'response', None)
if response and hasattr(response, 'output'):
# Use the final output from the response
output_items = response.output
logger.debug(f"[TemporalStreamingModel] Found {len(output_items)} output items in final response")
if response is not None:
if hasattr(response, 'output'):
# Use the final output from the response
output_items = response.output
logger.debug(f"[TemporalStreamingModel] Found {len(output_items)} output items in final response")
captured_usage = getattr(response, 'usage', None)
captured_response_id = getattr(response, 'id', None)

# End of event processing loop - close any open contexts
if reasoning_context:
Expand Down Expand Up @@ -863,14 +891,33 @@ async def get_response(
)
response_output.append(message)

# Create usage object
usage = Usage(
input_tokens=0,
output_tokens=0,
total_tokens=0,
input_tokens_details=InputTokensDetails(cached_tokens=0),
output_tokens_details=OutputTokensDetails(reasoning_tokens=len(''.join(reasoning_contents)) // 4), # Approximate
)
# Use the real usage from the streaming Response if available;
# fall back to zeros only when the stream ended without a
# ResponseCompletedEvent (error paths).
if captured_usage is not None:
usage = Usage(
input_tokens=captured_usage.input_tokens,
output_tokens=captured_usage.output_tokens,
total_tokens=captured_usage.total_tokens,
input_tokens_details=InputTokensDetails(
cached_tokens=getattr(
captured_usage.input_tokens_details, "cached_tokens", 0
),
),
output_tokens_details=OutputTokensDetails(
reasoning_tokens=getattr(
captured_usage.output_tokens_details, "reasoning_tokens", 0
),
),
)
else:
usage = Usage(
input_tokens=0,
output_tokens=0,
total_tokens=0,
input_tokens_details=InputTokensDetails(cached_tokens=0),
output_tokens_details=OutputTokensDetails(reasoning_tokens=0),
)

# Serialize response output items for span tracing
new_items = []
Expand Down Expand Up @@ -919,21 +966,34 @@ async def get_response(
output_data = {
"new_items": new_items,
"final_output": final_output,
"usage": {
"input_tokens": usage.input_tokens,
"output_tokens": usage.output_tokens,
"total_tokens": usage.total_tokens,
"cached_input_tokens": usage.input_tokens_details.cached_tokens,
"reasoning_tokens": usage.output_tokens_details.reasoning_tokens,
},
}
# Include tool calls if any were in the input
if tool_calls:
output_data["tool_calls"] = tool_calls
# Include tool outputs if any were processed
if tool_outputs:
output_data["tool_outputs"] = tool_outputs

span.output = output_data

# Return the response
# Return the response. response_id is the server-issued id from
# ResponseCompletedEvent.response.id, or None when the stream ended
# without a completed event (error path) — matching the documented
# `str | None` contract on `ModelResponse.response_id`. Returning
# None lets callers use it safely as `previous_response_id` for
# multi-turn chaining; a fabricated UUID would 400 against any real
# server.
return ModelResponse(
output=response_output,
usage=usage,
response_id=f"resp_{uuid.uuid4().hex[:8]}",
response_id=captured_response_id,
)

except Exception as e:
Expand Down
Loading
Loading