Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions src/mcp/client/session_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,9 @@ def __init__(
self._session_exit_stacks = {}
self._component_name_hook = component_name_hook

async def __aenter__(self) -> Self: # pragma: no cover
async def __aenter__(self) -> Self:
# Enter the exit stack only if we created it ourselves
if self._owns_exit_stack:
if self._owns_exit_stack: # pragma: no branch
await self._exit_stack.__aenter__()
return self

Expand All @@ -158,22 +158,22 @@ async def __aexit__(
_exc_type: type[BaseException] | None,
_exc_val: BaseException | None,
_exc_tb: TracebackType | None,
) -> bool | None: # pragma: no cover
) -> bool | None:
"""Closes session exit stacks and main exit stack upon completion."""

# Only close the main exit stack if we created it
if self._owns_exit_stack:
if self._owns_exit_stack: # pragma: no branch
await self._exit_stack.aclose()

# Concurrently close session stacks.
async with anyio.create_task_group() as tg:
for exit_stack in self._session_exit_stacks.values():
tg.start_soon(exit_stack.aclose)
tg.start_soon(exit_stack.aclose) # pragma: no cover

@property
def sessions(self) -> list[mcp.ClientSession]:
"""Returns the list of sessions being managed."""
return list(self._sessions.keys()) # pragma: no cover
return list(self._sessions.keys())

@property
def prompts(self) -> dict[str, types.Prompt]:
Expand Down Expand Up @@ -323,7 +323,7 @@ async def _establish_session(
await self._exit_stack.enter_async_context(session_stack)

return result.server_info, session
except Exception: # pragma: no cover
except Exception:
# If anything during this setup fails, ensure the session-specific
# stack is closed.
await session_stack.aclose()
Expand Down
25 changes: 21 additions & 4 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,20 @@ async def _handle_reconnection(
# Try to reconnect again if we still have an event ID
await self._handle_reconnection(ctx, last_event_id, retry_interval_ms, attempt + 1)

async def _handle_request_error(self, ctx: RequestContext, exc: Exception) -> None:
"""Report a request transport failure without crashing the transport task group."""
logger.debug("Error handling StreamableHTTP request", exc_info=True)

message = ctx.session_message.message
if isinstance(message, JSONRPCRequest):
error_data = ErrorData(code=INTERNAL_ERROR, message=f"Transport error: {exc}")
error_msg = SessionMessage(JSONRPCError(jsonrpc="2.0", id=message.id, error=error_data))
with contextlib.suppress(anyio.BrokenResourceError, anyio.ClosedResourceError):
await ctx.read_stream_writer.send(error_msg)
else:
with contextlib.suppress(anyio.BrokenResourceError, anyio.ClosedResourceError):
await ctx.read_stream_writer.send(exc)

async def post_writer(
self,
client: httpx.AsyncClient,
Expand Down Expand Up @@ -468,10 +482,13 @@ async def _handle_message(session_message: SessionMessage) -> None:
)

async def handle_request_async():
if is_resumption:
await self._handle_resumption_request(ctx)
else:
await self._handle_post_request(ctx)
try:
if is_resumption:
await self._handle_resumption_request(ctx)
else:
await self._handle_post_request(ctx)
except Exception as exc:
await self._handle_request_error(ctx, exc)

# If this is a request, start a new task to handle it
if isinstance(message, JSONRPCRequest):
Expand Down
33 changes: 33 additions & 0 deletions tests/client/test_notification_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import json

import anyio
import httpx
import pytest
from starlette.applications import Starlette
Expand Down Expand Up @@ -152,6 +153,21 @@ async def test_http_error_status_sends_jsonrpc_error() -> None:
await session.list_tools()


async def test_transport_error_sends_jsonrpc_error() -> None:
"""Verify request transport errors unblock the pending request with an MCPError."""

async def raise_connect_error(request: httpx.Request) -> httpx.Response:
raise httpx.ConnectError("All connection attempts failed", request=request)

async with httpx.AsyncClient(transport=httpx.MockTransport(raise_connect_error)) as client:
async with streamable_http_client("http://localhost/mcp", http_client=client) as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session: # pragma: no branch
with pytest.raises(
MCPError, match="Transport error: All connection attempts failed"
): # pragma: no branch
await session.initialize()


async def test_http_error_on_notification_does_not_hang() -> None:
"""Verify HTTP errors on notifications are silently ignored.

Expand All @@ -168,6 +184,23 @@ async def test_http_error_on_notification_does_not_hang() -> None:
await session.send_notification(RootsListChangedNotification(method="notifications/roots/list_changed"))


async def test_transport_error_on_notification_does_not_crash_transport() -> None:
"""Verify transport errors on notifications do not crash the transport task group."""

async def handle_request(request: httpx.Request) -> httpx.Response:
data = json.loads(request.content)
if data.get("method") == "initialize":
return httpx.Response(200, json={"jsonrpc": "2.0", "id": data["id"], "result": INIT_RESPONSE})
raise httpx.ConnectError("All connection attempts failed", request=request)

async with httpx.AsyncClient(transport=httpx.MockTransport(handle_request)) as client:
async with streamable_http_client("http://localhost/mcp", http_client=client) as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session: # pragma: no branch
await session.initialize()
await session.send_notification(RootsListChangedNotification(method="notifications/roots/list_changed"))
await anyio.sleep(0)


def _create_invalid_json_response_app() -> Starlette:
"""Create a server that returns invalid JSON for requests."""

Expand Down
21 changes: 21 additions & 0 deletions tests/client/test_session_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,3 +385,24 @@ async def test_client_session_group_establish_session_parameterized(
# 3. Assert returned values
assert returned_server_info is mock_initialize_result.server_info
assert returned_session is mock_entered_session


@pytest.mark.anyio
async def test_client_session_group_streamable_http_connect_error_is_catchable() -> None:
async def raise_connect_error(request: httpx.Request) -> httpx.Response:
raise httpx.ConnectError("All connection attempts failed", request=request)

def mock_client_factory(
headers: dict[str, str] | None = None,
timeout: httpx.Timeout | None = None,
auth: httpx.Auth | None = None,
) -> httpx.AsyncClient:
return httpx.AsyncClient(transport=httpx.MockTransport(raise_connect_error))

group = ClientSessionGroup()
async with group:
with mock.patch("mcp.client.session_group.create_mcp_http_client", side_effect=mock_client_factory):
with pytest.raises(MCPError, match="Transport error: All connection attempts failed"): # pragma: no branch
await group.connect_to_server(StreamableHttpParameters(url="http://localhost:3001/mcp/"))

assert group.sessions == []
Loading