Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 29 additions & 7 deletions src/mcp/shared/experimental/tasks/in_memory_task_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ class InMemoryTaskStore(TaskStore):
def __init__(self, page_size: int = 10) -> None:
self._tasks: dict[str, StoredTask] = {}
self._page_size = page_size
self._update_events: dict[str, anyio.Event] = {}
self._update_events: dict[str, list[anyio.Event]] = {}
self._pending_updates: set[str] = set()

def _calculate_expiry(self, ttl_ms: int | None) -> datetime | None:
"""Calculate expiry time from TTL in milliseconds."""
Expand Down Expand Up @@ -194,22 +195,43 @@ async def wait_for_update(self, task_id: str) -> None:
if task_id not in self._tasks:
raise ValueError(f"Task with ID {task_id} not found")

# Create a fresh event for waiting (anyio.Event can't be cleared)
self._update_events[task_id] = anyio.Event()
event = self._update_events[task_id]
await event.wait()
# If an update arrived before we started waiting, consume it and return.
if task_id in self._pending_updates:
self._pending_updates.discard(task_id)
return

# Create a per-waiter event so multiple concurrent waiters each get woken.
event = anyio.Event()
if task_id not in self._update_events:
self._update_events[task_id] = []
self._update_events[task_id].append(event)
try:
await event.wait()
finally:
# Clean up our event from the list (may already be removed by notify).
try:
self._update_events[task_id].remove(event)
except (ValueError, KeyError):
pass

async def notify_update(self, task_id: str) -> None:
"""Signal that a task has been updated."""
if task_id in self._update_events:
self._update_events[task_id].set()
events = self._update_events.pop(task_id, [])
if events:
for event in events:
event.set()
else:
# No waiters yet; mark as pending so the next wait_for_update returns
# immediately instead of blocking.
self._pending_updates.add(task_id)

# --- Testing/debugging helpers ---

def cleanup(self) -> None:
"""Cleanup all tasks (useful for testing or graceful shutdown)."""
self._tasks.clear()
self._update_events.clear()
self._pending_updates.clear()

def get_all_tasks(self) -> list[Task]:
"""Get all tasks (useful for debugging). Returns copies to prevent modification."""
Expand Down
40 changes: 40 additions & 0 deletions tests/experimental/tasks/server/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections.abc import AsyncIterator
from datetime import datetime, timedelta, timezone

import anyio
import pytest

from mcp.shared.exceptions import MCPError
Expand Down Expand Up @@ -328,6 +329,45 @@ async def test_wait_for_update_nonexistent_raises(store: InMemoryTaskStore) -> N
await store.wait_for_update("nonexistent-task-id")


@pytest.mark.anyio
async def test_wait_for_update_concurrent_waiters(store: InMemoryTaskStore) -> None:
"""Two concurrent waiters for the same task must both wake up."""
task = await store.create_task(metadata=TaskMetadata(ttl=60000))

woke: dict[str, bool] = {"a": False, "b": False}

async def waiter(name: str) -> None:
await store.wait_for_update(task.task_id)
woke[name] = True

async def updater() -> None:
await anyio.sleep(0.05)
await store.update_task(task.task_id, status="completed")

with anyio.fail_after(2):
async with anyio.create_task_group() as tg:
tg.start_soon(waiter, "a")
await anyio.sleep(0.01) # ensure a registers first
tg.start_soon(waiter, "b")
tg.start_soon(updater)

assert woke["a"], "waiter a should have been woken"
assert woke["b"], "waiter b should have been woken"


@pytest.mark.anyio
async def test_wait_for_update_notify_before_wait(store: InMemoryTaskStore) -> None:
"""If notify fires before wait, the signal must not be lost."""
task = await store.create_task(metadata=TaskMetadata(ttl=60000))

# Task completes before anyone waits
await store.update_task(task.task_id, status="completed")

# wait_for_update should return immediately (pending update consumed)
with anyio.fail_after(1):
await store.wait_for_update(task.task_id)


@pytest.mark.anyio
async def test_cancel_task_succeeds_for_working_task(store: InMemoryTaskStore) -> None:
"""Test cancel_task helper succeeds for a working task."""
Expand Down
Loading