diff --git a/src/agentex/lib/adk/_modules/streaming.py b/src/agentex/lib/adk/_modules/streaming.py index ab53ed68c..b57a552a1 100644 --- a/src/agentex/lib/adk/_modules/streaming.py +++ b/src/agentex/lib/adk/_modules/streaming.py @@ -7,6 +7,7 @@ from agentex.lib.adk.utils._modules.client import create_async_agentex_client from agentex.lib.core.adapters.streams.adapter_redis import RedisStreamRepository from agentex.lib.core.services.adk.streaming import ( + StreamingMode, StreamingService, StreamingTaskMessageContext, ) @@ -50,6 +51,7 @@ def streaming_task_message_context( self, task_id: str, initial_content: TaskMessageContent, + streaming_mode: StreamingMode = "coalesced", ) -> StreamingTaskMessageContext: """ Create a streaming context for managing TaskMessage lifecycle. @@ -60,7 +62,11 @@ def streaming_task_message_context( Args: task_id: The ID of the task initial_content: The initial content for the TaskMessage - agentex_client: The agentex client for creating/updating messages + streaming_mode: How per-delta updates are published. Defaults to + "coalesced" (50ms / 128-char windowed batches with an immediate + first-delta flush). Pass "per_token" for the legacy publish-every- + delta behavior, or "off" to suppress per-delta publishes entirely + while still recording the full message body on close. Returns: StreamingTaskMessageContext: Context manager for streaming operations @@ -76,4 +82,5 @@ def streaming_task_message_context( return self._streaming_service.streaming_task_message_context( task_id=task_id, initial_content=initial_content, + streaming_mode=streaming_mode, ) diff --git a/src/agentex/lib/core/services/adk/streaming.py b/src/agentex/lib/core/services/adk/streaming.py index 9fc3fc959..7799ea1eb 100644 --- a/src/agentex/lib/core/services/adk/streaming.py +++ b/src/agentex/lib/core/services/adk/streaming.py @@ -1,7 +1,9 @@ from __future__ import annotations import json -from typing import Literal +import asyncio +import contextlib +from typing import Literal, Callable, Awaitable from agentex import AsyncAgentex from agentex.lib.utils.logging import make_logger @@ -39,6 +41,198 @@ def _get_stream_topic(task_id: str) -> str: return f"task:{task_id}" +StreamingMode = Literal["off", "per_token", "coalesced"] +"""Controls how a StreamingTaskMessageContext publishes deltas. + +- "off": Feed the accumulator (so the persisted message body is correct) + but never publish per-delta events. Consumers see start + done + only. Lowest latency. +- "per_token": Publish every delta immediately. Highest UX fidelity for + token-by-token rendering, highest Redis cost, and re-introduces + head-of-line blocking on the producer's event loop. +- "coalesced": Buffer deltas in a small time/size window and publish them as + merged batches. The first delta flushes immediately for fast + perceived responsiveness; subsequent deltas flush every 50ms or + whenever 128 buffered chars accumulate, whichever comes first. + Order within each (delta type, index) channel is preserved + exactly; only granularity changes. +""" + + +def _delta_char_len(delta: TaskMessageDelta | None) -> int: + if delta is None: + return 0 + if isinstance(delta, TextDelta): + return len(delta.text_delta or "") + if isinstance(delta, DataDelta): + return len(delta.data_delta or "") + if isinstance(delta, ReasoningSummaryDelta): + return len(delta.summary_delta or "") + if isinstance(delta, ReasoningContentDelta): + return len(delta.content_delta or "") + if isinstance(delta, ToolRequestDelta): + return len(delta.arguments_delta or "") + if isinstance(delta, ToolResponseDelta): + return len(delta.content_delta or "") + return 0 + + +def _can_merge(a: TaskMessageDelta, b: TaskMessageDelta) -> bool: + if type(a) is not type(b): + return False + if isinstance(a, ReasoningSummaryDelta) and isinstance(b, ReasoningSummaryDelta): + return a.summary_index == b.summary_index + if isinstance(a, ReasoningContentDelta) and isinstance(b, ReasoningContentDelta): + return a.content_index == b.content_index + if isinstance(a, ToolRequestDelta) and isinstance(b, ToolRequestDelta): + return a.tool_call_id == b.tool_call_id + if isinstance(a, ToolResponseDelta) and isinstance(b, ToolResponseDelta): + return a.tool_call_id == b.tool_call_id + return True + + +def _merge_pair(a: TaskMessageDelta, b: TaskMessageDelta) -> TaskMessageDelta: + if isinstance(a, TextDelta) and isinstance(b, TextDelta): + return TextDelta(type="text", text_delta=(a.text_delta or "") + (b.text_delta or "")) + if isinstance(a, DataDelta) and isinstance(b, DataDelta): + return DataDelta(type="data", data_delta=(a.data_delta or "") + (b.data_delta or "")) + if isinstance(a, ReasoningSummaryDelta) and isinstance(b, ReasoningSummaryDelta): + return ReasoningSummaryDelta( + type="reasoning_summary", + summary_index=a.summary_index, + summary_delta=(a.summary_delta or "") + (b.summary_delta or ""), + ) + if isinstance(a, ReasoningContentDelta) and isinstance(b, ReasoningContentDelta): + return ReasoningContentDelta( + type="reasoning_content", + content_index=a.content_index, + content_delta=(a.content_delta or "") + (b.content_delta or ""), + ) + if isinstance(a, ToolRequestDelta) and isinstance(b, ToolRequestDelta): + return ToolRequestDelta( + type="tool_request", + tool_call_id=a.tool_call_id, + name=a.name, + arguments_delta=(a.arguments_delta or "") + (b.arguments_delta or ""), + ) + if isinstance(a, ToolResponseDelta) and isinstance(b, ToolResponseDelta): + return ToolResponseDelta( + type="tool_response", + tool_call_id=a.tool_call_id, + name=a.name, + content_delta=(a.content_delta or "") + (b.content_delta or ""), + ) + raise AssertionError( + f"_can_merge approved {type(a).__name__} pair but _merge_pair has no handler — " + "a new TaskMessageDelta variant was added without updating both functions" + ) + + +def _merge_consecutive(updates: list[StreamTaskMessageDelta]) -> list[StreamTaskMessageDelta]: + """Merge consecutive same-channel deltas. Order across channels is preserved exactly.""" + result: list[StreamTaskMessageDelta] = [] + for u in updates: + if u.delta is None or not result: + result.append(u) + continue + last = result[-1] + if last.delta is not None and _can_merge(last.delta, u.delta): + result[-1] = StreamTaskMessageDelta( + parent_task_message=last.parent_task_message, + delta=_merge_pair(last.delta, u.delta), + type="delta", + ) + else: + result.append(u) + return result + + +class CoalescingBuffer: + """Time-and-size-windowed buffer that merges consecutive same-channel deltas. + + Decouples the producer (model event loop) from the publisher (Redis): ``add`` + only enqueues and may signal an early flush; the actual publish always runs + on a background ticker, so the producer never awaits on a Redis round-trip. + """ + + FLUSH_INTERVAL_S = 0.050 + MAX_BUFFERED_CHARS = 128 + + def __init__(self, on_flush: Callable[[StreamTaskMessageDelta], Awaitable[object]]): + self._on_flush = on_flush + self._buf: list[StreamTaskMessageDelta] = [] + self._buf_chars = 0 + self._first_flushed = False + self._closed = False + self._lock = asyncio.Lock() + self._flush_signal = asyncio.Event() + self._task: asyncio.Task[None] | None = None + + def start(self) -> None: + if self._task is None: + self._task = asyncio.create_task(self._run(), name="coalescing-buffer") + + async def add(self, update: StreamTaskMessageDelta) -> None: + if self._closed: + return + async with self._lock: + self._buf.append(update) + self._buf_chars += _delta_char_len(update.delta) + if not self._first_flushed or self._buf_chars >= self.MAX_BUFFERED_CHARS: + self._first_flushed = True + self._flush_signal.set() + + async def _run(self) -> None: + try: + while not self._closed: + try: + await asyncio.wait_for(self._flush_signal.wait(), timeout=self.FLUSH_INTERVAL_S) + except asyncio.TimeoutError: + pass + async with self._lock: + self._flush_signal.clear() + drained = self._drain_locked() + for idx, u in enumerate(drained): + try: + await self._on_flush(u) + except asyncio.CancelledError: + # Re-enqueue the item being flushed plus any remaining so + # close()'s final drain can recover them. May cause a + # duplicate publish of the in-flight item, which is + # preferable to silent loss for a streaming UX. + async with self._lock: + self._buf = drained[idx:] + self._buf + raise + except Exception as e: + logger.exception(f"CoalescingBuffer flush failed: {e}") + except asyncio.CancelledError: + pass + + async def close(self) -> None: + self._closed = True + if self._task is not None: + self._flush_signal.set() + self._task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._task + self._task = None + async with self._lock: + drained = self._drain_locked() + for u in drained: + try: + await self._on_flush(u) + except Exception as e: + logger.exception(f"CoalescingBuffer final flush failed: {e}") + + def _drain_locked(self) -> list[StreamTaskMessageDelta]: + if not self._buf: + return [] + merged = _merge_consecutive(self._buf) + self._buf = [] + self._buf_chars = 0 + return merged + + class DeltaAccumulator: def __init__(self): self._accumulated_deltas: list[TaskMessageDelta] = [] @@ -176,6 +370,7 @@ def __init__( initial_content: TaskMessageContent, agentex_client: AsyncAgentex, streaming_service: "StreamingService", + streaming_mode: StreamingMode = "coalesced", ): self.task_id = task_id self.initial_content = initial_content @@ -184,6 +379,8 @@ def __init__( self._streaming_service = streaming_service self._is_closed = False self._delta_accumulator = DeltaAccumulator() + self._streaming_mode: StreamingMode = streaming_mode + self._buffer: CoalescingBuffer | None = None async def __aenter__(self) -> "StreamingTaskMessageContext": return await self.open() @@ -208,6 +405,10 @@ async def open(self) -> "StreamingTaskMessageContext": ) await self._streaming_service.stream_update(start_event) + if self._streaming_mode == "coalesced": + self._buffer = CoalescingBuffer(on_flush=self._streaming_service.stream_update) + self._buffer.start() + return self async def close(self) -> TaskMessage: @@ -218,6 +419,12 @@ async def close(self) -> TaskMessage: if self._is_closed: return self.task_message # Already done + # Drain any buffered deltas before announcing DONE so consumers see the + # full sequence in order. + if self._buffer is not None: + await self._buffer.close() + self._buffer = None + # Send the DONE event done_event = StreamTaskMessageDone( parent_task_message=self.task_message, @@ -227,8 +434,8 @@ async def close(self) -> TaskMessage: # Update the task message with the final content has_deltas = ( - self._delta_accumulator._accumulated_deltas or - self._delta_accumulator._reasoning_summaries or + self._delta_accumulator._accumulated_deltas or + self._delta_accumulator._reasoning_summaries or self._delta_accumulator._reasoning_contents ) if has_deltas: @@ -248,7 +455,20 @@ async def close(self) -> TaskMessage: async def stream_update( self, update: TaskMessageUpdate ) -> TaskMessageUpdate | None: - """Stream an update to the repository.""" + """Stream an update to the repository. + + Behavior depends on the context's ``streaming_mode``: + - "off": delta updates feed the accumulator (so the persisted message + body is correct) but are never published. + - "per_token": delta updates are published immediately. + - "coalesced": delta updates are queued in a 50ms / 128-char window and + flushed as merged batches on a background ticker; the first delta + flushes immediately for fast perceived responsiveness. + + ``StreamTaskMessageDone`` and ``StreamTaskMessageFull`` updates always + publish synchronously regardless of mode so consumers and persistence + stay in sync. + """ if self._is_closed: raise ValueError("Context is already done") @@ -258,6 +478,11 @@ async def stream_update( if isinstance(update, StreamTaskMessageDelta): if update.delta is not None: self._delta_accumulator.add_delta(update.delta) + if self._streaming_mode == "off": + return update + if self._streaming_mode == "coalesced" and self._buffer is not None: + await self._buffer.add(update) + return update result = await self._streaming_service.stream_update(update) @@ -288,12 +513,14 @@ def streaming_task_message_context( self, task_id: str, initial_content: TaskMessageContent, + streaming_mode: StreamingMode = "coalesced", ) -> StreamingTaskMessageContext: return StreamingTaskMessageContext( task_id=task_id, initial_content=initial_content, agentex_client=self._agentex_client, streaming_service=self, + streaming_mode=streaming_mode, ) async def stream_update( diff --git a/src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_streaming_model.py b/src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_streaming_model.py index b21694e88..d301af09a 100644 --- a/src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_streaming_model.py +++ b/src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_streaming_model.py @@ -2,7 +2,6 @@ from __future__ import annotations import uuid -import logging from typing import Any, List, Union, Optional, override from agents import ( @@ -29,6 +28,10 @@ ) from agents.computer import Computer, AsyncComputer +# Re-export the canonical StreamingMode literal from the streaming service so +# all layers share a single definition. +from agentex.lib.core.services.adk.streaming import StreamingMode as StreamingMode + try: from agents.tool import ShellTool # type: ignore[attr-defined] except ImportError: @@ -54,6 +57,7 @@ # AgentEx SDK imports from agentex.lib import adk +from agentex.lib.utils.logging import make_logger from agentex.lib.core.tracing.tracer import AsyncTracer from agentex.types.task_message_delta import TextDelta, ReasoningContentDelta, ReasoningSummaryDelta from agentex.types.task_message_update import StreamTaskMessageFull, StreamTaskMessageDelta @@ -65,8 +69,12 @@ streaming_parent_span_id, ) -# Create logger for this module -logger = logging.getLogger("agentex.temporal.streaming") +# Use the SDK's make_logger so this module's INFO/DEBUG output is actually +# visible (raw ``logging.getLogger`` returns a logger with no handler/level +# configured, which silently drops anything below WARNING). Keep the explicit +# name "agentex.temporal.streaming" so any external logging config targeting +# that name keeps working. +logger = make_logger("agentex.temporal.streaming") def _serialize_item(item: Any) -> dict[str, Any]: @@ -113,6 +121,7 @@ def __init__( model_name: str = "gpt-4o", _use_responses_api: bool = True, openai_client: Optional[AsyncOpenAI] = None, + streaming_mode: StreamingMode = "coalesced", ): """Initialize the streaming model with OpenAI client and model name. @@ -121,18 +130,25 @@ def __init__( _use_responses_api: Internal flag for responses API (deprecated, always True) openai_client: Optional custom AsyncOpenAI client. If not provided, a default client with max_retries=0 will be created (since Temporal handles retries) + streaming_mode: How per-delta updates flow to consumers. Defaults to + "coalesced" (50ms / 128-char windowed batches with an + immediate first-delta flush) for low latency without + giving up streaming UX. Use "per_token" for legacy + publish-every-delta behavior, or "off" to suppress + per-delta publishes entirely. """ # Use provided client or create default (Temporal handles retries) self.client = openai_client if openai_client is not None else AsyncOpenAI(max_retries=0) self.model_name = model_name # Always use Responses API for all models self.use_responses_api = True + self.streaming_mode: StreamingMode = streaming_mode # Initialize tracer as a class variable agentex_client = create_async_agentex_client() self.tracer = AsyncTracer(agentex_client) - logger.info(f"[TemporalStreamingModel] Initialized model={self.model_name}, use_responses_api={self.use_responses_api}, custom_client={openai_client is not None}, tracer=initialized") + logger.info(f"[TemporalStreamingModel] Initialized model={self.model_name}, use_responses_api={self.use_responses_api}, custom_client={openai_client is not None}, streaming_mode={self.streaming_mode}, tracer=initialized") def _non_null_or_not_given(self, value: Any) -> Any: """Convert None to NOT_GIVEN sentinel, matching OpenAI SDK pattern.""" @@ -634,6 +650,7 @@ async def get_response( type="reasoning", style="active", ), + streaming_mode=self.streaming_mode, ).__aenter__() elif item and getattr(item, 'type', None) == 'function_call': # Track the function call being streamed @@ -654,6 +671,7 @@ async def get_response( content="", format="markdown", ), + streaming_mode=self.streaming_mode, ).__aenter__() elif isinstance(event, ResponseFunctionCallArgumentsDeltaEvent): @@ -732,7 +750,8 @@ async def get_response( delta=delta_obj, type="delta", ) - await streaming_context.stream_update(update) if streaming_context else None + if streaming_context: + await streaming_context.stream_update(update) except Exception as e: logger.warning(f"Failed to send text delta: {e}") @@ -935,16 +954,24 @@ def stream_response(self, *args, **kwargs): class TemporalStreamingModelProvider(ModelProvider): """Custom model provider that returns a streaming-capable model.""" - def __init__(self, openai_client: Optional[AsyncOpenAI] = None): + def __init__( + self, + openai_client: Optional[AsyncOpenAI] = None, + streaming_mode: StreamingMode = "coalesced", + ): """Initialize the provider. Args: openai_client: Optional custom AsyncOpenAI client to use for all models. If not provided, each model will create its own default client. + streaming_mode: Default streaming mode applied to every model returned by + this provider. See ``StreamingMode`` for the meaning of + each value. Defaults to "coalesced" — fast but still streamy. """ super().__init__() self.openai_client = openai_client - logger.info(f"[TemporalStreamingModelProvider] Initialized, custom_client={openai_client is not None}") + self.streaming_mode: StreamingMode = streaming_mode + logger.info(f"[TemporalStreamingModelProvider] Initialized, custom_client={openai_client is not None}, streaming_mode={self.streaming_mode}") @override def get_model(self, model_name: Union[str, None]) -> Model: @@ -959,5 +986,9 @@ def get_model(self, model_name: Union[str, None]) -> Model: # Use the provided model_name or default to gpt-4o actual_model = model_name if model_name else "gpt-4o" logger.info(f"[TemporalStreamingModelProvider] Creating TemporalStreamingModel for model_name: {actual_model}") - model = TemporalStreamingModel(model_name=actual_model, openai_client=self.openai_client) + model = TemporalStreamingModel( + model_name=actual_model, + openai_client=self.openai_client, + streaming_mode=self.streaming_mode, + ) return model diff --git a/src/agentex/lib/core/temporal/plugins/openai_agents/tests/test_streaming_model.py b/src/agentex/lib/core/temporal/plugins/openai_agents/tests/test_streaming_model.py index 817e5e5b7..809dee2e0 100644 --- a/src/agentex/lib/core/temporal/plugins/openai_agents/tests/test_streaming_model.py +++ b/src/agentex/lib/core/temporal/plugins/openai_agents/tests/test_streaming_model.py @@ -773,11 +773,13 @@ async def test_responses_api_streaming(self, streaming_model, mock_adk_streaming task_id=sample_task_id ) - # Verify streaming context was created - mock_adk_streaming.streaming_task_message_context.assert_called_with( - task_id=sample_task_id, - initial_content=mock_adk_streaming.streaming_task_message_context.call_args.kwargs['initial_content'] - ) + # Verify streaming context was created with the right task_id. We + # don't strict-match the full kwargs because production also passes + # ``streaming_mode``, which is an implementation detail this test + # doesn't care about. + mock_adk_streaming.streaming_task_message_context.assert_called() + call_kwargs = mock_adk_streaming.streaming_task_message_context.call_args.kwargs + assert call_kwargs['task_id'] == sample_task_id # Verify result is returned as ModelResponse from agents import ModelResponse diff --git a/tests/lib/core/services/__init__.py b/tests/lib/core/services/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/lib/core/services/adk/__init__.py b/tests/lib/core/services/adk/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/lib/core/services/adk/test_streaming.py b/tests/lib/core/services/adk/test_streaming.py new file mode 100644 index 000000000..8b5fe9a35 --- /dev/null +++ b/tests/lib/core/services/adk/test_streaming.py @@ -0,0 +1,497 @@ +"""Tests for the streaming service: ``CoalescingBuffer``, merge helpers, and +``StreamingTaskMessageContext`` mode dispatch. + +These exercise the in-process behavior of the streaming layer without hitting +Redis or any AgentEx HTTP endpoints — everything below the +``StreamingService.stream_update`` boundary is mocked. +""" +from __future__ import annotations + +import asyncio +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from agentex.types.task_message import TaskMessage +from agentex.types.text_content import TextContent +from agentex.types.task_message_delta import ( + DataDelta, + TextDelta, + ToolRequestDelta, + ToolResponseDelta, + ReasoningSummaryDelta, +) +from agentex.types.task_message_update import StreamTaskMessageDelta +from agentex.lib.core.services.adk.streaming import ( + CoalescingBuffer, + StreamingTaskMessageContext, + _can_merge, + _merge_pair, + _delta_char_len, + _merge_consecutive, +) + + +@pytest.fixture +def task_message() -> TaskMessage: + return TaskMessage( + id="m1", + task_id="t1", + content=TextContent(author="agent", content="", format="markdown"), + streaming_status="IN_PROGRESS", + ) + + +def _text(tm: TaskMessage, s: str) -> StreamTaskMessageDelta: + return StreamTaskMessageDelta( + parent_task_message=tm, + delta=TextDelta(type="text", text_delta=s), + type="delta", + ) + + +def _reasoning_summary(tm: TaskMessage, idx: int, s: str) -> StreamTaskMessageDelta: + return StreamTaskMessageDelta( + parent_task_message=tm, + delta=ReasoningSummaryDelta( + type="reasoning_summary", summary_index=idx, summary_delta=s + ), + type="delta", + ) + + +async def _make_context(streaming_mode: str) -> tuple[StreamingTaskMessageContext, MagicMock, TaskMessage]: + tm = TaskMessage( + id="m1", + task_id="t1", + content=TextContent(author="agent", content="", format="markdown"), + streaming_status="IN_PROGRESS", + ) + svc = MagicMock() + svc.stream_update = AsyncMock() + client = MagicMock() + client.messages.create = AsyncMock(return_value=tm) + client.messages.update = AsyncMock() + ctx = StreamingTaskMessageContext( + task_id="t1", + initial_content=TextContent(author="agent", content="", format="markdown"), + agentex_client=client, + streaming_service=svc, + streaming_mode=streaming_mode, # type: ignore[arg-type] + ) + await ctx.open() + return ctx, svc, tm + + +class TestDeltaCharLen: + def test_text_delta(self) -> None: + assert _delta_char_len(TextDelta(type="text", text_delta="hello")) == 5 + + def test_reasoning_summary_delta(self) -> None: + assert ( + _delta_char_len( + ReasoningSummaryDelta( + type="reasoning_summary", summary_index=0, summary_delta="abc" + ) + ) + == 3 + ) + + def test_none_delta_is_zero(self) -> None: + assert _delta_char_len(None) == 0 + + def test_empty_string_delta(self) -> None: + assert _delta_char_len(TextDelta(type="text", text_delta="")) == 0 + + +class TestCanMerge: + def test_same_text_type(self) -> None: + a = TextDelta(type="text", text_delta="a") + b = TextDelta(type="text", text_delta="b") + assert _can_merge(a, b) is True + + def test_different_types_never_merge(self) -> None: + text = TextDelta(type="text", text_delta="a") + data = DataDelta(type="data", data_delta="b") + assert _can_merge(text, data) is False + + def test_reasoning_summary_same_index_merges(self) -> None: + a = ReasoningSummaryDelta(type="reasoning_summary", summary_index=0, summary_delta="x") + b = ReasoningSummaryDelta(type="reasoning_summary", summary_index=0, summary_delta="y") + assert _can_merge(a, b) is True + + def test_reasoning_summary_different_index_blocks_merge(self) -> None: + a = ReasoningSummaryDelta(type="reasoning_summary", summary_index=0, summary_delta="x") + b = ReasoningSummaryDelta(type="reasoning_summary", summary_index=1, summary_delta="y") + assert _can_merge(a, b) is False + + def test_tool_request_same_call_id_merges(self) -> None: + a = ToolRequestDelta(type="tool_request", tool_call_id="c1", name="t", arguments_delta="{") + b = ToolRequestDelta(type="tool_request", tool_call_id="c1", name="t", arguments_delta="}") + assert _can_merge(a, b) is True + + def test_tool_request_different_call_id_blocks_merge(self) -> None: + a = ToolRequestDelta(type="tool_request", tool_call_id="c1", name="t", arguments_delta="{") + b = ToolRequestDelta(type="tool_request", tool_call_id="c2", name="t", arguments_delta="}") + assert _can_merge(a, b) is False + + +class TestMergePair: + def test_text_concatenates(self) -> None: + merged = _merge_pair( + TextDelta(type="text", text_delta="Hello "), + TextDelta(type="text", text_delta="world"), + ) + assert isinstance(merged, TextDelta) + assert merged.text_delta == "Hello world" + + def test_reasoning_summary_concatenates_and_keeps_index(self) -> None: + merged = _merge_pair( + ReasoningSummaryDelta( + type="reasoning_summary", summary_index=2, summary_delta="hello " + ), + ReasoningSummaryDelta( + type="reasoning_summary", summary_index=2, summary_delta="world" + ), + ) + assert isinstance(merged, ReasoningSummaryDelta) + assert merged.summary_index == 2 + assert merged.summary_delta == "hello world" + + def test_tool_response_concatenates_and_keeps_call_id(self) -> None: + merged = _merge_pair( + ToolResponseDelta( + type="tool_response", tool_call_id="c1", name="t", content_delta="part1 " + ), + ToolResponseDelta( + type="tool_response", tool_call_id="c1", name="t", content_delta="part2" + ), + ) + assert isinstance(merged, ToolResponseDelta) + assert merged.tool_call_id == "c1" + assert merged.content_delta == "part1 part2" + + def test_handles_none_string_fields(self) -> None: + """Pydantic allows the *_delta fields to be None; merge must coerce to empty.""" + merged = _merge_pair( + TextDelta(type="text", text_delta=None), + TextDelta(type="text", text_delta="late"), + ) + assert isinstance(merged, TextDelta) + assert merged.text_delta == "late" + + +class TestMergeConsecutive: + def test_pure_text_collapses_to_one(self, task_message: TaskMessage) -> None: + deltas = [_text(task_message, s) for s in ["Hello", " ", "world", "!"]] + merged = _merge_consecutive(deltas) + assert len(merged) == 1 + assert merged[0].delta is not None + assert isinstance(merged[0].delta, TextDelta) + assert merged[0].delta.text_delta == "Hello world!" + + def test_empty_input_returns_empty_list(self) -> None: + assert _merge_consecutive([]) == [] + + def test_single_delta_passes_through(self, task_message: TaskMessage) -> None: + deltas = [_text(task_message, "lone")] + merged = _merge_consecutive(deltas) + assert len(merged) == 1 + assert merged[0] is deltas[0] # same object, no merge happened + + def test_cross_channel_order_preserved_for_reasoning( + self, task_message: TaskMessage + ) -> None: + """Consecutive same-(type, index) merges; distinct channels never reorder.""" + deltas = [ + _reasoning_summary(task_message, 0, "Let me "), + _reasoning_summary(task_message, 0, "think..."), + _reasoning_summary(task_message, 1, "Maybe "), + _reasoning_summary(task_message, 0, " Actually,"), + _reasoning_summary(task_message, 0, " yes."), + ] + merged = _merge_consecutive(deltas) + # Three groups: idx=0 run, idx=1 single, idx=0 run again — order preserved. + assert len(merged) == 3 + assert merged[0].delta is not None and isinstance( + merged[0].delta, ReasoningSummaryDelta + ) + assert merged[1].delta is not None and isinstance( + merged[1].delta, ReasoningSummaryDelta + ) + assert merged[2].delta is not None and isinstance( + merged[2].delta, ReasoningSummaryDelta + ) + assert merged[0].delta.summary_index == 0 + assert merged[0].delta.summary_delta == "Let me think..." + assert merged[1].delta.summary_index == 1 + assert merged[1].delta.summary_delta == "Maybe " + assert merged[2].delta.summary_index == 0 + assert merged[2].delta.summary_delta == " Actually, yes." + + def test_per_channel_concat_matches_per_token_semantics( + self, task_message: TaskMessage + ) -> None: + """Reconstructing per-channel content from the merged stream must match + what a per-token consumer would have seen.""" + deltas = [ + _reasoning_summary(task_message, 0, "Hel"), + _reasoning_summary(task_message, 0, "lo"), + _reasoning_summary(task_message, 1, "World"), + _reasoning_summary(task_message, 0, "!"), + ] + merged = _merge_consecutive(deltas) + + per_index: dict[int, str] = {} + for u in merged: + d = u.delta + assert isinstance(d, ReasoningSummaryDelta) + per_index[d.summary_index] = per_index.get(d.summary_index, "") + ( + d.summary_delta or "" + ) + + assert per_index == {0: "Hello!", 1: "World"} + + +class TestCoalescingBufferTimeWindow: + @pytest.mark.asyncio + async def test_first_delta_flushes_immediately( + self, task_message: TaskMessage + ) -> None: + """The first-delta-immediate optimization should trip a flush in <=20ms, + well below the 50ms time window, so consumers see ``something started``.""" + flushed: list[StreamTaskMessageDelta] = [] + + async def on_flush(u: StreamTaskMessageDelta) -> None: + flushed.append(u) + + buf = CoalescingBuffer(on_flush=on_flush) + buf.start() + try: + await buf.add(_text(task_message, "hi")) + # Give the ticker a single tick to drain the signal. + await asyncio.sleep(0.020) + assert len(flushed) == 1 + assert flushed[0].delta is not None and isinstance( + flushed[0].delta, TextDelta + ) + assert flushed[0].delta.text_delta == "hi" + finally: + await buf.close() + + @pytest.mark.asyncio + async def test_size_threshold_triggers_early_flush( + self, task_message: TaskMessage + ) -> None: + """Adding more than MAX_BUFFERED_CHARS in one shot should flush within + a single asyncio tick, well before the 50ms timer would fire.""" + flushed: list[StreamTaskMessageDelta] = [] + + async def on_flush(u: StreamTaskMessageDelta) -> None: + flushed.append(u) + + buf = CoalescingBuffer(on_flush=on_flush) + buf.start() + try: + # Burn the first-delta-immediate slot so we're on the steady-state path. + await buf.add(_text(task_message, "x")) + await asyncio.sleep(0.020) + flushed.clear() + + # Now add 200 chars in one delta — well over MAX_BUFFERED_CHARS=128. + await buf.add(_text(task_message, "A" * 200)) + await asyncio.sleep(0.010) # half the timer interval; only size can fire here + assert len(flushed) == 1 + assert flushed[0].delta is not None and isinstance( + flushed[0].delta, TextDelta + ) + assert flushed[0].delta.text_delta == "A" * 200 + finally: + await buf.close() + + @pytest.mark.asyncio + async def test_subsequent_deltas_coalesce_within_window( + self, task_message: TaskMessage + ) -> None: + """Three small deltas added inside one timer window should publish as + one merged delta (after the initial first-flush burns).""" + flushed: list[StreamTaskMessageDelta] = [] + + async def on_flush(u: StreamTaskMessageDelta) -> None: + flushed.append(u) + + buf = CoalescingBuffer(on_flush=on_flush) + buf.start() + try: + await buf.add(_text(task_message, "first")) # immediate flush + await asyncio.sleep(0.020) + flushed.clear() + + for chunk in ("ab", "cd", "ef"): + await buf.add(_text(task_message, chunk)) + # Wait past the 50ms window so the timer fires. + await asyncio.sleep(0.080) + # All three small deltas merge into a single publish. + assert len(flushed) == 1 + assert flushed[0].delta is not None and isinstance( + flushed[0].delta, TextDelta + ) + assert flushed[0].delta.text_delta == "abcdef" + finally: + await buf.close() + + +class TestCoalescingBufferClose: + @pytest.mark.asyncio + async def test_close_drains_remaining_buffered_items( + self, task_message: TaskMessage + ) -> None: + """Items added after the last timer tick must still flush before close() + completes — the persisted message body and the stream contract both + require it.""" + flushed: list[StreamTaskMessageDelta] = [] + + async def on_flush(u: StreamTaskMessageDelta) -> None: + flushed.append(u) + + buf = CoalescingBuffer(on_flush=on_flush) + buf.start() + await buf.add(_text(task_message, "first")) # immediate + await asyncio.sleep(0.020) + flushed.clear() + + # Add an item and immediately close — too fast for the 50ms timer. + await buf.add(_text(task_message, "last")) + await buf.close() + + assert len(flushed) == 1 + assert flushed[0].delta is not None and isinstance(flushed[0].delta, TextDelta) + assert flushed[0].delta.text_delta == "last" + + @pytest.mark.asyncio + async def test_close_when_idle_is_safe(self, task_message: TaskMessage) -> None: + """``close()`` with no buffered items must not raise.""" + buf = CoalescingBuffer(on_flush=AsyncMock()) + buf.start() + await buf.close() # no items, no signal, just exit cleanly + + @pytest.mark.asyncio + async def test_add_after_close_is_noop(self, task_message: TaskMessage) -> None: + """Defensive: ``add`` after ``close`` must silently do nothing rather + than raise. Real flows shouldn't hit this but tests racing close() + should not blow up.""" + flushed: list[StreamTaskMessageDelta] = [] + + async def on_flush(u: StreamTaskMessageDelta) -> None: + flushed.append(u) + + buf = CoalescingBuffer(on_flush=on_flush) + buf.start() + await buf.close() + # Fully drained and closed; this should silently no-op. + await buf.add(_text(task_message, "after")) + assert flushed == [] + + +class TestCoalescingBufferCancelDuringFlush: + @pytest.mark.asyncio + async def test_cancel_during_flush_recovers_remaining_items( + self, task_message: TaskMessage + ) -> None: + """Regression: when ``close()`` cancels the ticker mid-flush, items in + the local ``drained`` list must be re-enqueued so the final drain in + ``close()`` can recover them. Otherwise the last coalesced batch is + silently dropped — visible to consumers as a truncated stream. + """ + flushed: list[StreamTaskMessageDelta] = [] + first_started = asyncio.Event() + first_continue = asyncio.Event() + + async def slow_flush(u: StreamTaskMessageDelta) -> None: + flushed.append(u) + if len(flushed) == 1: + first_started.set() + # Block the first publish until the test releases it. This + # guarantees the cancellation lands inside the flush loop. + await first_continue.wait() + + buf = CoalescingBuffer(on_flush=slow_flush) + buf.start() + # Add five items quickly; they all land in self._buf and the ticker + # will drain them as one merged batch. + for i in range(5): + await buf.add(_text(task_message, f"chunk{i}")) + + await asyncio.wait_for(first_started.wait(), timeout=2.0) + # Trigger close() while the first flush is blocked, then release it. + close_task = asyncio.create_task(buf.close()) + first_continue.set() + await close_task + + # All five chunks must appear at least once across all publishes. + # (The first-flushed item may duplicate; that's the documented + # trade-off — duplicate > silent loss.) + full = "".join( + u.delta.text_delta or "" + for u in flushed + if isinstance(u.delta, TextDelta) + ) + for i in range(5): + assert f"chunk{i}" in full, ( + f"chunk{i} missing — silent data loss across cancel-during-flush boundary. " + f"flushed payloads: {[u.delta.text_delta for u in flushed if isinstance(u.delta, TextDelta)]}" + ) + + +class TestStreamingTaskMessageContextModes: + @pytest.mark.asyncio + async def test_off_mode_skips_publishes_but_persists_full_content(self) -> None: + ctx, svc, tm = await _make_context("off") + svc.stream_update.reset_mock() + for chunk in ("Hello", " ", "world"): + await ctx.stream_update(_text(tm, chunk)) + # Plenty of time for any background ticker — none should exist. + await asyncio.sleep(0.080) + assert svc.stream_update.call_count == 0, ( + "off mode must publish zero per-delta updates" + ) + + await ctx.close() + # The persisted message body must still contain the full assembled text, + # because the accumulator was fed even when publishing was suppressed. + update_kwargs = ctx._agentex_client.messages.update.call_args.kwargs + assert update_kwargs["content"]["content"] == "Hello world" + + @pytest.mark.asyncio + async def test_per_token_mode_publishes_each_delta_immediately(self) -> None: + ctx, svc, tm = await _make_context("per_token") + svc.stream_update.reset_mock() + for chunk in ("a", "b", "c"): + await ctx.stream_update(_text(tm, chunk)) + # Per-token mode must publish synchronously, no waiting required. + assert svc.stream_update.call_count == 3 + await ctx.close() + + @pytest.mark.asyncio + async def test_coalesced_mode_batches_and_persists_full_content(self) -> None: + ctx, svc, tm = await _make_context("coalesced") + svc.stream_update.reset_mock() + for chunk in ("Hello", " ", "world", "!"): + await ctx.stream_update(_text(tm, chunk)) + await ctx.close() + + # Assembled content is the union of all per-delta text. + update_kwargs = ctx._agentex_client.messages.update.call_args.kwargs + assert update_kwargs["content"]["content"] == "Hello world!" + + # Coalesced mode produces fewer publishes than per_token (4) but at + # least the start + at least one delta + done. + delta_publishes = [ + call + for call in svc.stream_update.call_args_list + if isinstance(call.args[0] if call.args else None, StreamTaskMessageDelta) + ] + assert len(delta_publishes) >= 1, "coalesced mode should publish at least one delta" + assert len(delta_publishes) < 4, ( + "coalesced mode should batch at least some of the four chunks" + )