Skip to content
31 changes: 29 additions & 2 deletions src/mcp/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
INVALID_PARAMS,
REQUEST_TIMEOUT,
CancelledNotification,
CancelledNotificationParams,
ClientNotification,
ClientRequest,
ClientResult,
Expand Down Expand Up @@ -141,7 +142,7 @@ async def respond(self, response: SendResultT | ErrorData) -> None:

async def cancel(self) -> None:
"""Cancel this request and mark it as completed."""
if not self._entered: # pragma: no cover
if not self._entered:
raise RuntimeError("RequestResponder must be used as a context manager")
if not self._cancel_scope: # pragma: no cover
raise RuntimeError("No active cancel scope")
Expand Down Expand Up @@ -292,9 +293,13 @@ async def send_request(
with anyio.fail_after(timeout):
response_or_error = await response_stream_reader.receive()
except TimeoutError:
await self._send_cancelled_notification(request_id, "request timed out")
class_name = request.__class__.__name__
message = f"Timed out while waiting for response to {class_name}. Waited {timeout} seconds."
raise MCPError(code=REQUEST_TIMEOUT, message=message)
except anyio.get_cancelled_exc_class():
await self._send_cancelled_notification(request_id, "request cancelled by caller")
raise

if isinstance(response_or_error, JSONRPCError):
raise MCPError.from_jsonrpc_error(response_or_error)
Expand Down Expand Up @@ -418,7 +423,7 @@ async def _handle_session_message(message: SessionMessage) -> None:
await self._handle_incoming(notification)
except Exception:
# For other validation errors, log and continue
logging.warning( # pragma: no cover
logging.warning(
f"Failed to validate notification:. Message was: {message.message}",
exc_info=True,
)
Expand Down Expand Up @@ -540,6 +545,28 @@ async def send_progress_notification(
) -> None:
"""Sends a progress notification for a request that is currently being processed."""

async def _send_cancelled_notification(self, request_id: RequestId, reason: str) -> None:
"""Send a cancellation notification to the remote side (best-effort).

Uses a shielded cancel scope with a timeout so the notification is
delivered even when called from inside a cancelled task, but does not
block shutdown if the write stream is unavailable.
"""
try:
with anyio.CancelScope(shield=True):
with anyio.fail_after(2):
notification = CancelledNotification(
method="notifications/cancelled",
params=CancelledNotificationParams(request_id=request_id, reason=reason),
)
jsonrpc_notification = JSONRPCNotification(
jsonrpc="2.0",
**notification.model_dump(by_alias=True, mode="json", exclude_none=True),
)
await self._write_stream.send(SessionMessage(message=jsonrpc_notification))
except Exception:
logging.warning("Failed to send cancellation notification for request %s", request_id)

async def _handle_incoming(
self, req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception
) -> None:
Expand Down
92 changes: 92 additions & 0 deletions tests/shared/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,3 +416,95 @@ async def make_request(client_session: ClientSession):
# Pending request completed successfully
assert len(result_holder) == 1
assert isinstance(result_holder[0], EmptyResult)


@pytest.mark.anyio
async def test_send_request_sends_cancelled_notification_on_timeout():
"""Client must send notifications/cancelled when a request times out.

Regression test for https://github.com/modelcontextprotocol/python-sdk/issues/2507.
"""
cancelled_notifications: list[SessionMessage] = []

async with create_client_server_memory_streams() as (client_streams, server_streams):
client_read, client_write = client_streams
server_read, _ = server_streams

async def mock_server():
async for message in server_read: # pragma: no branch
assert isinstance(message, SessionMessage)
if isinstance(message.message, types.JSONRPCNotification):
cancelled_notifications.append(message)

async with (
anyio.create_task_group() as tg,
ClientSession(
read_stream=client_read,
write_stream=client_write,
read_timeout_seconds=0.1,
) as client_session,
):
tg.start_soon(mock_server)

with pytest.raises(MCPError, match="Timed out"):
await client_session.send_ping()

await anyio.sleep(0.05)
tg.cancel_scope.cancel()

assert len(cancelled_notifications) == 1
notif = cancelled_notifications[0].message
assert isinstance(notif, types.JSONRPCNotification)
assert notif.method == "notifications/cancelled"
assert notif.params is not None
assert notif.params.get("reason") == "request timed out"


@pytest.mark.anyio
async def test_send_request_sends_cancelled_notification_on_caller_cancel():
"""Client must send notifications/cancelled when caller cancels the request.

Regression test for https://github.com/modelcontextprotocol/python-sdk/issues/2507.
"""
cancelled_notifications: list[SessionMessage] = []
ev_request_sent = anyio.Event()

async with create_client_server_memory_streams() as (client_streams, server_streams):
client_read, client_write = client_streams
server_read, _ = server_streams

async def mock_server():
async for message in server_read: # pragma: no branch
assert isinstance(message, SessionMessage)
if isinstance(message.message, JSONRPCRequest):
ev_request_sent.set()
if isinstance(message.message, types.JSONRPCNotification):
cancelled_notifications.append(message)

async def make_request(client_session: ClientSession):
await client_session.send_ping()

async with (
anyio.create_task_group() as tg,
ClientSession(
read_stream=client_read,
write_stream=client_write,
) as client_session,
):
tg.start_soon(mock_server)

async with anyio.create_task_group() as request_tg:
request_tg.start_soon(make_request, client_session)
with anyio.fail_after(1):
await ev_request_sent.wait()
request_tg.cancel_scope.cancel()

await anyio.sleep(0.05)
tg.cancel_scope.cancel()

assert len(cancelled_notifications) == 1
notif = cancelled_notifications[0].message
assert isinstance(notif, types.JSONRPCNotification)
assert notif.method == "notifications/cancelled"
assert notif.params is not None
assert notif.params.get("reason") == "request cancelled by caller"
Loading