diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index 961021264..aa99c5b42 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -147,9 +147,9 @@ def __init__( self._session_exit_stacks = {} self._component_name_hook = component_name_hook - async def __aenter__(self) -> Self: # pragma: no cover + async def __aenter__(self) -> Self: # Enter the exit stack only if we created it ourselves - if self._owns_exit_stack: + if self._owns_exit_stack: # pragma: no branch await self._exit_stack.__aenter__() return self @@ -158,22 +158,22 @@ async def __aexit__( _exc_type: type[BaseException] | None, _exc_val: BaseException | None, _exc_tb: TracebackType | None, - ) -> bool | None: # pragma: no cover + ) -> bool | None: """Closes session exit stacks and main exit stack upon completion.""" # Only close the main exit stack if we created it - if self._owns_exit_stack: + if self._owns_exit_stack: # pragma: no branch await self._exit_stack.aclose() # Concurrently close session stacks. async with anyio.create_task_group() as tg: for exit_stack in self._session_exit_stacks.values(): - tg.start_soon(exit_stack.aclose) + tg.start_soon(exit_stack.aclose) # pragma: no cover @property def sessions(self) -> list[mcp.ClientSession]: """Returns the list of sessions being managed.""" - return list(self._sessions.keys()) # pragma: no cover + return list(self._sessions.keys()) @property def prompts(self) -> dict[str, types.Prompt]: @@ -323,7 +323,7 @@ async def _establish_session( await self._exit_stack.enter_async_context(session_stack) return result.server_info, session - except Exception: # pragma: no cover + except Exception: # If anything during this setup fails, ensure the session-specific # stack is closed. await session_stack.aclose() diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 9a119c633..e2404868b 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -429,6 +429,20 @@ async def _handle_reconnection( # Try to reconnect again if we still have an event ID await self._handle_reconnection(ctx, last_event_id, retry_interval_ms, attempt + 1) + async def _handle_request_error(self, ctx: RequestContext, exc: Exception) -> None: + """Report a request transport failure without crashing the transport task group.""" + logger.debug("Error handling StreamableHTTP request", exc_info=True) + + message = ctx.session_message.message + if isinstance(message, JSONRPCRequest): + error_data = ErrorData(code=INTERNAL_ERROR, message=f"Transport error: {exc}") + error_msg = SessionMessage(JSONRPCError(jsonrpc="2.0", id=message.id, error=error_data)) + with contextlib.suppress(anyio.BrokenResourceError, anyio.ClosedResourceError): + await ctx.read_stream_writer.send(error_msg) + else: + with contextlib.suppress(anyio.BrokenResourceError, anyio.ClosedResourceError): + await ctx.read_stream_writer.send(exc) + async def post_writer( self, client: httpx.AsyncClient, @@ -468,10 +482,13 @@ async def _handle_message(session_message: SessionMessage) -> None: ) async def handle_request_async(): - if is_resumption: - await self._handle_resumption_request(ctx) - else: - await self._handle_post_request(ctx) + try: + if is_resumption: + await self._handle_resumption_request(ctx) + else: + await self._handle_post_request(ctx) + except Exception as exc: + await self._handle_request_error(ctx, exc) # If this is a request, start a new task to handle it if isinstance(message, JSONRPCRequest): diff --git a/tests/client/test_notification_response.py b/tests/client/test_notification_response.py index 69c8afeb8..3a7e75cd2 100644 --- a/tests/client/test_notification_response.py +++ b/tests/client/test_notification_response.py @@ -6,6 +6,7 @@ import json +import anyio import httpx import pytest from starlette.applications import Starlette @@ -152,6 +153,21 @@ async def test_http_error_status_sends_jsonrpc_error() -> None: await session.list_tools() +async def test_transport_error_sends_jsonrpc_error() -> None: + """Verify request transport errors unblock the pending request with an MCPError.""" + + async def raise_connect_error(request: httpx.Request) -> httpx.Response: + raise httpx.ConnectError("All connection attempts failed", request=request) + + async with httpx.AsyncClient(transport=httpx.MockTransport(raise_connect_error)) as client: + async with streamable_http_client("http://localhost/mcp", http_client=client) as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: # pragma: no branch + with pytest.raises( + MCPError, match="Transport error: All connection attempts failed" + ): # pragma: no branch + await session.initialize() + + async def test_http_error_on_notification_does_not_hang() -> None: """Verify HTTP errors on notifications are silently ignored. @@ -168,6 +184,23 @@ async def test_http_error_on_notification_does_not_hang() -> None: await session.send_notification(RootsListChangedNotification(method="notifications/roots/list_changed")) +async def test_transport_error_on_notification_does_not_crash_transport() -> None: + """Verify transport errors on notifications do not crash the transport task group.""" + + async def handle_request(request: httpx.Request) -> httpx.Response: + data = json.loads(request.content) + if data.get("method") == "initialize": + return httpx.Response(200, json={"jsonrpc": "2.0", "id": data["id"], "result": INIT_RESPONSE}) + raise httpx.ConnectError("All connection attempts failed", request=request) + + async with httpx.AsyncClient(transport=httpx.MockTransport(handle_request)) as client: + async with streamable_http_client("http://localhost/mcp", http_client=client) as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: # pragma: no branch + await session.initialize() + await session.send_notification(RootsListChangedNotification(method="notifications/roots/list_changed")) + await anyio.sleep(0) + + def _create_invalid_json_response_app() -> Starlette: """Create a server that returns invalid JSON for requests.""" diff --git a/tests/client/test_session_group.py b/tests/client/test_session_group.py index 6a58b39f3..4dd5ad725 100644 --- a/tests/client/test_session_group.py +++ b/tests/client/test_session_group.py @@ -385,3 +385,24 @@ async def test_client_session_group_establish_session_parameterized( # 3. Assert returned values assert returned_server_info is mock_initialize_result.server_info assert returned_session is mock_entered_session + + +@pytest.mark.anyio +async def test_client_session_group_streamable_http_connect_error_is_catchable() -> None: + async def raise_connect_error(request: httpx.Request) -> httpx.Response: + raise httpx.ConnectError("All connection attempts failed", request=request) + + def mock_client_factory( + headers: dict[str, str] | None = None, + timeout: httpx.Timeout | None = None, + auth: httpx.Auth | None = None, + ) -> httpx.AsyncClient: + return httpx.AsyncClient(transport=httpx.MockTransport(raise_connect_error)) + + group = ClientSessionGroup() + async with group: + with mock.patch("mcp.client.session_group.create_mcp_http_client", side_effect=mock_client_factory): + with pytest.raises(MCPError, match="Transport error: All connection attempts failed"): # pragma: no branch + await group.connect_to_server(StreamableHttpParameters(url="http://localhost:3001/mcp/")) + + assert group.sessions == []