diff --git a/src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_streaming_model.py b/src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_streaming_model.py index d301af09a..4f18ae379 100644 --- a/src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_streaming_model.py +++ b/src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_streaming_model.py @@ -54,6 +54,7 @@ ResponseReasoningSummaryTextDeltaEvent, ResponseFunctionCallArgumentsDeltaEvent, ) +from openai.types.responses.response_prompt_param import ResponsePromptParam # AgentEx SDK imports from agentex.lib import adk @@ -481,12 +482,23 @@ async def get_response( output_schema: Optional[AgentOutputSchemaBase], handoffs: list[Handoff], tracing: ModelTracing, # noqa: ARG002 - **kwargs, # noqa: ARG002 + *, + previous_response_id: Optional[str] = None, + conversation_id: Optional[str] = None, + prompt: Optional[ResponsePromptParam] = None, ) -> ModelResponse: """Get a non-streaming response from the model with streaming to Redis. This method is used by Temporal activities and needs to return a complete response, but we stream the response to Redis while generating it. + + ``previous_response_id``, ``conversation_id``, and ``prompt`` are all + Responses API server-state parameters threaded through by the OpenAI + Agents SDK. Each is forwarded to ``responses.create`` only when + explicitly set — defaults resolve to ``NOT_GIVEN`` and are omitted from + the request body. Not all OpenAI-compatible backends recognize these + fields, so callers on alternative providers see no wire-level change + unless they opt in. """ task_id = streaming_task_id.get() @@ -575,6 +587,11 @@ async def get_response( if model_settings.top_logprobs is not None: extra_args["top_logprobs"] = model_settings.top_logprobs + # Opt-in prompt_cache_key: forwarded only when the caller supplies it via + # model_settings.extra_args["prompt_cache_key"]. Not all OpenAI-compatible + # endpoints recognize this parameter, so we don't auto-inject a default. + prompt_cache_key = extra_args.pop("prompt_cache_key", NOT_GIVEN) + # Create the response stream using Responses API logger.debug(f"[TemporalStreamingModel] Creating response stream with Responses API") stream = await self.client.responses.create( # type: ignore[call-overload] @@ -605,12 +622,20 @@ async def get_response( extra_headers=model_settings.extra_headers, extra_query=model_settings.extra_query, extra_body=model_settings.extra_body, + prompt_cache_key=prompt_cache_key, + previous_response_id=self._non_null_or_not_given(previous_response_id), + # SDK abstract names this conversation_id; the Responses API + # endpoint kwarg is `conversation` (accepts a str id directly). + conversation=self._non_null_or_not_given(conversation_id), + prompt=self._non_null_or_not_given(prompt), # Any additional parameters from extra_args **extra_args, ) # Process the stream of events from Responses API output_items = [] + captured_usage = None + captured_response_id = None current_text = "" streaming_context = None reasoning_context = None @@ -821,10 +846,13 @@ async def get_response( # Response completed logger.debug(f"[TemporalStreamingModel] Response completed") response = getattr(event, 'response', None) - if response and hasattr(response, 'output'): - # Use the final output from the response - output_items = response.output - logger.debug(f"[TemporalStreamingModel] Found {len(output_items)} output items in final response") + if response is not None: + if hasattr(response, 'output'): + # Use the final output from the response + output_items = response.output + logger.debug(f"[TemporalStreamingModel] Found {len(output_items)} output items in final response") + captured_usage = getattr(response, 'usage', None) + captured_response_id = getattr(response, 'id', None) # End of event processing loop - close any open contexts if reasoning_context: @@ -863,14 +891,33 @@ async def get_response( ) response_output.append(message) - # Create usage object - usage = Usage( - input_tokens=0, - output_tokens=0, - total_tokens=0, - input_tokens_details=InputTokensDetails(cached_tokens=0), - output_tokens_details=OutputTokensDetails(reasoning_tokens=len(''.join(reasoning_contents)) // 4), # Approximate - ) + # Use the real usage from the streaming Response if available; + # fall back to zeros only when the stream ended without a + # ResponseCompletedEvent (error paths). + if captured_usage is not None: + usage = Usage( + input_tokens=captured_usage.input_tokens, + output_tokens=captured_usage.output_tokens, + total_tokens=captured_usage.total_tokens, + input_tokens_details=InputTokensDetails( + cached_tokens=getattr( + captured_usage.input_tokens_details, "cached_tokens", 0 + ), + ), + output_tokens_details=OutputTokensDetails( + reasoning_tokens=getattr( + captured_usage.output_tokens_details, "reasoning_tokens", 0 + ), + ), + ) + else: + usage = Usage( + input_tokens=0, + output_tokens=0, + total_tokens=0, + input_tokens_details=InputTokensDetails(cached_tokens=0), + output_tokens_details=OutputTokensDetails(reasoning_tokens=0), + ) # Serialize response output items for span tracing new_items = [] @@ -919,6 +966,13 @@ async def get_response( output_data = { "new_items": new_items, "final_output": final_output, + "usage": { + "input_tokens": usage.input_tokens, + "output_tokens": usage.output_tokens, + "total_tokens": usage.total_tokens, + "cached_input_tokens": usage.input_tokens_details.cached_tokens, + "reasoning_tokens": usage.output_tokens_details.reasoning_tokens, + }, } # Include tool calls if any were in the input if tool_calls: @@ -926,14 +980,20 @@ async def get_response( # Include tool outputs if any were processed if tool_outputs: output_data["tool_outputs"] = tool_outputs - + span.output = output_data - # Return the response + # Return the response. response_id is the server-issued id from + # ResponseCompletedEvent.response.id, or None when the stream ended + # without a completed event (error path) — matching the documented + # `str | None` contract on `ModelResponse.response_id`. Returning + # None lets callers use it safely as `previous_response_id` for + # multi-turn chaining; a fabricated UUID would 400 against any real + # server. return ModelResponse( output=response_output, usage=usage, - response_id=f"resp_{uuid.uuid4().hex[:8]}", + response_id=captured_response_id, ) except Exception as e: diff --git a/src/agentex/lib/core/temporal/plugins/openai_agents/tests/test_streaming_model.py b/src/agentex/lib/core/temporal/plugins/openai_agents/tests/test_streaming_model.py index 809dee2e0..97dda0e61 100644 --- a/src/agentex/lib/core/temporal/plugins/openai_agents/tests/test_streaming_model.py +++ b/src/agentex/lib/core/temporal/plugins/openai_agents/tests/test_streaming_model.py @@ -2,6 +2,7 @@ Comprehensive tests for StreamingModel with all configurations and tool types. """ +from typing import Optional from unittest.mock import AsyncMock, MagicMock import pytest @@ -20,7 +21,7 @@ class TestStreamingModelSettings: """Test that all ModelSettings parameters work with Responses API""" @pytest.mark.asyncio - async def test_temperature_setting(self, streaming_model, _streaming_context_vars, sample_task_id): + async def test_temperature_setting(self, streaming_model, _streaming_context_vars): """Test that temperature parameter is properly passed to Responses API""" streaming_model.client.responses.create = AsyncMock() @@ -43,7 +44,6 @@ async def test_temperature_setting(self, streaming_model, _streaming_context_var output_schema=None, handoffs=[], tracing=None, - task_id=sample_task_id ) # Verify temperature was passed correctly @@ -51,7 +51,7 @@ async def test_temperature_setting(self, streaming_model, _streaming_context_var assert create_call.kwargs['temperature'] == temp @pytest.mark.asyncio - async def test_top_p_setting(self, streaming_model, _streaming_context_vars, sample_task_id): + async def test_top_p_setting(self, streaming_model, _streaming_context_vars): """Test that top_p parameter is properly passed to Responses API""" streaming_model.client.responses.create = AsyncMock() @@ -73,7 +73,6 @@ async def test_top_p_setting(self, streaming_model, _streaming_context_vars, sam output_schema=None, handoffs=[], tracing=None, - task_id=sample_task_id ) create_call = streaming_model.client.responses.create.call_args @@ -81,7 +80,7 @@ async def test_top_p_setting(self, streaming_model, _streaming_context_vars, sam assert create_call.kwargs['top_p'] == expected @pytest.mark.asyncio - async def test_max_tokens_setting(self, streaming_model, _streaming_context_vars, sample_task_id): + async def test_max_tokens_setting(self, streaming_model, _streaming_context_vars): """Test that max_tokens is properly mapped to max_output_tokens""" streaming_model.client.responses.create = AsyncMock() @@ -101,14 +100,13 @@ async def test_max_tokens_setting(self, streaming_model, _streaming_context_vars output_schema=None, handoffs=[], tracing=None, - task_id=sample_task_id ) create_call = streaming_model.client.responses.create.call_args assert create_call.kwargs['max_output_tokens'] == 2000 @pytest.mark.asyncio - async def test_reasoning_effort_settings(self, streaming_model, _streaming_context_vars, sample_task_id): + async def test_reasoning_effort_settings(self, streaming_model, _streaming_context_vars): """Test reasoning effort levels (low/medium/high)""" streaming_model.client.responses.create = AsyncMock() @@ -131,14 +129,13 @@ async def test_reasoning_effort_settings(self, streaming_model, _streaming_conte output_schema=None, handoffs=[], tracing=None, - task_id=sample_task_id ) create_call = streaming_model.client.responses.create.call_args assert create_call.kwargs['reasoning'] == {"effort": effort} @pytest.mark.asyncio - async def test_reasoning_summary_settings(self, streaming_model, _streaming_context_vars, sample_task_id): + async def test_reasoning_summary_settings(self, streaming_model, _streaming_context_vars): """Test reasoning summary settings (auto/none)""" streaming_model.client.responses.create = AsyncMock() @@ -161,14 +158,13 @@ async def test_reasoning_summary_settings(self, streaming_model, _streaming_cont output_schema=None, handoffs=[], tracing=None, - task_id=sample_task_id ) create_call = streaming_model.client.responses.create.call_args assert create_call.kwargs['reasoning'] == {"effort": "medium", "summary": summary} @pytest.mark.asyncio - async def test_tool_choice_variations(self, streaming_model, _streaming_context_vars, sample_task_id, sample_function_tool): + async def test_tool_choice_variations(self, streaming_model, _streaming_context_vars, sample_function_tool): """Test various tool_choice settings""" streaming_model.client.responses.create = AsyncMock() @@ -199,14 +195,13 @@ async def test_tool_choice_variations(self, streaming_model, _streaming_context_ output_schema=None, handoffs=[], tracing=None, - task_id=sample_task_id ) create_call = streaming_model.client.responses.create.call_args assert create_call.kwargs['tool_choice'] == expected @pytest.mark.asyncio - async def test_parallel_tool_calls(self, streaming_model, _streaming_context_vars, sample_task_id, sample_function_tool): + async def test_parallel_tool_calls(self, streaming_model, _streaming_context_vars, sample_function_tool): """Test parallel tool calls setting""" streaming_model.client.responses.create = AsyncMock() @@ -227,14 +222,13 @@ async def test_parallel_tool_calls(self, streaming_model, _streaming_context_var output_schema=None, handoffs=[], tracing=None, - task_id=sample_task_id ) create_call = streaming_model.client.responses.create.call_args assert create_call.kwargs['parallel_tool_calls'] == parallel @pytest.mark.asyncio - async def test_truncation_strategy(self, streaming_model, _streaming_context_vars, sample_task_id): + async def test_truncation_strategy(self, streaming_model, _streaming_context_vars): """Test truncation parameter""" streaming_model.client.responses.create = AsyncMock() @@ -255,14 +249,13 @@ async def test_truncation_strategy(self, streaming_model, _streaming_context_var output_schema=None, handoffs=[], tracing=None, - task_id=sample_task_id ) create_call = streaming_model.client.responses.create.call_args assert create_call.kwargs['truncation'] == "auto" @pytest.mark.asyncio - async def test_response_include(self, streaming_model, _streaming_context_vars, sample_task_id, sample_file_search_tool): + async def test_response_include(self, streaming_model, _streaming_context_vars, sample_file_search_tool): """Test response include parameter""" streaming_model.client.responses.create = AsyncMock() @@ -284,7 +277,6 @@ async def test_response_include(self, streaming_model, _streaming_context_vars, output_schema=None, handoffs=[], tracing=None, - task_id=sample_task_id ) create_call = streaming_model.client.responses.create.call_args @@ -294,7 +286,7 @@ async def test_response_include(self, streaming_model, _streaming_context_vars, assert "file_search_call.results" in include_list # Added by file search tool @pytest.mark.asyncio - async def test_verbosity(self, streaming_model, _streaming_context_vars, sample_task_id): + async def test_verbosity(self, streaming_model, _streaming_context_vars): """Test verbosity settings""" streaming_model.client.responses.create = AsyncMock() @@ -314,14 +306,13 @@ async def test_verbosity(self, streaming_model, _streaming_context_vars, sample_ output_schema=None, handoffs=[], tracing=None, - task_id=sample_task_id ) create_call = streaming_model.client.responses.create.call_args assert create_call.kwargs['text'] == {"verbosity": "high"} @pytest.mark.asyncio - async def test_metadata_and_store(self, streaming_model, _streaming_context_vars, sample_task_id): + async def test_metadata_and_store(self, streaming_model, _streaming_context_vars): """Test metadata and store parameters""" streaming_model.client.responses.create = AsyncMock() @@ -347,7 +338,6 @@ async def test_metadata_and_store(self, streaming_model, _streaming_context_vars output_schema=None, handoffs=[], tracing=None, - task_id=sample_task_id ) create_call = streaming_model.client.responses.create.call_args @@ -355,7 +345,7 @@ async def test_metadata_and_store(self, streaming_model, _streaming_context_vars assert create_call.kwargs['store'] == store @pytest.mark.asyncio - async def test_extra_headers_and_body(self, streaming_model, _streaming_context_vars, sample_task_id): + async def test_extra_headers_and_body(self, streaming_model, _streaming_context_vars): """Test extra customization parameters""" streaming_model.client.responses.create = AsyncMock() @@ -383,7 +373,6 @@ async def test_extra_headers_and_body(self, streaming_model, _streaming_context_ output_schema=None, handoffs=[], tracing=None, - task_id=sample_task_id ) create_call = streaming_model.client.responses.create.call_args @@ -392,7 +381,7 @@ async def test_extra_headers_and_body(self, streaming_model, _streaming_context_ assert create_call.kwargs['extra_query'] == extra_query @pytest.mark.asyncio - async def test_top_logprobs(self, streaming_model, _streaming_context_vars, sample_task_id): + async def test_top_logprobs(self, streaming_model, _streaming_context_vars): """Test top_logprobs parameter""" streaming_model.client.responses.create = AsyncMock() @@ -412,7 +401,6 @@ async def test_top_logprobs(self, streaming_model, _streaming_context_vars, samp output_schema=None, handoffs=[], tracing=None, - task_id=sample_task_id ) create_call = streaming_model.client.responses.create.call_args @@ -427,7 +415,7 @@ class TestStreamingModelTools: """Test that all tool types work with streaming""" @pytest.mark.asyncio - async def test_function_tool(self, streaming_model, _streaming_context_vars, sample_task_id, sample_function_tool): + async def test_function_tool(self, streaming_model, _streaming_context_vars, sample_function_tool): """Test FunctionTool conversion and streaming""" streaming_model.client.responses.create = AsyncMock() @@ -445,7 +433,6 @@ async def test_function_tool(self, streaming_model, _streaming_context_vars, sam output_schema=None, handoffs=[], tracing=None, - task_id=sample_task_id ) create_call = streaming_model.client.responses.create.call_args @@ -457,7 +444,7 @@ async def test_function_tool(self, streaming_model, _streaming_context_vars, sam assert 'parameters' in tools[0] @pytest.mark.asyncio - async def test_web_search_tool(self, streaming_model, _streaming_context_vars, sample_task_id, sample_web_search_tool): + async def test_web_search_tool(self, streaming_model, _streaming_context_vars, sample_web_search_tool): """Test WebSearchTool conversion""" streaming_model.client.responses.create = AsyncMock() @@ -475,7 +462,6 @@ async def test_web_search_tool(self, streaming_model, _streaming_context_vars, s output_schema=None, handoffs=[], tracing=None, - task_id=sample_task_id ) create_call = streaming_model.client.responses.create.call_args @@ -484,7 +470,7 @@ async def test_web_search_tool(self, streaming_model, _streaming_context_vars, s assert tools[0]['type'] == 'web_search' @pytest.mark.asyncio - async def test_file_search_tool(self, streaming_model, _streaming_context_vars, sample_task_id, sample_file_search_tool): + async def test_file_search_tool(self, streaming_model, _streaming_context_vars, sample_file_search_tool): """Test FileSearchTool conversion""" streaming_model.client.responses.create = AsyncMock() @@ -502,7 +488,6 @@ async def test_file_search_tool(self, streaming_model, _streaming_context_vars, output_schema=None, handoffs=[], tracing=None, - task_id=sample_task_id ) create_call = streaming_model.client.responses.create.call_args @@ -513,7 +498,7 @@ async def test_file_search_tool(self, streaming_model, _streaming_context_vars, assert tools[0]['max_num_results'] == 10 @pytest.mark.asyncio - async def test_computer_tool(self, streaming_model, _streaming_context_vars, sample_task_id, sample_computer_tool): + async def test_computer_tool(self, streaming_model, _streaming_context_vars, sample_computer_tool): """Test ComputerTool conversion""" streaming_model.client.responses.create = AsyncMock() @@ -531,7 +516,6 @@ async def test_computer_tool(self, streaming_model, _streaming_context_vars, sam output_schema=None, handoffs=[], tracing=None, - task_id=sample_task_id ) create_call = streaming_model.client.responses.create.call_args @@ -543,7 +527,7 @@ async def test_computer_tool(self, streaming_model, _streaming_context_vars, sam assert tools[0]['display_height'] == 1080 @pytest.mark.asyncio - async def test_multiple_computer_tools_error(self, streaming_model, _streaming_context_vars, sample_task_id, sample_computer_tool): + async def test_multiple_computer_tools_error(self, streaming_model, _streaming_context_vars, sample_computer_tool): """Test that multiple computer tools raise an error""" streaming_model.client.responses.create = AsyncMock() @@ -563,11 +547,10 @@ async def test_multiple_computer_tools_error(self, streaming_model, _streaming_c output_schema=None, handoffs=[], tracing=None, - task_id=sample_task_id ) @pytest.mark.asyncio - async def test_hosted_mcp_tool(self, streaming_model, _streaming_context_vars, sample_task_id, sample_hosted_mcp_tool): + async def test_hosted_mcp_tool(self, streaming_model, _streaming_context_vars, sample_hosted_mcp_tool): """Test HostedMCPTool conversion""" streaming_model.client.responses.create = AsyncMock() @@ -585,7 +568,6 @@ async def test_hosted_mcp_tool(self, streaming_model, _streaming_context_vars, s output_schema=None, handoffs=[], tracing=None, - task_id=sample_task_id ) create_call = streaming_model.client.responses.create.call_args @@ -595,7 +577,7 @@ async def test_hosted_mcp_tool(self, streaming_model, _streaming_context_vars, s assert tools[0]['server_label'] == 'test_server' @pytest.mark.asyncio - async def test_image_generation_tool(self, streaming_model, _streaming_context_vars, sample_task_id, sample_image_generation_tool): + async def test_image_generation_tool(self, streaming_model, _streaming_context_vars, sample_image_generation_tool): """Test ImageGenerationTool conversion""" streaming_model.client.responses.create = AsyncMock() @@ -613,7 +595,6 @@ async def test_image_generation_tool(self, streaming_model, _streaming_context_v output_schema=None, handoffs=[], tracing=None, - task_id=sample_task_id ) create_call = streaming_model.client.responses.create.call_args @@ -622,7 +603,7 @@ async def test_image_generation_tool(self, streaming_model, _streaming_context_v assert tools[0]['type'] == 'image_generation' @pytest.mark.asyncio - async def test_code_interpreter_tool(self, streaming_model, _streaming_context_vars, sample_task_id, sample_code_interpreter_tool): + async def test_code_interpreter_tool(self, streaming_model, _streaming_context_vars, sample_code_interpreter_tool): """Test CodeInterpreterTool conversion""" streaming_model.client.responses.create = AsyncMock() @@ -640,7 +621,6 @@ async def test_code_interpreter_tool(self, streaming_model, _streaming_context_v output_schema=None, handoffs=[], tracing=None, - task_id=sample_task_id ) create_call = streaming_model.client.responses.create.call_args @@ -649,7 +629,7 @@ async def test_code_interpreter_tool(self, streaming_model, _streaming_context_v assert tools[0]['type'] == 'code_interpreter' @pytest.mark.asyncio - async def test_local_shell_tool(self, streaming_model, _streaming_context_vars, sample_task_id, sample_local_shell_tool): + async def test_local_shell_tool(self, streaming_model, _streaming_context_vars, sample_local_shell_tool): """Test LocalShellTool conversion""" streaming_model.client.responses.create = AsyncMock() @@ -667,7 +647,6 @@ async def test_local_shell_tool(self, streaming_model, _streaming_context_vars, output_schema=None, handoffs=[], tracing=None, - task_id=sample_task_id ) create_call = streaming_model.client.responses.create.call_args @@ -677,7 +656,7 @@ async def test_local_shell_tool(self, streaming_model, _streaming_context_vars, # working_directory no longer in API - LocalShellTool uses executor internally @pytest.mark.asyncio - async def test_handoffs(self, streaming_model, _streaming_context_vars, sample_task_id, sample_handoff): + async def test_handoffs(self, streaming_model, _streaming_context_vars, sample_handoff): """Test Handoff conversion to function tools""" streaming_model.client.responses.create = AsyncMock() @@ -695,7 +674,6 @@ async def test_handoffs(self, streaming_model, _streaming_context_vars, sample_t output_schema=None, handoffs=[sample_handoff], tracing=None, - task_id=sample_task_id ) create_call = streaming_model.client.responses.create.call_args @@ -706,7 +684,7 @@ async def test_handoffs(self, streaming_model, _streaming_context_vars, sample_t assert tools[0]['description'] == 'Transfer to support agent' @pytest.mark.asyncio - async def test_mixed_tools(self, streaming_model, _streaming_context_vars, sample_task_id, + async def test_mixed_tools(self, streaming_model, _streaming_context_vars, sample_function_tool, sample_web_search_tool, sample_handoff): """Test multiple tools together""" streaming_model.client.responses.create = AsyncMock() @@ -725,7 +703,6 @@ async def test_mixed_tools(self, streaming_model, _streaming_context_vars, sampl output_schema=None, handoffs=[sample_handoff], tracing=None, - task_id=sample_task_id ) create_call = streaming_model.client.responses.create.call_args @@ -757,7 +734,7 @@ async def test_responses_api_streaming(self, streaming_model, mock_adk_streaming text_delta_2 = MagicMock(spec=ResponseTextDeltaEvent) text_delta_2.delta = "world!" completed = MagicMock(spec=ResponseCompletedEvent) - completed.response = MagicMock(output=[], usage=MagicMock()) + completed.response = MagicMock(output=[], usage=MagicMock(), id=None) mock_stream = AsyncMock() mock_stream.__aiter__.return_value = iter([item_added, text_delta_1, text_delta_2, completed]) streaming_model.client.responses.create.return_value = mock_stream @@ -770,7 +747,6 @@ async def test_responses_api_streaming(self, streaming_model, mock_adk_streaming output_schema=None, handoffs=[], tracing=None, - task_id=sample_task_id ) # Verify streaming context was created with the right task_id. We @@ -798,7 +774,7 @@ async def test_task_id_threading(self, streaming_model, mock_adk_streaming, _str item_added.item = MagicMock(type="message") item_added.output_index = 0 completed = MagicMock(spec=ResponseCompletedEvent) - completed.response = MagicMock(output=[], usage=MagicMock()) + completed.response = MagicMock(output=[], usage=MagicMock(), id=None) mock_stream = AsyncMock() mock_stream.__aiter__.return_value = iter([item_added, completed]) streaming_model.client.responses.create.return_value = mock_stream @@ -821,7 +797,7 @@ async def test_task_id_threading(self, streaming_model, mock_adk_streaming, _str assert call_args.kwargs['task_id'] == expected_task_id @pytest.mark.asyncio - async def test_redis_context_creation(self, streaming_model, mock_adk_streaming, _streaming_context_vars, sample_task_id): + async def test_redis_context_creation(self, streaming_model, mock_adk_streaming, _streaming_context_vars): """Test that Redis streaming contexts are created properly""" streaming_model.client.responses.create = AsyncMock() @@ -834,7 +810,7 @@ async def test_redis_context_creation(self, streaming_model, mock_adk_streaming, reasoning_delta.delta = "Thinking..." reasoning_delta.summary_index = 0 completed = MagicMock(spec=ResponseCompletedEvent) - completed.response = MagicMock(output=[], usage=MagicMock()) + completed.response = MagicMock(output=[], usage=MagicMock(), id=None) mock_stream = AsyncMock() mock_stream.__aiter__.return_value = iter([item_added, reasoning_delta, completed]) streaming_model.client.responses.create.return_value = mock_stream @@ -847,7 +823,6 @@ async def test_redis_context_creation(self, streaming_model, mock_adk_streaming, output_schema=None, handoffs=[], tracing=None, - task_id=sample_task_id ) # Should create at least one context for reasoning @@ -873,4 +848,386 @@ async def test_missing_task_id_error(self, streaming_model): output_schema=None, handoffs=[], tracing=None, - ) \ No newline at end of file + ) + + +class TestStreamingModelUsageResponseIdAndCacheKey: + """Cover real-Usage capture, real response_id, span emission, and opt-in prompt_cache_key.""" + + @staticmethod + def _async_iter(events): + async def _gen(): + for event in events: + yield event + return _gen() + + @staticmethod + def _make_response_completed_event( + *, + input_tokens: int = 0, + output_tokens: int = 0, + total_tokens: int = 0, + cached_tokens: int = 0, + reasoning_tokens: int = 0, + with_usage: bool = True, + response_id: Optional[str] = "resp_real_server_id", + ): + usage = MagicMock() + usage.input_tokens = input_tokens + usage.output_tokens = output_tokens + usage.total_tokens = total_tokens + usage.input_tokens_details = MagicMock(cached_tokens=cached_tokens) + usage.output_tokens_details = MagicMock(reasoning_tokens=reasoning_tokens) + + response = MagicMock() + response.output = [] + response.usage = usage if with_usage else None + response.id = response_id + + event = MagicMock(spec=ResponseCompletedEvent) + event.response = response + return event + + @pytest.fixture + def mock_span(self): + return MagicMock() + + @pytest.fixture + def streaming_model_with_mock_tracer(self, streaming_model, mock_span): + """A streaming_model whose tracer.trace().span(...) yields a captured mock span.""" + async_cm = MagicMock() + async_cm.__aenter__ = AsyncMock(return_value=mock_span) + async_cm.__aexit__ = AsyncMock(return_value=False) + trace_obj = MagicMock() + trace_obj.span = MagicMock(return_value=async_cm) + streaming_model.tracer = MagicMock() + streaming_model.tracer.trace = MagicMock(return_value=trace_obj) + return streaming_model + + @pytest.mark.asyncio + async def test_usage_captured_from_completed_event( + self, + streaming_model_with_mock_tracer, + _streaming_context_vars, # noqa: ARG002 + ): + model = streaming_model_with_mock_tracer + completed = self._make_response_completed_event( + input_tokens=1234, output_tokens=56, total_tokens=1290, + cached_tokens=987, reasoning_tokens=42, + ) + model.client.responses.create = AsyncMock(return_value=self._async_iter([completed])) + + response = await model.get_response( + system_instructions=None, + input="hi", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=None, + ) + + assert response.usage.input_tokens == 1234 + assert response.usage.output_tokens == 56 + assert response.usage.total_tokens == 1290 + assert response.usage.input_tokens_details.cached_tokens == 987 + assert response.usage.output_tokens_details.reasoning_tokens == 42 + + @pytest.mark.asyncio + async def test_usage_falls_back_when_no_completed_event( + self, + streaming_model_with_mock_tracer, + _streaming_context_vars, # noqa: ARG002 + ): + """Stream ending without a ResponseCompletedEvent (error path) → zero Usage.""" + model = streaming_model_with_mock_tracer + model.client.responses.create = AsyncMock(return_value=self._async_iter([])) + + response = await model.get_response( + system_instructions=None, + input="hi", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=None, + ) + + assert response.usage.input_tokens == 0 + assert response.usage.output_tokens == 0 + assert response.usage.total_tokens == 0 + assert response.usage.input_tokens_details.cached_tokens == 0 + assert response.usage.output_tokens_details.reasoning_tokens == 0 + + @pytest.mark.asyncio + async def test_usage_emitted_in_span_output( + self, + streaming_model_with_mock_tracer, + _streaming_context_vars, # noqa: ARG002 + mock_span, + ): + model = streaming_model_with_mock_tracer + completed = self._make_response_completed_event( + input_tokens=100, output_tokens=10, total_tokens=110, + cached_tokens=80, reasoning_tokens=5, + ) + model.client.responses.create = AsyncMock(return_value=self._async_iter([completed])) + + await model.get_response( + system_instructions=None, + input="hi", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=None, + ) + + assert isinstance(mock_span.output, dict) + usage_block = mock_span.output["usage"] + assert usage_block == { + "input_tokens": 100, + "output_tokens": 10, + "total_tokens": 110, + "cached_input_tokens": 80, + "reasoning_tokens": 5, + } + + @pytest.mark.asyncio + async def test_response_id_captured_from_completed_event( + self, + streaming_model_with_mock_tracer, + _streaming_context_vars, # noqa: ARG002 + ): + """Real server-issued id flows back on ModelResponse.response_id.""" + model = streaming_model_with_mock_tracer + completed = self._make_response_completed_event(response_id="resp_abcdef123456") + model.client.responses.create = AsyncMock(return_value=self._async_iter([completed])) + + response = await model.get_response( + system_instructions=None, + input="hi", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=None, + ) + + assert response.response_id == "resp_abcdef123456" + + @pytest.mark.asyncio + async def test_response_id_is_none_when_no_completed_event( + self, + streaming_model_with_mock_tracer, + _streaming_context_vars, # noqa: ARG002 + ): + """Stream ending without ResponseCompletedEvent → response_id is None. + + Critical: must NOT fabricate a UUID. Returning a fake id would cause + downstream `previous_response_id` chaining to 400 against the server. + """ + model = streaming_model_with_mock_tracer + model.client.responses.create = AsyncMock(return_value=self._async_iter([])) + + response = await model.get_response( + system_instructions=None, + input="hi", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=None, + ) + + assert response.response_id is None + + @pytest.mark.asyncio + async def test_prompt_cache_key_not_sent_by_default( + self, + streaming_model_with_mock_tracer, + _streaming_context_vars, # noqa: ARG002 + ): + """Without an opt-in, prompt_cache_key resolves to NOT_GIVEN (omitted from request).""" + model = streaming_model_with_mock_tracer + completed = self._make_response_completed_event() + model.client.responses.create = AsyncMock(return_value=self._async_iter([completed])) + + await model.get_response( + system_instructions=None, + input="hi", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=None, + ) + + kwargs = model.client.responses.create.call_args.kwargs + assert kwargs["prompt_cache_key"] is NOT_GIVEN + + @pytest.mark.asyncio + async def test_prompt_cache_key_forwarded_when_opted_in( + self, + streaming_model_with_mock_tracer, + _streaming_context_vars, # noqa: ARG002 + ): + """Caller opt-in via model_settings.extra_args is forwarded to responses.create.""" + model = streaming_model_with_mock_tracer + completed = self._make_response_completed_event() + model.client.responses.create = AsyncMock(return_value=self._async_iter([completed])) + + await model.get_response( + system_instructions=None, + input="hi", + model_settings=ModelSettings(extra_args={"prompt_cache_key": "my-key"}), + tools=[], + output_schema=None, + handoffs=[], + tracing=None, + ) + + kwargs = model.client.responses.create.call_args.kwargs + assert kwargs["prompt_cache_key"] == "my-key" + # Must be popped from extra_args so the SDK doesn't see it twice. + assert list(kwargs).count("prompt_cache_key") == 1 + + @pytest.mark.asyncio + async def test_previous_response_id_not_sent_by_default( + self, + streaming_model_with_mock_tracer, + _streaming_context_vars, # noqa: ARG002 + ): + """Without an opt-in, previous_response_id resolves to NOT_GIVEN. + + Critical for non-Responses-API-native backends (e.g. Claude-via-LiteLLM) + where unknown fields on the request body could be rejected. NOT_GIVEN + is filtered before serialization, so the field is omitted entirely. + """ + model = streaming_model_with_mock_tracer + completed = self._make_response_completed_event() + model.client.responses.create = AsyncMock(return_value=self._async_iter([completed])) + + await model.get_response( + system_instructions=None, + input="hi", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=None, + ) + + kwargs = model.client.responses.create.call_args.kwargs + assert kwargs["previous_response_id"] is NOT_GIVEN + + @pytest.mark.asyncio + async def test_previous_response_id_forwarded_via_sdk_kwarg( + self, + streaming_model_with_mock_tracer, + _streaming_context_vars, # noqa: ARG002 + ): + """The SDK threads previous_response_id as a keyword arg per Model.get_response + abstract contract. Verify it reaches responses.create instead of being silently + swallowed (which was the prior behavior under **kwargs).""" + model = streaming_model_with_mock_tracer + completed = self._make_response_completed_event() + model.client.responses.create = AsyncMock(return_value=self._async_iter([completed])) + + await model.get_response( + system_instructions=None, + input="hi", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=None, + previous_response_id="resp_prior_turn", + ) + + kwargs = model.client.responses.create.call_args.kwargs + assert kwargs["previous_response_id"] == "resp_prior_turn" + + @pytest.mark.asyncio + async def test_conversation_and_prompt_not_sent_by_default( + self, + streaming_model_with_mock_tracer, + _streaming_context_vars, # noqa: ARG002 + ): + """Without an opt-in, conversation/prompt resolve to NOT_GIVEN. + + Same opt-in pattern as previous_response_id and prompt_cache_key — the + wire request is unchanged for callers (and non-OpenAI backends) that + don't supply these. + """ + model = streaming_model_with_mock_tracer + completed = self._make_response_completed_event() + model.client.responses.create = AsyncMock(return_value=self._async_iter([completed])) + + await model.get_response( + system_instructions=None, + input="hi", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=None, + ) + + kwargs = model.client.responses.create.call_args.kwargs + assert kwargs["conversation"] is NOT_GIVEN + assert kwargs["prompt"] is NOT_GIVEN + + @pytest.mark.asyncio + async def test_conversation_id_forwarded_via_sdk_kwarg( + self, + streaming_model_with_mock_tracer, + _streaming_context_vars, # noqa: ARG002 + ): + """The SDK abstract names this `conversation_id`; the Responses API + endpoint kwarg is `conversation`. Caller passes a string id; we forward + it as-is (the Conversation type accepts str).""" + model = streaming_model_with_mock_tracer + completed = self._make_response_completed_event() + model.client.responses.create = AsyncMock(return_value=self._async_iter([completed])) + + await model.get_response( + system_instructions=None, + input="hi", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=None, + conversation_id="conv_abc123", + ) + + kwargs = model.client.responses.create.call_args.kwargs + assert kwargs["conversation"] == "conv_abc123" + + @pytest.mark.asyncio + async def test_prompt_forwarded_via_sdk_kwarg( + self, + streaming_model_with_mock_tracer, + _streaming_context_vars, # noqa: ARG002 + ): + """ResponsePromptParam (a TypedDict for pre-built prompts) is forwarded + as-is to responses.create.""" + model = streaming_model_with_mock_tracer + completed = self._make_response_completed_event() + model.client.responses.create = AsyncMock(return_value=self._async_iter([completed])) + + prompt_param = {"id": "prompt_test_id", "version": "1"} + await model.get_response( + system_instructions=None, + input="hi", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=None, + prompt=prompt_param, # type: ignore[arg-type] + ) + + kwargs = model.client.responses.create.call_args.kwargs + assert kwargs["prompt"] == prompt_param \ No newline at end of file