From 9fba13a5b5f430af8dcded55d47ca95780254941 Mon Sep 17 00:00:00 2001 From: Genmin Date: Thu, 30 Apr 2026 21:40:16 -0700 Subject: [PATCH 1/4] fix oauth auth flow user agent forwarding --- src/mcp/client/auth/oauth2.py | 18 ++++++- tests/client/test_auth.py | 92 +++++++++++++++++++++++++++++++++++ 2 files changed, 108 insertions(+), 2 deletions(-) diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index 72309f577..2a39964dd 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -273,6 +273,12 @@ def __init__( self._validate_resource_url_callback = validate_resource_url self._initialized = False + def _copy_user_agent_header(self, request: httpx.Request, source_request: httpx.Request) -> httpx.Request: + user_agent = source_request.headers.get("User-Agent") + if user_agent and "User-Agent" not in request.headers: + request.headers["User-Agent"] = user_agent + return request + async def _handle_protected_resource_response(self, response: httpx.Response) -> bool: """Handle protected resource metadata discovery response. @@ -515,6 +521,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. if not self.context.is_token_valid() and self.context.can_refresh_token(): # Try to refresh token refresh_request = await self._refresh_token() # pragma: no cover + self._copy_user_agent_header(refresh_request, request) # pragma: no cover refresh_response = yield refresh_request # pragma: no cover if not await self._handle_refresh_response(refresh_response): # pragma: no cover @@ -539,6 +546,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. for url in prm_discovery_urls: # pragma: no branch discovery_request = create_oauth_metadata_request(url) + self._copy_user_agent_header(discovery_request, request) discovery_response = yield discovery_request # sending request @@ -565,6 +573,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. # Step 2: Discover OAuth Authorization Server Metadata (OASM) (with fallback for legacy servers) for url in asm_discovery_urls: # pragma: no branch oauth_metadata_request = create_oauth_metadata_request(url) + self._copy_user_agent_header(oauth_metadata_request, request) oauth_metadata_response = yield oauth_metadata_request ok, asm = await handle_auth_metadata_response(oauth_metadata_response) @@ -604,13 +613,16 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. self.context.client_metadata, self.context.get_authorization_base_url(self.context.server_url), ) + self._copy_user_agent_header(registration_request, request) registration_response = yield registration_request client_information = await handle_registration_response(registration_response) self.context.client_info = client_information await self.context.storage.set_client_info(client_information) # Step 5: Perform authorization and complete token exchange - token_response = yield await self._perform_authorization() + token_request = await self._perform_authorization() + self._copy_user_agent_header(token_request, request) + token_response = yield token_request await self._handle_token_response(token_response) except Exception: # pragma: no cover logger.exception("OAuth flow error") @@ -635,7 +647,9 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. ) # Step 2b: Perform (re-)authorization and token exchange - token_response = yield await self._perform_authorization() + token_request = await self._perform_authorization() + self._copy_user_agent_header(token_request, request) + token_response = yield token_request await self._handle_token_response(token_response) except Exception: # pragma: no cover logger.exception("OAuth flow error") diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index bb0bce4c9..fea5fb157 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -1589,6 +1589,98 @@ async def callback_handler() -> tuple[str, str | None]: except StopAsyncIteration: pass + @pytest.mark.anyio + async def test_oauth_flow_forwards_user_agent_to_generated_auth_requests( + self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage + ): + """OAuth discovery, registration, and token requests should preserve the transport User-Agent.""" + + async def redirect_handler(url: str) -> None: + pass # pragma: no cover + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", "test_state" # pragma: no cover + + provider = OAuthClientProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + ) + provider._initialized = True + provider._perform_authorization_code_grant = mock.AsyncMock( + return_value=("test_auth_code", "test_code_verifier") + ) + + test_request = httpx.Request( + "POST", + "https://api.example.com/v1/mcp", + headers={"User-Agent": "custom-mcp-client/1.0"}, + ) + auth_flow = provider.async_auth_flow(test_request) + + try: + first_request = await auth_flow.__anext__() + assert first_request.headers["User-Agent"] == "custom-mcp-client/1.0" + + response = httpx.Response( + 401, + headers={ + "WWW-Authenticate": ( + 'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"' + ) + }, + request=test_request, + ) + + discovery_request = await auth_flow.asend(response) + assert discovery_request.headers["User-Agent"] == "custom-mcp-client/1.0" + + discovery_response = httpx.Response( + 200, + content=( + b'{"resource": "https://api.example.com/v1/mcp", ' + b'"authorization_servers": ["https://auth.example.com"]}' + ), + request=discovery_request, + ) + + oauth_metadata_request = await auth_flow.asend(discovery_response) + assert oauth_metadata_request.headers["User-Agent"] == "custom-mcp-client/1.0" + + oauth_metadata_response = httpx.Response( + 200, + content=( + b'{"issuer": "https://auth.example.com", ' + b'"authorization_endpoint": "https://auth.example.com/authorize", ' + b'"token_endpoint": "https://auth.example.com/token", ' + b'"registration_endpoint": "https://auth.example.com/register"}' + ), + request=oauth_metadata_request, + ) + + registration_request = await auth_flow.asend(oauth_metadata_response) + assert registration_request.headers["User-Agent"] == "custom-mcp-client/1.0" + + registration_response = httpx.Response( + 201, + content=( + b'{"client_id": "test_client", ' + b'"client_secret": "test_secret", ' + b'"redirect_uris": ["http://localhost:3030/callback"], ' + b'"token_endpoint_auth_method": "client_secret_post", ' + b'"grant_types": ["authorization_code"], ' + b'"response_types": ["code"]}' + ), + request=registration_request, + ) + + token_request = await auth_flow.asend(registration_response) + assert token_request.headers["User-Agent"] == "custom-mcp-client/1.0" + finally: + await auth_flow.aclose() + @pytest.mark.anyio async def test_legacy_server_with_different_prm_and_root_urls( self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage From 03aac76ad35c188d9dd15400a78dd6a3e9133036 Mon Sep 17 00:00:00 2001 From: Genmin Date: Thu, 30 Apr 2026 21:48:38 -0700 Subject: [PATCH 2/4] test: remove covered oauth no-cover marker --- src/mcp/client/auth/oauth2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index 2a39964dd..a4f847b6a 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -624,7 +624,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. self._copy_user_agent_header(token_request, request) token_response = yield token_request await self._handle_token_response(token_response) - except Exception: # pragma: no cover + except Exception: logger.exception("OAuth flow error") raise From 5cca3b527aa93f0df0c6e01f2af0a6e5520d918a Mon Sep 17 00:00:00 2001 From: Genmin Date: Thu, 30 Apr 2026 22:04:54 -0700 Subject: [PATCH 3/4] test oauth flow exception logging --- tests/client/test_auth.py | 51 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index fea5fb157..3f9f14d8b 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -1167,6 +1167,57 @@ async def test_auth_flow_with_no_tokens(self, oauth_provider: OAuthClientProvide assert oauth_provider.context.current_tokens.access_token == "new_access_token" assert oauth_provider.context.token_expiry_time is not None + @pytest.mark.anyio + async def test_auth_flow_logs_and_reraises_oauth_errors(self, oauth_provider: OAuthClientProvider, caplog): + """OAuth flow failures should be logged and re-raised.""" + oauth_provider.context.current_tokens = None + oauth_provider.context.token_expiry_time = None + oauth_provider.context.client_info = OAuthClientInformationFull( + client_id="existing_client", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + ) + oauth_provider._initialized = True + oauth_provider._perform_authorization = mock.AsyncMock(side_effect=RuntimeError("auth boom")) + + test_request = httpx.Request("GET", "https://api.example.com/v1/mcp") + auth_flow = oauth_provider.async_auth_flow(test_request) + + await auth_flow.__anext__() + + response = httpx.Response( + 401, + headers={ + "WWW-Authenticate": 'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"' + }, + request=test_request, + ) + discovery_request = await auth_flow.asend(response) + + discovery_response = httpx.Response( + 200, + content=( + b'{"resource": "https://api.example.com/v1/mcp", "authorization_servers": ["https://auth.example.com"]}' + ), + request=discovery_request, + ) + oauth_metadata_request = await auth_flow.asend(discovery_response) + + oauth_metadata_response = httpx.Response( + 200, + content=( + b'{"issuer": "https://auth.example.com", ' + b'"authorization_endpoint": "https://auth.example.com/authorize", ' + b'"token_endpoint": "https://auth.example.com/token"}' + ), + request=oauth_metadata_request, + ) + + with caplog.at_level("ERROR", logger="mcp.client.auth.oauth2"): + with pytest.raises(RuntimeError, match="auth boom"): + await auth_flow.asend(oauth_metadata_response) + + assert "OAuth flow error" in caplog.text + @pytest.mark.anyio async def test_auth_flow_no_unnecessary_retry_after_oauth( self, oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage, valid_tokens: OAuthToken From 78eafacac7c071dbade1dbbf83db97e343c278a8 Mon Sep 17 00:00:00 2001 From: Genmin Date: Thu, 30 Apr 2026 22:19:23 -0700 Subject: [PATCH 4/4] test: type caplog fixture --- tests/client/test_auth.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 3f9f14d8b..cb80b97d0 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -1168,7 +1168,9 @@ async def test_auth_flow_with_no_tokens(self, oauth_provider: OAuthClientProvide assert oauth_provider.context.token_expiry_time is not None @pytest.mark.anyio - async def test_auth_flow_logs_and_reraises_oauth_errors(self, oauth_provider: OAuthClientProvider, caplog): + async def test_auth_flow_logs_and_reraises_oauth_errors( + self, oauth_provider: OAuthClientProvider, caplog: pytest.LogCaptureFixture + ): """OAuth flow failures should be logged and re-raised.""" oauth_provider.context.current_tokens = None oauth_provider.context.token_expiry_time = None