Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/agentex/lib/adk/_modules/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from agentex.lib.adk.utils._modules.client import create_async_agentex_client
from agentex.lib.core.adapters.streams.adapter_redis import RedisStreamRepository
from agentex.lib.core.services.adk.streaming import (
StreamingMode,
StreamingService,
StreamingTaskMessageContext,
)
Expand Down Expand Up @@ -50,6 +51,7 @@ def streaming_task_message_context(
self,
task_id: str,
initial_content: TaskMessageContent,
streaming_mode: StreamingMode = "coalesced",
) -> StreamingTaskMessageContext:
"""
Create a streaming context for managing TaskMessage lifecycle.
Expand All @@ -60,7 +62,11 @@ def streaming_task_message_context(
Args:
task_id: The ID of the task
initial_content: The initial content for the TaskMessage
agentex_client: The agentex client for creating/updating messages
streaming_mode: How per-delta updates are published. Defaults to
"coalesced" (50ms / 128-char windowed batches with an immediate
first-delta flush). Pass "per_token" for the legacy publish-every-
delta behavior, or "off" to suppress per-delta publishes entirely
while still recording the full message body on close.

Returns:
StreamingTaskMessageContext: Context manager for streaming operations
Expand All @@ -76,4 +82,5 @@ def streaming_task_message_context(
return self._streaming_service.streaming_task_message_context(
task_id=task_id,
initial_content=initial_content,
streaming_mode=streaming_mode,
)
235 changes: 231 additions & 4 deletions src/agentex/lib/core/services/adk/streaming.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import json
from typing import Literal
import asyncio
import contextlib
from typing import Literal, Callable, Awaitable

from agentex import AsyncAgentex
from agentex.lib.utils.logging import make_logger
Expand Down Expand Up @@ -39,6 +41,198 @@ def _get_stream_topic(task_id: str) -> str:
return f"task:{task_id}"


StreamingMode = Literal["off", "per_token", "coalesced"]
"""Controls how a StreamingTaskMessageContext publishes deltas.

