diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/DefaultSseMessageEndpointValidator.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/DefaultSseMessageEndpointValidator.java index 5ee9b85fd..4be5875db 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/DefaultSseMessageEndpointValidator.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/DefaultSseMessageEndpointValidator.java @@ -10,8 +10,8 @@ /** * Default {@link SseMessageEndpointValidator} that validates the {@code message} endpoint - * advertised by an SSE server. Message endpoints must be a relative URI, without path - * traversal or authority. + * advertised by an SSE server. Message endpoints must either have the same origin as the + * SSE uri, or be a relative uri. * * @author Daniel Garnier-Moiroux */ @@ -30,16 +30,19 @@ public void validate(URI sseUri, String messageEndpoint) throws InvalidSseMessag messageEndpoint); } - if (endpointUri.isAbsolute()) { - // Exclude absolute URIs e.g. https://example.com/mcp - throw new InvalidSseMessageEndpointException("messageEndpoint must be a relative path, not an absolute URI", - messageEndpoint); - } + if (endpointUri.isAbsolute() || endpointUri.getRawAuthority() != null) { + String scheme = endpointUri.getScheme(); + String host = endpointUri.getHost(); + int port = endpointUri.getPort(); - if (endpointUri.getRawAuthority() != null) { - // Exclude network paths e.g. //example.com/mcp - throw new InvalidSseMessageEndpointException( - "messageEndpoint must be a relative path and must not contain an authority", messageEndpoint); + boolean sameScheme = scheme != null && scheme.equalsIgnoreCase(sseUri.getScheme()); + boolean sameHost = host != null && host.equalsIgnoreCase(sseUri.getHost()); + boolean samePort = port == sseUri.getPort(); + + if (!sameScheme || !sameHost || !samePort) { + throw new InvalidSseMessageEndpointException( + "messageEndpoint must be a relative path or a same-origin URI", messageEndpoint); + } } // Exclude path-traversal diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/DefaultSseMessageEndpointValidatorTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/DefaultSseMessageEndpointValidatorTests.java index f1fc82850..cf2e045a1 100644 --- a/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/DefaultSseMessageEndpointValidatorTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/client/transport/DefaultSseMessageEndpointValidatorTests.java @@ -25,7 +25,7 @@ class DefaultSseMessageEndpointValidatorTests { private final DefaultSseMessageEndpointValidator validator = new DefaultSseMessageEndpointValidator(); @ParameterizedTest - @ValueSource(strings = { "/messages", "messages?session=abc", "/" }) + @ValueSource(strings = { "/messages", "messages?session=abc", "/", "https://mcp.example.com/messages" }) void valid(String endpoint) { assertThatCode(() -> validator.validate(SSE_URI, endpoint)).doesNotThrowAnyException(); } @@ -41,20 +41,20 @@ void invalidEmpty(String endpoint) { @ParameterizedTest @ValueSource(strings = { "/foo/../bar", "/foo/./bar", "../bar", "./bar", "/foo/%2E%2E/bar", "/foo/%2e/bar" }) void invalidPathTraversal(String endpoint) { - assertThatThrownBy(() -> validator.validate(SSE_URI, endpoint)).hasMessageContaining("path-traversal") + assertThatThrownBy(() -> validator.validate(SSE_URI, endpoint)) + .hasMessageContaining("must not contain path-traversal segments") .asInstanceOf(type(InvalidSseMessageEndpointException.class)) .extracting(InvalidSseMessageEndpointException::getMessageEndpoint) .isEqualTo(endpoint); } @ParameterizedTest - @ValueSource(strings = { "https://mcp.example.com/messages", "https://127.0.0.1/messages", - "https://mcp.example.com:8443/messages", "http://localhost:1234/messages", "file:///etc/passwd", - "gopher://mcp.example.com/_test" }) + @ValueSource(strings = { "https://127.0.0.1/messages", "https://mcp.example.com:8443/messages", + "http://localhost:1234/messages", "file:///etc/passwd", "gopher://mcp.example.com/_test" }) void invalidAbsoluteUris(String endpoint) { - // Even an absolute URI on the same origin must be rejected: the contract - // is that the messageEndpoint is a path-only relative reference. - assertThatThrownBy(() -> validator.validate(SSE_URI, endpoint)).hasMessageContaining("must be a relative path") + // Absolute URIs must be same-origin. + assertThatThrownBy(() -> validator.validate(SSE_URI, endpoint)) + .hasMessageContaining("must be a relative path or a same-origin URI") .asInstanceOf(type(InvalidSseMessageEndpointException.class)) .extracting(InvalidSseMessageEndpointException::getMessageEndpoint) .isEqualTo(endpoint); @@ -62,11 +62,12 @@ void invalidAbsoluteUris(String endpoint) { } @ParameterizedTest - @ValueSource(strings = { "//example/messages", "//user:secret@example/messages" }) + @ValueSource(strings = { "//example/messages", "//user:secret@example/messages", "//mcp.example.com/messages" }) void invalidNetworkReference(String endpoint) { // `//host/...` introduces an authority and is therefore not a pure path. + // It is missing a scheme, so it fails same-origin check. assertThatThrownBy(() -> validator.validate(SSE_URI, endpoint)) - .hasMessageContaining("must not contain an authority") + .hasMessageContaining("must be a relative path or a same-origin URI") .asInstanceOf(type(InvalidSseMessageEndpointException.class)) .extracting(InvalidSseMessageEndpointException::getMessageEndpoint) .isEqualTo(endpoint);