From 065264d98fc9f361e9fbb9c3041aaadbde0e8af8 Mon Sep 17 00:00:00 2001 From: Dayna Blackwell Date: Sat, 2 May 2026 07:56:53 -0700 Subject: [PATCH] fix: resolve lost-wakeup races in InMemoryTaskStore.wait_for_update Two races in wait_for_update: 1. Concurrent waiters: second caller overwrites the first's event in _update_events[task_id], so the first waiter hangs forever. Fix: use a list of events per task_id so each waiter gets its own. 2. Notify before wait: if update_task completes before wait_for_update is called, the signal is lost because no event exists yet. Fix: track pending updates in a set; wait_for_update checks and consumes pending flags before creating an event. Both races are reachable via task_result_handler.py:126 when multiple clients poll the same task or when a task completes between status checks. Adds two tests: concurrent waiters and notify-before-wait. Fixes #2535 --- .../tasks/in_memory_task_store.py | 36 +++++++++++++---- tests/experimental/tasks/server/test_store.py | 40 +++++++++++++++++++ 2 files changed, 69 insertions(+), 7 deletions(-) diff --git a/src/mcp/shared/experimental/tasks/in_memory_task_store.py b/src/mcp/shared/experimental/tasks/in_memory_task_store.py index 42f4fb703..1982931fe 100644 --- a/src/mcp/shared/experimental/tasks/in_memory_task_store.py +++ b/src/mcp/shared/experimental/tasks/in_memory_task_store.py @@ -46,7 +46,8 @@ class InMemoryTaskStore(TaskStore): def __init__(self, page_size: int = 10) -> None: self._tasks: dict[str, StoredTask] = {} self._page_size = page_size - self._update_events: dict[str, anyio.Event] = {} + self._update_events: dict[str, list[anyio.Event]] = {} + self._pending_updates: set[str] = set() def _calculate_expiry(self, ttl_ms: int | None) -> datetime | None: """Calculate expiry time from TTL in milliseconds.""" @@ -194,15 +195,35 @@ async def wait_for_update(self, task_id: str) -> None: if task_id not in self._tasks: raise ValueError(f"Task with ID {task_id} not found") - # Create a fresh event for waiting (anyio.Event can't be cleared) - self._update_events[task_id] = anyio.Event() - event = self._update_events[task_id] - await event.wait() + # If an update arrived before we started waiting, consume it and return. + if task_id in self._pending_updates: + self._pending_updates.discard(task_id) + return + + # Create a per-waiter event so multiple concurrent waiters each get woken. + event = anyio.Event() + if task_id not in self._update_events: + self._update_events[task_id] = [] + self._update_events[task_id].append(event) + try: + await event.wait() + finally: + # Clean up our event from the list (may already be removed by notify). + try: + self._update_events[task_id].remove(event) + except (ValueError, KeyError): + pass async def notify_update(self, task_id: str) -> None: """Signal that a task has been updated.""" - if task_id in self._update_events: - self._update_events[task_id].set() + events = self._update_events.pop(task_id, []) + if events: + for event in events: + event.set() + else: + # No waiters yet; mark as pending so the next wait_for_update returns + # immediately instead of blocking. + self._pending_updates.add(task_id) # --- Testing/debugging helpers --- @@ -210,6 +231,7 @@ def cleanup(self) -> None: """Cleanup all tasks (useful for testing or graceful shutdown).""" self._tasks.clear() self._update_events.clear() + self._pending_updates.clear() def get_all_tasks(self) -> list[Task]: """Get all tasks (useful for debugging). Returns copies to prevent modification.""" diff --git a/tests/experimental/tasks/server/test_store.py b/tests/experimental/tasks/server/test_store.py index 0d431899c..d3d9e2644 100644 --- a/tests/experimental/tasks/server/test_store.py +++ b/tests/experimental/tasks/server/test_store.py @@ -3,6 +3,7 @@ from collections.abc import AsyncIterator from datetime import datetime, timedelta, timezone +import anyio import pytest from mcp.shared.exceptions import MCPError @@ -328,6 +329,45 @@ async def test_wait_for_update_nonexistent_raises(store: InMemoryTaskStore) -> N await store.wait_for_update("nonexistent-task-id") +@pytest.mark.anyio +async def test_wait_for_update_concurrent_waiters(store: InMemoryTaskStore) -> None: + """Two concurrent waiters for the same task must both wake up.""" + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + + woke: dict[str, bool] = {"a": False, "b": False} + + async def waiter(name: str) -> None: + await store.wait_for_update(task.task_id) + woke[name] = True + + async def updater() -> None: + await anyio.sleep(0.05) + await store.update_task(task.task_id, status="completed") + + with anyio.fail_after(2): + async with anyio.create_task_group() as tg: + tg.start_soon(waiter, "a") + await anyio.sleep(0.01) # ensure a registers first + tg.start_soon(waiter, "b") + tg.start_soon(updater) + + assert woke["a"], "waiter a should have been woken" + assert woke["b"], "waiter b should have been woken" + + +@pytest.mark.anyio +async def test_wait_for_update_notify_before_wait(store: InMemoryTaskStore) -> None: + """If notify fires before wait, the signal must not be lost.""" + task = await store.create_task(metadata=TaskMetadata(ttl=60000)) + + # Task completes before anyone waits + await store.update_task(task.task_id, status="completed") + + # wait_for_update should return immediately (pending update consumed) + with anyio.fail_after(1): + await store.wait_for_update(task.task_id) + + @pytest.mark.anyio async def test_cancel_task_succeeds_for_working_task(store: InMemoryTaskStore) -> None: """Test cancel_task helper succeeds for a working task."""