diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 243eef5ae..fb63cf39f 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -23,6 +23,7 @@ INVALID_PARAMS, REQUEST_TIMEOUT, CancelledNotification, + CancelledNotificationParams, ClientNotification, ClientRequest, ClientResult, @@ -141,7 +142,7 @@ async def respond(self, response: SendResultT | ErrorData) -> None: async def cancel(self) -> None: """Cancel this request and mark it as completed.""" - if not self._entered: # pragma: no cover + if not self._entered: raise RuntimeError("RequestResponder must be used as a context manager") if not self._cancel_scope: # pragma: no cover raise RuntimeError("No active cancel scope") @@ -292,9 +293,13 @@ async def send_request( with anyio.fail_after(timeout): response_or_error = await response_stream_reader.receive() except TimeoutError: + await self._send_cancelled_notification(request_id, "request timed out") class_name = request.__class__.__name__ message = f"Timed out while waiting for response to {class_name}. Waited {timeout} seconds." raise MCPError(code=REQUEST_TIMEOUT, message=message) + except anyio.get_cancelled_exc_class(): + await self._send_cancelled_notification(request_id, "request cancelled by caller") + raise if isinstance(response_or_error, JSONRPCError): raise MCPError.from_jsonrpc_error(response_or_error) @@ -418,7 +423,7 @@ async def _handle_session_message(message: SessionMessage) -> None: await self._handle_incoming(notification) except Exception: # For other validation errors, log and continue - logging.warning( # pragma: no cover + logging.warning( f"Failed to validate notification:. Message was: {message.message}", exc_info=True, ) @@ -540,6 +545,28 @@ async def send_progress_notification( ) -> None: """Sends a progress notification for a request that is currently being processed.""" + async def _send_cancelled_notification(self, request_id: RequestId, reason: str) -> None: + """Send a cancellation notification to the remote side (best-effort). + + Uses a shielded cancel scope with a timeout so the notification is + delivered even when called from inside a cancelled task, but does not + block shutdown if the write stream is unavailable. + """ + try: + with anyio.CancelScope(shield=True): + with anyio.fail_after(2): + notification = CancelledNotification( + method="notifications/cancelled", + params=CancelledNotificationParams(request_id=request_id, reason=reason), + ) + jsonrpc_notification = JSONRPCNotification( + jsonrpc="2.0", + **notification.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + await self._write_stream.send(SessionMessage(message=jsonrpc_notification)) + except Exception: + logging.warning("Failed to send cancellation notification for request %s", request_id) + async def _handle_incoming( self, req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception ) -> None: diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index d7c6cc3b5..e0ed5ca11 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -416,3 +416,95 @@ async def make_request(client_session: ClientSession): # Pending request completed successfully assert len(result_holder) == 1 assert isinstance(result_holder[0], EmptyResult) + + +@pytest.mark.anyio +async def test_send_request_sends_cancelled_notification_on_timeout(): + """Client must send notifications/cancelled when a request times out. + + Regression test for https://github.com/modelcontextprotocol/python-sdk/issues/2507. + """ + cancelled_notifications: list[SessionMessage] = [] + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, _ = server_streams + + async def mock_server(): + async for message in server_read: # pragma: no branch + assert isinstance(message, SessionMessage) + if isinstance(message.message, types.JSONRPCNotification): + cancelled_notifications.append(message) + + async with ( + anyio.create_task_group() as tg, + ClientSession( + read_stream=client_read, + write_stream=client_write, + read_timeout_seconds=0.1, + ) as client_session, + ): + tg.start_soon(mock_server) + + with pytest.raises(MCPError, match="Timed out"): + await client_session.send_ping() + + await anyio.sleep(0.05) + tg.cancel_scope.cancel() + + assert len(cancelled_notifications) == 1 + notif = cancelled_notifications[0].message + assert isinstance(notif, types.JSONRPCNotification) + assert notif.method == "notifications/cancelled" + assert notif.params is not None + assert notif.params.get("reason") == "request timed out" + + +@pytest.mark.anyio +async def test_send_request_sends_cancelled_notification_on_caller_cancel(): + """Client must send notifications/cancelled when caller cancels the request. + + Regression test for https://github.com/modelcontextprotocol/python-sdk/issues/2507. + """ + cancelled_notifications: list[SessionMessage] = [] + ev_request_sent = anyio.Event() + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, _ = server_streams + + async def mock_server(): + async for message in server_read: # pragma: no branch + assert isinstance(message, SessionMessage) + if isinstance(message.message, JSONRPCRequest): + ev_request_sent.set() + if isinstance(message.message, types.JSONRPCNotification): + cancelled_notifications.append(message) + + async def make_request(client_session: ClientSession): + await client_session.send_ping() + + async with ( + anyio.create_task_group() as tg, + ClientSession( + read_stream=client_read, + write_stream=client_write, + ) as client_session, + ): + tg.start_soon(mock_server) + + async with anyio.create_task_group() as request_tg: + request_tg.start_soon(make_request, client_session) + with anyio.fail_after(1): + await ev_request_sent.wait() + request_tg.cancel_scope.cancel() + + await anyio.sleep(0.05) + tg.cancel_scope.cancel() + + assert len(cancelled_notifications) == 1 + notif = cancelled_notifications[0].message + assert isinstance(notif, types.JSONRPCNotification) + assert notif.method == "notifications/cancelled" + assert notif.params is not None + assert notif.params.get("reason") == "request cancelled by caller"