- "off": Feed the accumulator (so the persisted message body is correct)
but never publish per-delta events. Consumers see start + done
only. Lowest latency.
- "per_token": Publish every delta immediately. Highest UX fidelity for
token-by-token rendering, highest Redis cost, and re-introduces
head-of-line blocking on the producer's event loop.
- "coalesced": Buffer deltas in a small time/size window and publish them as
merged batches. The first delta flushes immediately for fast
perceived responsiveness; subsequent deltas flush every 50ms or
whenever 128 buffered chars accumulate, whichever comes first.
Order within each (delta type, index) channel is preserved
exactly; only granularity changes.
"""


def _delta_char_len(delta: TaskMessageDelta | None) -> int:
if delta is None:
return 0
if isinstance(delta, TextDelta):
return len(delta.text_delta or "")
if isinstance(delta, DataDelta):
return len(delta.data_delta or "")
if isinstance(delta, ReasoningSummaryDelta):
return len(delta.summary_delta or "")
if isinstance(delta, ReasoningContentDelta):
return len(delta.content_delta or "")
if isinstance(delta, ToolRequestDelta):
return len(delta.arguments_delta or "")
if isinstance(delta, ToolResponseDelta):
return len(delta.content_delta or "")
return 0


def _can_merge(a: TaskMessageDelta, b: TaskMessageDelta) -> bool:
if type(a) is not type(b):
return False
if isinstance(a, ReasoningSummaryDelta) and isinstance(b, ReasoningSummaryDelta):
return a.summary_index == b.summary_index
if isinstance(a, ReasoningContentDelta) and isinstance(b, ReasoningContentDelta):
return a.content_index == b.content_index
if isinstance(a, ToolRequestDelta) and isinstance(b, ToolRequestDelta):
return a.tool_call_id == b.tool_call_id
if isinstance(a, ToolResponseDelta) and isinstance(b, ToolResponseDelta):
return a.tool_call_id == b.tool_call_id
return True


def _merge_pair(a: TaskMessageDelta, b: TaskMessageDelta) -> TaskMessageDelta:
if isinstance(a, TextDelta) and isinstance(b, TextDelta):
return TextDelta(type="text", text_delta=(a.text_delta or "") + (b.text_delta or ""))
if isinstance(a, DataDelta) and isinstance(b, DataDelta):
return DataDelta(type="data", data_delta=(a.data_delta or "") + (b.data_delta or ""))
if isinstance(a, ReasoningSummaryDelta) and isinstance(b, ReasoningSummaryDelta):
return ReasoningSummaryDelta(
type="reasoning_summary",
summary_index=a.summary_index,
summary_delta=(a.summary_delta or "") + (b.summary_delta or ""),
)
if isinstance(a, ReasoningContentDelta) and isinstance(b, ReasoningContentDelta):
return ReasoningContentDelta(
type="reasoning_content",
content_index=a.content_index,
content_delta=(a.content_delta or "") + (b.content_delta or ""),
)
if isinstance(a, ToolRequestDelta) and isinstance(b, ToolRequestDelta):
return ToolRequestDelta(
type="tool_request",
tool_call_id=a.tool_call_id,
name=a.name,
arguments_delta=(a.arguments_delta or "") + (b.arguments_delta or ""),
)
if isinstance(a, ToolResponseDelta) and isinstance(b, ToolResponseDelta):
return ToolResponseDelta(
type="tool_response",
tool_call_id=a.tool_call_id,
name=a.name,
content_delta=(a.content_delta or "") + (b.content_delta or ""),
)
raise AssertionError(
f"_can_merge approved {type(a).__name__} pair but _merge_pair has no handler — "
"a new TaskMessageDelta variant was added without updating both functions"
)


def _merge_consecutive(updates: list[StreamTaskMessageDelta]) -> list[StreamTaskMessageDelta]:
"""Merge consecutive same-channel deltas. Order across channels is preserved exactly."""
result: list[StreamTaskMessageDelta] = []
for u in updates:
if u.delta is None or not result:
result.append(u)
continue
last = result[-1]
if last.delta is not None and _can_merge(last.delta, u.delta):
result[-1] = StreamTaskMessageDelta(
parent_task_message=last.parent_task_message,
delta=_merge_pair(last.delta, u.delta),
type="delta",
)
else:
result.append(u)
return result


class CoalescingBuffer:
"""Time-and-size-windowed buffer that merges consecutive same-channel deltas.

