diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index 72309f577..a4f847b6a 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -273,6 +273,12 @@ def __init__( self._validate_resource_url_callback = validate_resource_url self._initialized = False + def _copy_user_agent_header(self, request: httpx.Request, source_request: httpx.Request) -> httpx.Request: + user_agent = source_request.headers.get("User-Agent") + if user_agent and "User-Agent" not in request.headers: + request.headers["User-Agent"] = user_agent + return request + async def _handle_protected_resource_response(self, response: httpx.Response) -> bool: """Handle protected resource metadata discovery response. @@ -515,6 +521,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. if not self.context.is_token_valid() and self.context.can_refresh_token(): # Try to refresh token refresh_request = await self._refresh_token() # pragma: no cover + self._copy_user_agent_header(refresh_request, request) # pragma: no cover refresh_response = yield refresh_request # pragma: no cover if not await self._handle_refresh_response(refresh_response): # pragma: no cover @@ -539,6 +546,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. for url in prm_discovery_urls: # pragma: no branch discovery_request = create_oauth_metadata_request(url) + self._copy_user_agent_header(discovery_request, request) discovery_response = yield discovery_request # sending request @@ -565,6 +573,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. # Step 2: Discover OAuth Authorization Server Metadata (OASM) (with fallback for legacy servers) for url in asm_discovery_urls: # pragma: no branch oauth_metadata_request = create_oauth_metadata_request(url) + self._copy_user_agent_header(oauth_metadata_request, request) oauth_metadata_response = yield oauth_metadata_request ok, asm = await handle_auth_metadata_response(oauth_metadata_response) @@ -604,15 +613,18 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. self.context.client_metadata, self.context.get_authorization_base_url(self.context.server_url), ) + self._copy_user_agent_header(registration_request, request) registration_response = yield registration_request client_information = await handle_registration_response(registration_response) self.context.client_info = client_information await self.context.storage.set_client_info(client_information) # Step 5: Perform authorization and complete token exchange - token_response = yield await self._perform_authorization() + token_request = await self._perform_authorization() + self._copy_user_agent_header(token_request, request) + token_response = yield token_request await self._handle_token_response(token_response) - except Exception: # pragma: no cover + except Exception: logger.exception("OAuth flow error") raise @@ -635,7 +647,9 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. ) # Step 2b: Perform (re-)authorization and token exchange - token_response = yield await self._perform_authorization() + token_request = await self._perform_authorization() + self._copy_user_agent_header(token_request, request) + token_response = yield token_request await self._handle_token_response(token_response) except Exception: # pragma: no cover logger.exception("OAuth flow error") diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index bb0bce4c9..cb80b97d0 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -1167,6 +1167,59 @@ async def test_auth_flow_with_no_tokens(self, oauth_provider: OAuthClientProvide assert oauth_provider.context.current_tokens.access_token == "new_access_token" assert oauth_provider.context.token_expiry_time is not None + @pytest.mark.anyio + async def test_auth_flow_logs_and_reraises_oauth_errors( + self, oauth_provider: OAuthClientProvider, caplog: pytest.LogCaptureFixture + ): + """OAuth flow failures should be logged and re-raised.""" + oauth_provider.context.current_tokens = None + oauth_provider.context.token_expiry_time = None + oauth_provider.context.client_info = OAuthClientInformationFull( + client_id="existing_client", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + ) + oauth_provider._initialized = True + oauth_provider._perform_authorization = mock.AsyncMock(side_effect=RuntimeError("auth boom")) + + test_request = httpx.Request("GET", "https://api.example.com/v1/mcp") + auth_flow = oauth_provider.async_auth_flow(test_request) + + await auth_flow.__anext__() + + response = httpx.Response( + 401, + headers={ + "WWW-Authenticate": 'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"' + }, + request=test_request, + ) + discovery_request = await auth_flow.asend(response) + + discovery_response = httpx.Response( + 200, + content=( + b'{"resource": "https://api.example.com/v1/mcp", "authorization_servers": ["https://auth.example.com"]}' + ), + request=discovery_request, + ) + oauth_metadata_request = await auth_flow.asend(discovery_response) + + oauth_metadata_response = httpx.Response( + 200, + content=( + b'{"issuer": "https://auth.example.com", ' + b'"authorization_endpoint": "https://auth.example.com/authorize", ' + b'"token_endpoint": "https://auth.example.com/token"}' + ), + request=oauth_metadata_request, + ) + + with caplog.at_level("ERROR", logger="mcp.client.auth.oauth2"): + with pytest.raises(RuntimeError, match="auth boom"): + await auth_flow.asend(oauth_metadata_response) + + assert "OAuth flow error" in caplog.text + @pytest.mark.anyio async def test_auth_flow_no_unnecessary_retry_after_oauth( self, oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage, valid_tokens: OAuthToken @@ -1589,6 +1642,98 @@ async def callback_handler() -> tuple[str, str | None]: except StopAsyncIteration: pass + @pytest.mark.anyio + async def test_oauth_flow_forwards_user_agent_to_generated_auth_requests( + self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage + ): + """OAuth discovery, registration, and token requests should preserve the transport User-Agent.""" + + async def redirect_handler(url: str) -> None: + pass # pragma: no cover + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", "test_state" # pragma: no cover + + provider = OAuthClientProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + ) + provider._initialized = True + provider._perform_authorization_code_grant = mock.AsyncMock( + return_value=("test_auth_code", "test_code_verifier") + ) + + test_request = httpx.Request( + "POST", + "https://api.example.com/v1/mcp", + headers={"User-Agent": "custom-mcp-client/1.0"}, + ) + auth_flow = provider.async_auth_flow(test_request) + + try: + first_request = await auth_flow.__anext__() + assert first_request.headers["User-Agent"] == "custom-mcp-client/1.0" + + response = httpx.Response( + 401, + headers={ + "WWW-Authenticate": ( + 'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"' + ) + }, + request=test_request, + ) + + discovery_request = await auth_flow.asend(response) + assert discovery_request.headers["User-Agent"] == "custom-mcp-client/1.0" + + discovery_response = httpx.Response( + 200, + content=( + b'{"resource": "https://api.example.com/v1/mcp", ' + b'"authorization_servers": ["https://auth.example.com"]}' + ), + request=discovery_request, + ) + + oauth_metadata_request = await auth_flow.asend(discovery_response) + assert oauth_metadata_request.headers["User-Agent"] == "custom-mcp-client/1.0" + + oauth_metadata_response = httpx.Response( + 200, + content=( + b'{"issuer": "https://auth.example.com", ' + b'"authorization_endpoint": "https://auth.example.com/authorize", ' + b'"token_endpoint": "https://auth.example.com/token", ' + b'"registration_endpoint": "https://auth.example.com/register"}' + ), + request=oauth_metadata_request, + ) + + registration_request = await auth_flow.asend(oauth_metadata_response) + assert registration_request.headers["User-Agent"] == "custom-mcp-client/1.0" + + registration_response = httpx.Response( + 201, + content=( + b'{"client_id": "test_client", ' + b'"client_secret": "test_secret", ' + b'"redirect_uris": ["http://localhost:3030/callback"], ' + b'"token_endpoint_auth_method": "client_secret_post", ' + b'"grant_types": ["authorization_code"], ' + b'"response_types": ["code"]}' + ), + request=registration_request, + ) + + token_request = await auth_flow.asend(registration_response) + assert token_request.headers["User-Agent"] == "custom-mcp-client/1.0" + finally: + await auth_flow.aclose() + @pytest.mark.anyio async def test_legacy_server_with_different_prm_and_root_urls( self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage