From 632ef64cd12d1aacee88aaa61e85e0091f38d661 Mon Sep 17 00:00:00 2001 From: Henry Lee Date: Fri, 1 May 2026 22:03:07 +0800 Subject: [PATCH] feat: allow overriding SSE messages endpoint --- src/mcp/client/sse.py | 20 ++++++++- tests/shared/test_sse.py | 93 +++++++++++++++++++++++++++++++++++++++- 2 files changed, 110 insertions(+), 3 deletions(-) diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 193204a15..f6443a74e 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -27,6 +27,20 @@ def _extract_session_id_from_endpoint(endpoint_url: str) -> str | None: return query_params.get("sessionId", [None])[0] or query_params.get("session_id", [None])[0] +def _resolve_endpoint_url(sse_url: str, endpoint_data: str, messages_url: str | None = None) -> str: + if messages_url is None: + return urljoin(sse_url, endpoint_data) + + endpoint_url = urljoin(sse_url, messages_url) + endpoint_query = urlparse(endpoint_data).query + if endpoint_query: + endpoint_parsed = urlparse(endpoint_url) + query = "&".join(filter(None, [endpoint_parsed.query, endpoint_query])) + endpoint_url = endpoint_parsed._replace(query=query).geturl() + + return endpoint_url + + @asynccontextmanager async def sse_client( url: str, @@ -36,6 +50,7 @@ async def sse_client( httpx_client_factory: McpHttpClientFactory = create_mcp_http_client, auth: httpx.Auth | None = None, on_session_created: Callable[[str], None] | None = None, + messages_url: str | None = None, ): """Client transport for SSE. @@ -50,6 +65,9 @@ async def sse_client( httpx_client_factory: Factory function for creating the HTTPX client. auth: Optional HTTPX authentication handler. on_session_created: Optional callback invoked with the session ID when received. + messages_url: Optional message endpoint URL to use instead of deriving it + from the SSE endpoint event. Relative URLs are resolved against `url`, + and any session query parameters from the endpoint event are preserved. """ logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}") async with httpx_client_factory( @@ -68,7 +86,7 @@ async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED): logger.debug(f"Received SSE event: {sse.event}") match sse.event: case "endpoint": - endpoint_url = urljoin(url, sse.data) + endpoint_url = _resolve_endpoint_url(url, sse.data, messages_url) logger.debug(f"Received endpoint URL: {endpoint_url}") url_parsed = urlparse(url) diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 5629a5707..09a2b397d 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -2,7 +2,7 @@ import multiprocessing import socket from collections.abc import AsyncGenerator, Generator -from typing import Any +from typing import Any, cast from unittest.mock import AsyncMock, MagicMock, Mock, patch from urllib.parse import urlparse @@ -20,11 +20,12 @@ import mcp.client.sse from mcp import types from mcp.client.session import ClientSession -from mcp.client.sse import _extract_session_id_from_endpoint, sse_client +from mcp.client.sse import _extract_session_id_from_endpoint, _resolve_endpoint_url, sse_client from mcp.server import Server, ServerRequestContext from mcp.server.sse import SseServerTransport from mcp.server.transport_security import TransportSecuritySettings from mcp.shared.exceptions import MCPError +from mcp.shared.message import SessionMessage from mcp.types import ( CallToolRequestParams, CallToolResult, @@ -229,6 +230,50 @@ def test_extract_session_id_from_endpoint(endpoint_url: str, expected: str | Non assert _extract_session_id_from_endpoint(endpoint_url) == expected +@pytest.mark.parametrize( + ("sse_url", "endpoint_data", "messages_url", "expected"), + [ + ( + "https://example.com/api/v1/sse", + "/v1/messages/?session_id=abc123", + None, + "https://example.com/v1/messages/?session_id=abc123", + ), + ( + "https://example.com/api/v1/sse", + "/v1/messages/?session_id=abc123", + "https://example.com/api/v1/messages/", + "https://example.com/api/v1/messages/?session_id=abc123", + ), + ( + "https://example.com/api/v1/sse", + "/v1/messages/?session_id=abc123", + "/api/v1/messages/", + "https://example.com/api/v1/messages/?session_id=abc123", + ), + ( + "https://example.com/api/v1/sse", + "/v1/messages/?session_id=abc123", + "https://example.com/api/v1/messages/?tenant=blue", + "https://example.com/api/v1/messages/?tenant=blue&session_id=abc123", + ), + ( + "https://example.com/api/v1/sse", + "/v1/messages/", + "https://example.com/api/v1/messages/", + "https://example.com/api/v1/messages/", + ), + ], +) +def test_resolve_endpoint_url_with_messages_url_override( + sse_url: str, + endpoint_data: str, + messages_url: str | None, + expected: str, +) -> None: + assert _resolve_endpoint_url(sse_url, endpoint_data, messages_url) == expected + + @pytest.mark.anyio async def test_sse_client_on_session_created_not_called_when_no_session_id( server: None, server_url: str, monkeypatch: pytest.MonkeyPatch @@ -249,6 +294,50 @@ def mock_extract(url: str) -> None: callback_mock.assert_not_called() +@pytest.mark.anyio +async def test_sse_client_uses_messages_url_override() -> None: + async def mock_aiter_sse() -> AsyncGenerator[ServerSentEvent, None]: + yield ServerSentEvent(event="endpoint", data="/v1/messages/?session_id=abc123") + await anyio.sleep_forever() + + mock_event_source = MagicMock() + mock_event_source.aiter_sse.return_value = mock_aiter_sse() + mock_event_source.response = MagicMock() + mock_event_source.response.raise_for_status = MagicMock() + + mock_aconnect_sse = MagicMock() + mock_aconnect_sse.__aenter__ = AsyncMock(return_value=mock_event_source) + mock_aconnect_sse.__aexit__ = AsyncMock(return_value=None) + + mock_client = MagicMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client.post = AsyncMock(return_value=MagicMock(status_code=200, raise_for_status=MagicMock())) + + def mock_httpx_client_factory( + headers: dict[str, str] | None = None, + timeout: httpx.Timeout | None = None, + auth: httpx.Auth | None = None, + ) -> httpx.AsyncClient: + _ = (headers, timeout, auth) + return cast(httpx.AsyncClient, mock_client) + + with patch("mcp.client.sse.aconnect_sse", return_value=mock_aconnect_sse): + async with sse_client( + "https://example.com/api/v1/sse", + httpx_client_factory=mock_httpx_client_factory, + messages_url="https://example.com/api/v1/messages/", + ) as (_, write_stream): + message = types.JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") + await write_stream.send(SessionMessage(message)) + with anyio.fail_after(1): # pragma: no branch + while not mock_client.post.await_count: + await anyio.sleep(0.01) + + mock_client.post.assert_awaited() + assert mock_client.post.await_args.args[0] == "https://example.com/api/v1/messages/?session_id=abc123" + + @pytest.fixture async def initialized_sse_client_session(server: None, server_url: str) -> AsyncGenerator[ClientSession, None]: async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams: