Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions src/mcp/client/auth/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,12 @@ def __init__(
self._validate_resource_url_callback = validate_resource_url
self._initialized = False

def _copy_user_agent_header(self, request: httpx.Request, source_request: httpx.Request) -> httpx.Request:
user_agent = source_request.headers.get("User-Agent")
if user_agent and "User-Agent" not in request.headers:
request.headers["User-Agent"] = user_agent
return request

async def _handle_protected_resource_response(self, response: httpx.Response) -> bool:
"""Handle protected resource metadata discovery response.

Expand Down Expand Up @@ -515,6 +521,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
if not self.context.is_token_valid() and self.context.can_refresh_token():
# Try to refresh token
refresh_request = await self._refresh_token() # pragma: no cover
self._copy_user_agent_header(refresh_request, request) # pragma: no cover
refresh_response = yield refresh_request # pragma: no cover

if not await self._handle_refresh_response(refresh_response): # pragma: no cover
Expand All @@ -539,6 +546,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.

for url in prm_discovery_urls: # pragma: no branch
discovery_request = create_oauth_metadata_request(url)
self._copy_user_agent_header(discovery_request, request)

discovery_response = yield discovery_request # sending request

Expand All @@ -565,6 +573,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
# Step 2: Discover OAuth Authorization Server Metadata (OASM) (with fallback for legacy servers)
for url in asm_discovery_urls: # pragma: no branch
oauth_metadata_request = create_oauth_metadata_request(url)
self._copy_user_agent_header(oauth_metadata_request, request)
oauth_metadata_response = yield oauth_metadata_request

ok, asm = await handle_auth_metadata_response(oauth_metadata_response)
Expand Down Expand Up @@ -604,15 +613,18 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
self.context.client_metadata,
self.context.get_authorization_base_url(self.context.server_url),
)
self._copy_user_agent_header(registration_request, request)
registration_response = yield registration_request
client_information = await handle_registration_response(registration_response)
self.context.client_info = client_information
await self.context.storage.set_client_info(client_information)

# Step 5: Perform authorization and complete token exchange
token_response = yield await self._perform_authorization()
token_request = await self._perform_authorization()
self._copy_user_agent_header(token_request, request)
token_response = yield token_request
await self._handle_token_response(token_response)
except Exception: # pragma: no cover
except Exception:
logger.exception("OAuth flow error")
raise

Expand All @@ -635,7 +647,9 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
)

# Step 2b: Perform (re-)authorization and token exchange
token_response = yield await self._perform_authorization()
token_request = await self._perform_authorization()
self._copy_user_agent_header(token_request, request)
token_response = yield token_request
await self._handle_token_response(token_response)
except Exception: # pragma: no cover
logger.exception("OAuth flow error")
Expand Down
145 changes: 145 additions & 0 deletions tests/client/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -1167,6 +1167,59 @@ async def test_auth_flow_with_no_tokens(self, oauth_provider: OAuthClientProvide
assert oauth_provider.context.current_tokens.access_token == "new_access_token"
assert oauth_provider.context.token_expiry_time is not None

@pytest.mark.anyio
async def test_auth_flow_logs_and_reraises_oauth_errors(
self, oauth_provider: OAuthClientProvider, caplog: pytest.LogCaptureFixture
):
"""OAuth flow failures should be logged and re-raised."""
oauth_provider.context.current_tokens = None
oauth_provider.context.token_expiry_time = None
oauth_provider.context.client_info = OAuthClientInformationFull(
client_id="existing_client",
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
)
oauth_provider._initialized = True
oauth_provider._perform_authorization = mock.AsyncMock(side_effect=RuntimeError("auth boom"))

test_request = httpx.Request("GET", "https://api.example.com/v1/mcp")
auth_flow = oauth_provider.async_auth_flow(test_request)

await auth_flow.__anext__()

response = httpx.Response(
401,
headers={
"WWW-Authenticate": 'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"'
},
request=test_request,
)
discovery_request = await auth_flow.asend(response)

discovery_response = httpx.Response(
200,
content=(
b'{"resource": "https://api.example.com/v1/mcp", "authorization_servers": ["https://auth.example.com"]}'
),
request=discovery_request,
)
oauth_metadata_request = await auth_flow.asend(discovery_response)

oauth_metadata_response = httpx.Response(
200,
content=(
b'{"issuer": "https://auth.example.com", '
b'"authorization_endpoint": "https://auth.example.com/authorize", '
b'"token_endpoint": "https://auth.example.com/token"}'
),
request=oauth_metadata_request,
)

with caplog.at_level("ERROR", logger="mcp.client.auth.oauth2"):
with pytest.raises(RuntimeError, match="auth boom"):
await auth_flow.asend(oauth_metadata_response)

assert "OAuth flow error" in caplog.text

@pytest.mark.anyio
async def test_auth_flow_no_unnecessary_retry_after_oauth(
self, oauth_provider: OAuthClientProvider, mock_storage: MockTokenStorage, valid_tokens: OAuthToken
Expand Down Expand Up @@ -1589,6 +1642,98 @@ async def callback_handler() -> tuple[str, str | None]:
except StopAsyncIteration:
pass

@pytest.mark.anyio
async def test_oauth_flow_forwards_user_agent_to_generated_auth_requests(
self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
):
"""OAuth discovery, registration, and token requests should preserve the transport User-Agent."""

async def redirect_handler(url: str) -> None:
pass # pragma: no cover

async def callback_handler() -> tuple[str, str | None]:
return "test_auth_code", "test_state" # pragma: no cover

provider = OAuthClientProvider(
server_url="https://api.example.com/v1/mcp",
client_metadata=client_metadata,
storage=mock_storage,
redirect_handler=redirect_handler,
callback_handler=callback_handler,
)
provider._initialized = True
provider._perform_authorization_code_grant = mock.AsyncMock(
return_value=("test_auth_code", "test_code_verifier")
)

test_request = httpx.Request(
"POST",
"https://api.example.com/v1/mcp",
headers={"User-Agent": "custom-mcp-client/1.0"},
)
auth_flow = provider.async_auth_flow(test_request)

try:
first_request = await auth_flow.__anext__()
assert first_request.headers["User-Agent"] == "custom-mcp-client/1.0"

response = httpx.Response(
401,
headers={
"WWW-Authenticate": (
'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"'
)
},
request=test_request,
)

discovery_request = await auth_flow.asend(response)
assert discovery_request.headers["User-Agent"] == "custom-mcp-client/1.0"

discovery_response = httpx.Response(
200,
content=(
b'{"resource": "https://api.example.com/v1/mcp", '
b'"authorization_servers": ["https://auth.example.com"]}'
),
request=discovery_request,
)

oauth_metadata_request = await auth_flow.asend(discovery_response)
assert oauth_metadata_request.headers["User-Agent"] == "custom-mcp-client/1.0"

oauth_metadata_response = httpx.Response(
200,
content=(
b'{"issuer": "https://auth.example.com", '
b'"authorization_endpoint": "https://auth.example.com/authorize", '
b'"token_endpoint": "https://auth.example.com/token", '
b'"registration_endpoint": "https://auth.example.com/register"}'
),
request=oauth_metadata_request,
)

registration_request = await auth_flow.asend(oauth_metadata_response)
assert registration_request.headers["User-Agent"] == "custom-mcp-client/1.0"

registration_response = httpx.Response(
201,
content=(
b'{"client_id": "test_client", '
b'"client_secret": "test_secret", '
b'"redirect_uris": ["http://localhost:3030/callback"], '
b'"token_endpoint_auth_method": "client_secret_post", '
b'"grant_types": ["authorization_code"], '
b'"response_types": ["code"]}'
),
request=registration_request,
)

token_request = await auth_flow.asend(registration_response)
assert token_request.headers["User-Agent"] == "custom-mcp-client/1.0"
finally:
await auth_flow.aclose()

@pytest.mark.anyio
async def test_legacy_server_with_different_prm_and_root_urls(
self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
Expand Down
Loading