Decouples the producer (model event loop) from the publisher (Redis): ``add``
only enqueues and may signal an early flush; the actual publish always runs
on a background ticker, so the producer never awaits on a Redis round-trip.
"""

FLUSH_INTERVAL_S = 0.050
MAX_BUFFERED_CHARS = 128

def __init__(self, on_flush: Callable[[StreamTaskMessageDelta], Awaitable[object]]):
self._on_flush = on_flush
self._buf: list[StreamTaskMessageDelta] = []
self._buf_chars = 0
self._first_flushed = False
self._closed = False
self._lock = asyncio.Lock()
self._flush_signal = asyncio.Event()
self._task: asyncio.Task[None] | None = None

def start(self) -> None:
if self._task is None:
self._task = asyncio.create_task(self._run(), name="coalescing-buffer")

async def add(self, update: StreamTaskMessageDelta) -> None:
if self._closed:
return
async with self._lock:
self._buf.append(update)
self._buf_chars += _delta_char_len(update.delta)
if not self._first_flushed or self._buf_chars >= self.MAX_BUFFERED_CHARS:
self._first_flushed = True
self._flush_signal.set()

async def _run(self) -> None:
try:
while not self._closed:
try:
await asyncio.wait_for(self._flush_signal.wait(), timeout=self.FLUSH_INTERVAL_S)
except asyncio.TimeoutError:
pass
async with self._lock:
self._flush_signal.clear()
drained = self._drain_locked()
for idx, u in enumerate(drained):
try:
await self._on_flush(u)
except asyncio.CancelledError:
# Re-enqueue the item being flushed plus any remaining so
# close()'s final drain can recover them. May cause a
# duplicate publish of the in-flight item, which is
# preferable to silent loss for a streaming UX.
async with self._lock:
self._buf = drained[idx:] + self._buf
raise
except Exception as e:
logger.exception(f"CoalescingBuffer flush failed: {e}")
except asyncio.CancelledError:
pass

async def close(self) -> None:
self._closed = True
if self._task is not None:
self._flush_signal.set()
self._task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._task
self._task = None
async with self._lock:
drained = self._drain_locked()
for u in drained:
try:
await self._on_flush(u)
except Exception as e:
logger.exception(f"CoalescingBuffer final flush failed: {e}")

def _drain_locked(self) -> list[StreamTaskMessageDelta]:
if not self._buf:
return []
merged = _merge_consecutive(self._buf)
self._buf = []
self._buf_chars = 0
return merged


class DeltaAccumulator:
def __init__(self):
self._accumulated_deltas: list[TaskMessageDelta] = []
Expand Down Expand Up @@ -176,6 +370,7 @@ def __init__(
initial_content: TaskMessageContent,
agentex_client: AsyncAgentex,
streaming_service: "StreamingService",
streaming_mode: StreamingMode = "coalesced",
):
self.task_id = task_id
self.initial_content = initial_content
Expand All @@ -184,6 +379,8 @@ def __init__(
self._streaming_service = streaming_service
self._is_closed = False
self._delta_accumulator = DeltaAccumulator()
self._streaming_mode: StreamingMode = streaming_mode
self._buffer: CoalescingBuffer | None = None

async def __aenter__(self) -> "StreamingTaskMessageContext":
return await self.open()
Expand All @@ -208,6 +405,10 @@ async def open(self) -> "StreamingTaskMessageContext":
)
await self._streaming_service.stream_update(start_event)

if self._streaming_mode == "coalesced":
self._buffer = CoalescingBuffer(on_flush=self._streaming_service.stream_update)
self._buffer.start()

return self

async def close(self) -> TaskMessage:
Expand All @@ -218,6 +419,12 @@ async def close(self) -> TaskMessage:
if self._is_closed:
return self.task_message # Already done

# Drain any buffered deltas before announcing DONE so consumers see the
# full sequence in order.
if self._buffer is not None:
await self._buffer.close()
self._buffer = None

# Send the DONE event
done_event = StreamTaskMessageDone(
parent_task_message=self.task_message,
Expand All @@ -227,8 +434,8 @@ async def close(self) -> TaskMessage:

# Update the task message with the final content
has_deltas = (
self._delta_accumulator._accumulated_deltas or
self._delta_accumulator._reasoning_summaries or
self._delta_accumulator._accumulated_deltas or
self._delta_accumulator._reasoning_summaries or
self._delta_accumulator._reasoning_contents
)
if has_deltas:
Expand All @@ -248,7 +455,20 @@ async def close(self) -> TaskMessage:
async def stream_update(
self, update: TaskMessageUpdate
) -> TaskMessageUpdate | None:
"""Stream an update to the repository."""
"""Stream an update to the repository.

Behavior depends on the context's ``streaming_mode``:
- "off": delta updates feed the accumulator (so the persisted message
body is correct) but are never published.
- "per_token": delta updates are published immediately.
- "coalesced": delta updates are queued in a 50ms / 128-char window and
flushed as merged batches on a background ticker; the first delta
flushes immediately for fast perceived responsiveness.

``StreamTaskMessageDone`` and ``StreamTaskMessageFull`` updates always
publish synchronously regardless of mode so consumers and persistence
stay in sync.
"""
if self._is_closed:
raise ValueError("Context is already done")

Expand All @@ -258,6 +478,11 @@ async def stream_update(
if isinstance(update, StreamTaskMessageDelta):
if update.delta is not None:
self._delta_accumulator.add_delta(update.delta)
if self._streaming_mode == "off":
return update
if self._streaming_mode == "coalesced" and self._buffer is not None:
await self._buffer.add(update)
return update

result = await self._streaming_service.stream_update(update)

Expand Down Expand Up @@ -288,12 +513,14 @@ def streaming_task_message_context(
self,
task_id: str,
initial_content: TaskMessageContent,
streaming_mode: StreamingMode = "coalesced",
) -> StreamingTaskMessageContext:
return StreamingTaskMessageContext(
task_id=task_id,
initial_content=initial_content,
agentex_client=self._agentex_client,
streaming_service=self,
streaming_mode=streaming_mode,
)

async def stream_update(
Expand Down
Loading
Loading