diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index 434c07a1b..5e1a2076d 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -297,6 +297,16 @@ public class McpAsyncClient { notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_PROGRESS, asyncProgressNotificationHandler(progressConsumersFinal)); + // Elicitation Complete Notification + List>> elicitationCompleteConsumersFinal = new ArrayList<>(); + elicitationCompleteConsumersFinal + .add((notification) -> Mono.fromRunnable(() -> logger.debug("Elicitation complete: {}", notification))); + if (!Utils.isEmpty(features.elicitationCompleteConsumers())) { + elicitationCompleteConsumersFinal.addAll(features.elicitationCompleteConsumers()); + } + notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_ELICITATION_COMPLETE, + asyncElicitationCompleteNotificationHandler(elicitationCompleteConsumersFinal)); + Function> postInitializationHook = init -> { if (init.initializeResult().capabilities().tools() == null || !enableCallToolSchemaCaching) { @@ -1039,6 +1049,20 @@ private NotificationHandler asyncProgressNotificationHandler( }; } + private NotificationHandler asyncElicitationCompleteNotificationHandler( + List>> elicitationCompleteConsumers) { + + return params -> { + McpSchema.ElicitationCompleteNotification notification = transport.unmarshalFrom(params, + new TypeRef() { + }); + + return Flux.fromIterable(elicitationCompleteConsumers) + .flatMap(consumer -> consumer.apply(notification)) + .then(); + }; + } + /** * This method is package-private and used for test only. Should not be called by user * code. diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/McpClient.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpClient.java index 12f34e60a..fba502f80 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/client/McpClient.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpClient.java @@ -4,6 +4,15 @@ package io.modelcontextprotocol.client; +import java.time.Duration; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Supplier; + import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.json.McpJsonDefaults; import io.modelcontextprotocol.json.schema.JsonSchemaValidator; @@ -20,15 +29,6 @@ import io.modelcontextprotocol.util.Assert; import reactor.core.publisher.Mono; -import java.time.Duration; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.function.Consumer; -import java.util.function.Function; -import java.util.function.Supplier; - /** * Factory class for creating Model Context Protocol (MCP) clients. MCP is a protocol that * enables AI models to interact with external tools and resources through a standardized @@ -185,6 +185,8 @@ class SyncSpec { private final List> progressConsumers = new ArrayList<>(); + private final List> elicitationCompleteConsumers = new ArrayList<>(); + private Function samplingHandler; private Function elicitationHandler; @@ -437,6 +439,22 @@ public SyncSpec progressConsumers(List> return this; } + /** + * Adds a consumer to be notified when an elicitation complete notification is + * received from the server. This allows the client to react when a URL-mode + * elicitation flow has been completed. + * @param elicitationCompleteConsumer A consumer that receives elicitation + * complete notifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if elicitationCompleteConsumer is null + */ + public SyncSpec elicitationCompleteConsumer( + Consumer elicitationCompleteConsumer) { + Assert.notNull(elicitationCompleteConsumer, "Elicitation complete consumer must not be null"); + this.elicitationCompleteConsumers.add(elicitationCompleteConsumer); + return this; + } + /** * Add a provider of {@link McpTransportContext}, providing a context before * calling any client operation. This allows to extract thread-locals and hand @@ -488,7 +506,7 @@ public McpSyncClient build() { McpClientFeatures.Sync syncFeatures = new McpClientFeatures.Sync(this.clientInfo, this.capabilities, this.roots, this.toolsChangeConsumers, this.resourcesChangeConsumers, this.resourcesUpdateConsumers, this.promptsChangeConsumers, this.loggingConsumers, this.progressConsumers, this.samplingHandler, - this.elicitationHandler, this.enableCallToolSchemaCaching); + this.elicitationHandler, this.elicitationCompleteConsumers, this.enableCallToolSchemaCaching); McpClientFeatures.Async asyncFeatures = McpClientFeatures.Async.fromSync(syncFeatures); @@ -541,6 +559,8 @@ class AsyncSpec { private final List>> progressConsumers = new ArrayList<>(); + private final List>> elicitationCompleteConsumers = new ArrayList<>(); + private Function> samplingHandler; private Function> elicitationHandler; @@ -795,6 +815,23 @@ public AsyncSpec progressConsumers( return this; } + /** + * Adds a consumer to be notified when an elicitation complete notification is + * received from the server. This allows the client to react when a URL-mode + * elicitation flow has been completed. + * @param elicitationCompleteConsumer A function that receives elicitation + * complete notifications and returns a Mono indicating completion. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if elicitationCompleteConsumer is null + */ + public AsyncSpec elicitationCompleteConsumer( + Function> elicitationCompleteConsumer) { + Assert.notNull(elicitationCompleteConsumer, "Elicitation complete consumer must not be null"); + this.elicitationCompleteConsumers.add(elicitationCompleteConsumer); + return this; + } + /** * Sets the JSON schema validator to use for validating tool responses against * output schemas. @@ -833,7 +870,8 @@ public McpAsyncClient build() { new McpClientFeatures.Async(this.clientInfo, this.capabilities, this.roots, this.toolsChangeConsumers, this.resourcesChangeConsumers, this.resourcesUpdateConsumers, this.promptsChangeConsumers, this.loggingConsumers, this.progressConsumers, - this.samplingHandler, this.elicitationHandler, this.enableCallToolSchemaCaching)); + this.samplingHandler, this.elicitationHandler, this.elicitationCompleteConsumers, + this.enableCallToolSchemaCaching)); } } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java index 127d53337..a707011a6 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java @@ -62,6 +62,7 @@ class McpClientFeatures { * @param progressConsumers the progress consumers. * @param samplingHandler the sampling handler. * @param elicitationHandler the elicitation handler. + * @param elicitationCompleteConsumers the elicitation complete consumers. * @param enableCallToolSchemaCaching whether to enable call tool schema caching. */ record Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities, @@ -73,6 +74,7 @@ record Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c List>> progressConsumers, Function> samplingHandler, Function> elicitationHandler, + List>> elicitationCompleteConsumers, boolean enableCallToolSchemaCaching) { /** @@ -86,6 +88,7 @@ record Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c * @param progressConsumers the progress consumers. * @param samplingHandler the sampling handler. * @param elicitationHandler the elicitation handler. + * @param elicitationCompleteConsumers the elicitation complete consumers. * @param enableCallToolSchemaCaching whether to enable call tool schema caching. */ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities, @@ -98,6 +101,7 @@ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c List>> progressConsumers, Function> samplingHandler, Function> elicitationHandler, + List>> elicitationCompleteConsumers, boolean enableCallToolSchemaCaching) { Assert.notNull(clientInfo, "Client info must not be null"); @@ -117,6 +121,8 @@ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c this.progressConsumers = progressConsumers != null ? progressConsumers : List.of(); this.samplingHandler = samplingHandler; this.elicitationHandler = elicitationHandler; + this.elicitationCompleteConsumers = elicitationCompleteConsumers != null ? elicitationCompleteConsumers + : List.of(); this.enableCallToolSchemaCaching = enableCallToolSchemaCaching; } @@ -134,7 +140,7 @@ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c Function> elicitationHandler) { this(clientInfo, clientCapabilities, roots, toolsChangeConsumers, resourcesChangeConsumers, resourcesUpdateConsumers, promptsChangeConsumers, loggingConsumers, List.of(), samplingHandler, - elicitationHandler, false); + elicitationHandler, List.of(), false); } /** @@ -182,6 +188,13 @@ public static Async fromSync(Sync syncSpec) { .subscribeOn(Schedulers.boundedElastic())); } + List>> elicitationCompleteConsumers = new ArrayList<>(); + for (Consumer consumer : syncSpec + .elicitationCompleteConsumers()) { + elicitationCompleteConsumers.add(n -> Mono.fromRunnable(() -> consumer.accept(n)) + .subscribeOn(Schedulers.boundedElastic())); + } + Function> samplingHandler = r -> Mono .fromCallable(() -> syncSpec.samplingHandler().apply(r)) .subscribeOn(Schedulers.boundedElastic()); @@ -193,7 +206,7 @@ public static Async fromSync(Sync syncSpec) { return new Async(syncSpec.clientInfo(), syncSpec.clientCapabilities(), syncSpec.roots(), toolsChangeConsumers, resourcesChangeConsumers, resourcesUpdateConsumers, promptsChangeConsumers, loggingConsumers, progressConsumers, samplingHandler, elicitationHandler, - syncSpec.enableCallToolSchemaCaching); + elicitationCompleteConsumers, syncSpec.enableCallToolSchemaCaching); } } @@ -211,6 +224,7 @@ public static Async fromSync(Sync syncSpec) { * @param progressConsumers the progress consumers. * @param samplingHandler the sampling handler. * @param elicitationHandler the elicitation handler. + * @param elicitationCompleteConsumers the elicitation complete consumers. * @param enableCallToolSchemaCaching whether to enable call tool schema caching. */ public record Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities, @@ -222,6 +236,7 @@ public record Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabili List> progressConsumers, Function samplingHandler, Function elicitationHandler, + List> elicitationCompleteConsumers, boolean enableCallToolSchemaCaching) { /** @@ -237,6 +252,7 @@ public record Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabili * @param progressConsumers the progress consumers. * @param samplingHandler the sampling handler. * @param elicitationHandler the elicitation handler. + * @param elicitationCompleteConsumers the elicitation complete consumers. * @param enableCallToolSchemaCaching whether to enable call tool schema caching. */ public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities, @@ -248,6 +264,7 @@ public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities cl List> progressConsumers, Function samplingHandler, Function elicitationHandler, + List> elicitationCompleteConsumers, boolean enableCallToolSchemaCaching) { Assert.notNull(clientInfo, "Client info must not be null"); @@ -267,6 +284,8 @@ public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities cl this.progressConsumers = progressConsumers != null ? progressConsumers : List.of(); this.samplingHandler = samplingHandler; this.elicitationHandler = elicitationHandler; + this.elicitationCompleteConsumers = elicitationCompleteConsumers != null ? elicitationCompleteConsumers + : List.of(); this.enableCallToolSchemaCaching = enableCallToolSchemaCaching; } @@ -283,7 +302,7 @@ public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities cl Function elicitationHandler) { this(clientInfo, clientCapabilities, roots, toolsChangeConsumers, resourcesChangeConsumers, resourcesUpdateConsumers, promptsChangeConsumers, loggingConsumers, List.of(), samplingHandler, - elicitationHandler, false); + elicitationHandler, List.of(), false); } } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java index 40a76045b..b5fc633c0 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java @@ -4,17 +4,15 @@ package io.modelcontextprotocol.server; -import io.modelcontextprotocol.common.McpTransportContext; import java.util.ArrayList; import java.util.Collections; +import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.json.TypeRef; -import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpLoggableSession; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; -import io.modelcontextprotocol.spec.McpSession; import io.modelcontextprotocol.util.Assert; import reactor.core.publisher.Mono; @@ -152,10 +150,31 @@ public Mono createElicitation(McpSchema.ElicitRequest el if (this.clientCapabilities.elicitation() == null) { return Mono.error(new IllegalStateException("Client must be configured with elicitation capabilities")); } + if ("url".equals(elicitRequest.mode()) && this.clientCapabilities.elicitation().url() == null) { + return Mono.error(new IllegalStateException("Client must be configured with URL elicitation capabilities")); + } return this.session.sendRequest(McpSchema.METHOD_ELICITATION_CREATE, elicitRequest, ELICITATION_RESULT_TYPE_REF); } + /** + * Sends a notification to the client that an out-of-band URL elicitation interaction + * has completed. + * @param elicitationId The ID of the elicitation that completed + * @return A Mono that completes when the notification has been sent + */ + public Mono sendElicitationComplete(String elicitationId) { + if (this.clientCapabilities == null) { + return Mono + .error(new IllegalStateException("Client must be initialized. Call the initialize method first!")); + } + if (this.clientCapabilities.elicitation() == null || this.clientCapabilities.elicitation().url() == null) { + return Mono.error(new IllegalStateException("Client must be configured with URL elicitation capabilities")); + } + return this.session.sendNotification(McpSchema.METHOD_NOTIFICATION_ELICITATION_COMPLETE, + new McpSchema.ElicitationCompleteNotification(elicitationId, null)); + } + /** * Retrieves the list of all roots provided by the client. * @return A Mono that emits the list of roots result. diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java index 0b9115b79..5d66fb6f6 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java @@ -100,6 +100,15 @@ public McpSchema.ElicitResult createElicitation(McpSchema.ElicitRequest elicitRe return this.exchange.createElicitation(elicitRequest).block(); } + /** + * Sends a notification to the client that an out-of-band URL elicitation interaction + * has completed. + * @param elicitationId The ID of the elicitation that completed + */ + public void sendElicitationComplete(String elicitationId) { + this.exchange.sendElicitationComplete(elicitationId).block(); + } + /** * Retrieves the list of all roots provided by the client. * @return The list of roots result. diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpSchema.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpSchema.java index a3ed2dbde..b3e117b27 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpSchema.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpSchema.java @@ -10,6 +10,9 @@ import java.util.List; import java.util.Map; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; @@ -17,11 +20,10 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonSubTypes; import com.fasterxml.jackson.annotation.JsonTypeInfo; + import io.modelcontextprotocol.json.McpJsonMapper; import io.modelcontextprotocol.json.TypeRef; import io.modelcontextprotocol.util.Assert; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; /** * Based on the JSON-RPC 2.0 @@ -106,6 +108,8 @@ private McpSchema() { // Elicitation Methods public static final String METHOD_ELICITATION_CREATE = "elicitation/create"; + public static final String METHOD_NOTIFICATION_ELICITATION_COMPLETE = "notifications/elicitation/complete"; + // --------------------------- // JSON-RPC Error Codes // --------------------------- @@ -144,6 +148,12 @@ public static final class ErrorCodes { */ public static final int RESOURCE_NOT_FOUND = -32002; + /** + * URL elicitation required. The server requires the client to complete a URL mode + * elicitation before the request can proceed. + */ + public static final int URL_ELICITATION_REQUIRED = -32042; + } /** @@ -2051,21 +2061,32 @@ public CreateMessageResult build() { * A request from the server to elicit additional information from the user via the * client. * + * @param mode The elicitation mode: "form", "url", or null (defaults to form) * @param message The message to present to the user * @param requestedSchema A restricted subset of JSON Schema. Only top-level - * properties are allowed, without nesting + * properties are allowed, without nesting. Only valid for form mode. + * @param url The URL for the user to visit (URL mode only) + * @param elicitationId A unique identifier for this elicitation (URL mode only) * @param meta See specification for notes on _meta usage */ @JsonInclude(JsonInclude.Include.NON_ABSENT) @JsonIgnoreProperties(ignoreUnknown = true) public record ElicitRequest( // @formatter:off + @JsonProperty("mode") String mode, @JsonProperty("message") String message, @JsonProperty("requestedSchema") Map requestedSchema, + @JsonProperty("url") String url, + @JsonProperty("elicitationId") String elicitationId, @JsonProperty("_meta") Map meta) implements Request { // @formatter:on - // backwards compatibility constructor + // backwards compatibility constructor (form mode, no meta) public ElicitRequest(String message, Map requestedSchema) { - this(message, requestedSchema, null); + this(null, message, requestedSchema, null, null, null); + } + + // backwards compatibility constructor (form mode, with meta) + public ElicitRequest(String message, Map requestedSchema, Map meta) { + this(null, message, requestedSchema, null, null, meta); } public static Builder builder() { @@ -2074,12 +2095,23 @@ public static Builder builder() { public static class Builder { + private String mode; + private String message; private Map requestedSchema; + private String url; + + private String elicitationId; + private Map meta; + public Builder mode(String mode) { + this.mode = mode; + return this; + } + public Builder message(String message) { this.message = message; return this; @@ -2090,6 +2122,16 @@ public Builder requestedSchema(Map requestedSchema) { return this; } + public Builder url(String url) { + this.url = url; + return this; + } + + public Builder elicitationId(String elicitationId) { + this.elicitationId = elicitationId; + return this; + } + public Builder meta(Map meta) { this.meta = meta; return this; @@ -2104,7 +2146,18 @@ public Builder progressToken(Object progressToken) { } public ElicitRequest build() { - return new ElicitRequest(message, requestedSchema, meta); + if ("url".equals(this.mode)) { + if (this.url == null) { + throw new IllegalArgumentException("url must be non-null when mode is 'url'"); + } + if (this.elicitationId == null) { + throw new IllegalArgumentException("elicitationId must be non-null when mode is 'url'"); + } + if (this.requestedSchema != null) { + throw new IllegalArgumentException("requestedSchema must not be set when mode is 'url'"); + } + } + return new ElicitRequest(mode, message, requestedSchema, url, elicitationId, meta); } } @@ -2263,6 +2316,20 @@ public ResourcesUpdatedNotification(String uri) { } } + /** + * A notification from the server to the client indicating that an out-of-band URL + * elicitation interaction has completed. + * + * @param elicitationId The ID of the elicitation that completed + * @param meta See specification for notes on _meta usage + */ + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record ElicitationCompleteNotification(// @formatter:off + @JsonProperty("elicitationId") String elicitationId, + @JsonProperty("_meta") Map meta) implements Notification { // @formatter:on + } + /** * The Model Context Protocol (MCP) provides a standardized way for servers to send * structured log messages to clients. Clients can control logging verbosity by diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/McpAsyncServerExchangeTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/McpAsyncServerExchangeTests.java index e6161a59f..77cc4d5b5 100644 --- a/mcp-core/src/test/java/io/modelcontextprotocol/server/McpAsyncServerExchangeTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/McpAsyncServerExchangeTests.java @@ -4,31 +4,31 @@ package io.modelcontextprotocol.server; -import io.modelcontextprotocol.common.McpTransportContext; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Map; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpServerSession; -import io.modelcontextprotocol.json.TypeRef; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; -import reactor.core.publisher.Mono; -import reactor.test.StepVerifier; - import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import org.mockito.Mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import org.mockito.MockitoAnnotations; + +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpServerSession; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; /** * Tests for {@link McpAsyncServerExchange}. @@ -461,6 +461,119 @@ void testCreateElicitationWithSessionError() { }); } + // --------------------------------------- + // URL Elicitation Capability Gate Tests + // --------------------------------------- + + @Test + void testCreateElicitationUrlModeWithoutUrlCapability() { + // Given - Create exchange with Elicitation() (no url capability) + McpSchema.ClientCapabilities capabilitiesWithElicitation = McpSchema.ClientCapabilities.builder() + .elicitation() + .build(); + + McpAsyncServerExchange exchangeWithElicitation = new McpAsyncServerExchange("testSessionId", mockSession, + capabilitiesWithElicitation, clientInfo, McpTransportContext.EMPTY); + + McpSchema.ElicitRequest elicitRequest = McpSchema.ElicitRequest.builder() + .mode("url") + .message("Authorize GitHub") + .url("https://server.com/connect?id=abc-123") + .elicitationId("abc-123") + .build(); + + StepVerifier.create(exchangeWithElicitation.createElicitation(elicitRequest)).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(IllegalStateException.class) + .hasMessage("Client must be configured with URL elicitation capabilities"); + }); + + // Verify that sendRequest was never called + verify(mockSession, never()).sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), any(), any(TypeRef.class)); + } + + @Test + void testCreateElicitationFormModeWithEmptyElicitation() { + // Given - Create exchange with Elicitation() (empty = form only) + McpSchema.ClientCapabilities capabilitiesWithElicitation = McpSchema.ClientCapabilities.builder() + .elicitation() + .build(); + + McpAsyncServerExchange exchangeWithElicitation = new McpAsyncServerExchange("testSessionId", mockSession, + capabilitiesWithElicitation, clientInfo, McpTransportContext.EMPTY); + + McpSchema.ElicitRequest elicitRequest = McpSchema.ElicitRequest.builder() + .message("Please provide your name") + .build(); + + McpSchema.ElicitResult expectedResult = McpSchema.ElicitResult.builder() + .message(McpSchema.ElicitResult.Action.ACCEPT) + .build(); + + when(mockSession.sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), any(TypeRef.class))) + .thenReturn(Mono.just(expectedResult)); + + StepVerifier.create(exchangeWithElicitation.createElicitation(elicitRequest)).assertNext(result -> { + assertThat(result).isEqualTo(expectedResult); + }).verifyComplete(); + + // Verify that sendRequest was called + verify(mockSession, times(1)).sendRequest(eq(McpSchema.METHOD_ELICITATION_CREATE), eq(elicitRequest), + any(TypeRef.class)); + } + + @Test + void testSendElicitationCompleteWithoutUrlCapability() { + // Given - Create exchange with Elicitation() (no url capability) + McpSchema.ClientCapabilities capabilitiesWithElicitation = McpSchema.ClientCapabilities.builder() + .elicitation() + .build(); + + McpAsyncServerExchange exchangeWithElicitation = new McpAsyncServerExchange("testSessionId", mockSession, + capabilitiesWithElicitation, clientInfo, McpTransportContext.EMPTY); + + StepVerifier.create(exchangeWithElicitation.sendElicitationComplete("abc")).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(IllegalStateException.class) + .hasMessage("Client must be configured with URL elicitation capabilities"); + }); + + // Verify that sendNotification was never called + verify(mockSession, never()).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_ELICITATION_COMPLETE), any()); + } + + @Test + void testSendElicitationCompleteWithUrlCapability() { + // Given - Create exchange with Elicitation(form, url) + McpSchema.ClientCapabilities capabilitiesWithUrl = McpSchema.ClientCapabilities.builder() + .elicitation(true, true) + .build(); + + McpAsyncServerExchange exchangeWithUrl = new McpAsyncServerExchange("testSessionId", mockSession, + capabilitiesWithUrl, clientInfo, McpTransportContext.EMPTY); + + when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_ELICITATION_COMPLETE), any())) + .thenReturn(Mono.empty()); + + StepVerifier.create(exchangeWithUrl.sendElicitationComplete("abc")).verifyComplete(); + + // Verify that sendNotification was called + verify(mockSession, times(1)).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_ELICITATION_COMPLETE), any()); + } + + @Test + void testSendElicitationCompleteWithNullCapabilities() { + // Given - Create exchange with null capabilities + McpAsyncServerExchange exchangeWithNullCapabilities = new McpAsyncServerExchange("testSessionId", mockSession, + null, clientInfo, McpTransportContext.EMPTY); + + StepVerifier.create(exchangeWithNullCapabilities.sendElicitationComplete("abc")).verifyErrorSatisfies(error -> { + assertThat(error).isInstanceOf(IllegalStateException.class) + .hasMessage("Client must be initialized. Call the initialize method first!"); + }); + + // Verify that sendNotification was never called + verify(mockSession, never()).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_ELICITATION_COMPLETE), any()); + } + // --------------------------------------- // Create Message Tests // --------------------------------------- diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/McpSyncServerExchangeTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/McpSyncServerExchangeTests.java index fba733c9a..9e41d872b 100644 --- a/mcp-core/src/test/java/io/modelcontextprotocol/server/McpSyncServerExchangeTests.java +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/McpSyncServerExchangeTests.java @@ -9,25 +9,25 @@ import java.util.List; import java.util.Map; -import io.modelcontextprotocol.common.McpTransportContext; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpServerSession; -import io.modelcontextprotocol.json.TypeRef; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; -import reactor.core.publisher.Mono; - import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import org.mockito.Mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import org.mockito.MockitoAnnotations; + +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpServerSession; +import reactor.core.publisher.Mono; /** * Tests for {@link McpSyncServerExchange}. @@ -460,6 +460,49 @@ void testCreateElicitationWithSessionError() { .hasMessage("Session communication error"); } + // --------------------------------------- + // URL Elicitation Capability Gate Tests + // --------------------------------------- + + @Test + void testSendElicitationCompleteWithoutUrlCapability() { + // Given - Create exchange with Elicitation() (no url capability) + McpSchema.ClientCapabilities capabilitiesWithElicitation = McpSchema.ClientCapabilities.builder() + .elicitation() + .build(); + + McpAsyncServerExchange asyncExchangeWithElicitation = new McpAsyncServerExchange("testSessionId", mockSession, + capabilitiesWithElicitation, clientInfo, McpTransportContext.EMPTY); + McpSyncServerExchange exchangeWithElicitation = new McpSyncServerExchange(asyncExchangeWithElicitation); + + assertThatThrownBy(() -> exchangeWithElicitation.sendElicitationComplete("abc")) + .isInstanceOf(IllegalStateException.class) + .hasMessage("Client must be configured with URL elicitation capabilities"); + + // Verify that sendNotification was never called + verify(mockSession, never()).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_ELICITATION_COMPLETE), any()); + } + + @Test + void testSendElicitationCompleteWithUrlCapability() { + // Given - Create exchange with Elicitation(form, url) + McpSchema.ClientCapabilities capabilitiesWithUrl = McpSchema.ClientCapabilities.builder() + .elicitation(true, true) + .build(); + + McpAsyncServerExchange asyncExchangeWithUrl = new McpAsyncServerExchange("testSessionId", mockSession, + capabilitiesWithUrl, clientInfo, McpTransportContext.EMPTY); + McpSyncServerExchange exchangeWithUrl = new McpSyncServerExchange(asyncExchangeWithUrl); + + when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_ELICITATION_COMPLETE), any())) + .thenReturn(Mono.empty()); + + exchangeWithUrl.sendElicitationComplete("abc"); + + // Verify that sendNotification was called + verify(mockSession, times(1)).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_ELICITATION_COMPLETE), any()); + } + // --------------------------------------- // Create Message Tests // --------------------------------------- diff --git a/mcp-test/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java b/mcp-test/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java index 47a229afd..788975193 100644 --- a/mcp-test/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java +++ b/mcp-test/src/test/java/io/modelcontextprotocol/client/McpAsyncClientResponseHandlerTests.java @@ -8,6 +8,9 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import io.modelcontextprotocol.json.TypeRef; @@ -561,4 +564,69 @@ void testPingMessageRequestHandling() { asyncMcpClient.closeGracefully(); } + @Test + void testElicitationCompleteNotificationHandling() throws InterruptedException { + MockMcpClientTransport transport = initializationEnabledTransport(); + + // Track received notifications + AtomicReference received = new AtomicReference<>(); + CountDownLatch latch = new CountDownLatch(1); + + // Create a consumer for elicitation complete notifications + Function> elicitationCompleteConsumer = notification -> Mono + .fromRunnable(() -> { + received.set(notification); + latch.countDown(); + }); + + // Create client with elicitation complete consumer + McpAsyncClient asyncMcpClient = McpClient.async(transport) + .elicitationCompleteConsumer(elicitationCompleteConsumer) + .build(); + + assertThat(asyncMcpClient.initialize().block()).isNotNull(); + + // Simulate server sending notifications/elicitation/complete notification + McpSchema.ElicitationCompleteNotification completeNotification = new McpSchema.ElicitationCompleteNotification( + "test-elicitation-id", null); + McpSchema.JSONRPCNotification notification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, + McpSchema.METHOD_NOTIFICATION_ELICITATION_COMPLETE, completeNotification); + transport.simulateIncomingMessage(notification); + + // Wait for the consumer to be invoked + assertThat(latch.await(5, TimeUnit.SECONDS)).isTrue(); + + // Verify the consumer received the correct elicitationId + assertThat(received.get()).isNotNull(); + assertThat(received.get().elicitationId()).isEqualTo("test-elicitation-id"); + + asyncMcpClient.closeGracefully(); + } + + @Test + void testElicitationCompleteNotificationWithoutConsumer() throws InterruptedException { + MockMcpClientTransport transport = initializationEnabledTransport(); + + // Create client WITHOUT an elicitation complete consumer + McpAsyncClient asyncMcpClient = McpClient.async(transport).build(); + + assertThat(asyncMcpClient.initialize().block()).isNotNull(); + + // Simulate server sending notifications/elicitation/complete notification + // This should be silently ignored without error + McpSchema.ElicitationCompleteNotification completeNotification = new McpSchema.ElicitationCompleteNotification( + "test-elicitation-id", null); + McpSchema.JSONRPCNotification notification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, + McpSchema.METHOD_NOTIFICATION_ELICITATION_COMPLETE, completeNotification); + transport.simulateIncomingMessage(notification); + + // Give the async pipeline time to process + Thread.sleep(500); + + // Verify no error occurred — client is still functional + assertThat(asyncMcpClient.isInitialized()).isTrue(); + + asyncMcpClient.closeGracefully(); + } + } diff --git a/mcp-test/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java b/mcp-test/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java index 09529f2e0..c8b9af056 100644 --- a/mcp-test/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java +++ b/mcp-test/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java @@ -4,12 +4,6 @@ package io.modelcontextprotocol.spec; -import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; -import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; -import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - import java.io.IOException; import java.util.Arrays; import java.util.Collections; @@ -17,11 +11,16 @@ import java.util.List; import java.util.Map; -import io.modelcontextprotocol.spec.McpSchema.TextResourceContents; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import org.assertj.core.api.InstanceOfAssertFactories; import org.junit.jupiter.api.Test; import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.spec.McpSchema.TextResourceContents; +import static io.modelcontextprotocol.util.McpJsonMapperUtils.JSON_MAPPER; +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; import net.javacrumbs.jsonunit.core.Option; /** @@ -1725,6 +1724,204 @@ void testElicitationCapabilityBuilderFormOnly() throws Exception { assertThat(json).doesNotContain("\"url\""); } + // URL Elicitation Tests + + @Test + void testCreateUrlModeElicitRequest() throws Exception { + McpSchema.ElicitRequest request = McpSchema.ElicitRequest.builder() + .mode("url") + .message("Please authorize with GitHub") + .elicitationId("elicit-abc-123") + .url("https://example.com/authorize") + .build(); + + String value = JSON_MAPPER.writeValueAsString(request); + + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .isObject() + .isEqualTo( + json(""" + {"mode":"url","message":"Please authorize with GitHub","elicitationId":"elicit-abc-123","url":"https://example.com/authorize"}""")); + } + + @Test + void testUrlModeElicitRequestRoundTrip() throws Exception { + McpSchema.ElicitRequest original = McpSchema.ElicitRequest.builder() + .mode("url") + .message("Please authorize") + .elicitationId("elicit-abc-123") + .url("https://example.com/authorize") + .build(); + + String json = JSON_MAPPER.writeValueAsString(original); + McpSchema.ElicitRequest deserialized = JSON_MAPPER.readValue(json, McpSchema.ElicitRequest.class); + + assertThat(deserialized.mode()).isEqualTo("url"); + assertThat(deserialized.message()).isEqualTo("Please authorize"); + assertThat(deserialized.elicitationId()).isEqualTo("elicit-abc-123"); + assertThat(deserialized.url()).isEqualTo("https://example.com/authorize"); + assertThat(deserialized.requestedSchema()).isNull(); + } + + @Test + void testFormModeElicitRequestBackwardCompatConstructor() throws Exception { + // Existing 2-arg constructor should still work and produce JSON without mode + // field + McpSchema.ElicitRequest request = new McpSchema.ElicitRequest("Enter your name", + Map.of("type", "object", "properties", Map.of("name", Map.of("type", "string")))); + + String value = JSON_MAPPER.writeValueAsString(request); + + // Should NOT contain "mode" field (backward compat) + assertThatJson(value).isObject().doesNotContainKey("mode"); + assertThatJson(value).isObject().containsKey("message"); + assertThatJson(value).isObject().containsKey("requestedSchema"); + } + + @Test + void testFormModeElicitRequestRoundTrip() throws Exception { + McpSchema.ElicitRequest original = new McpSchema.ElicitRequest("Enter your name", + Map.of("type", "object", "properties", Map.of("name", Map.of("type", "string")))); + + String json = JSON_MAPPER.writeValueAsString(original); + McpSchema.ElicitRequest deserialized = JSON_MAPPER.readValue(json, McpSchema.ElicitRequest.class); + + assertThat(deserialized.mode()).isNull(); + assertThat(deserialized.message()).isEqualTo("Enter your name"); + assertThat(deserialized.requestedSchema()).isNotNull(); + assertThat(deserialized.url()).isNull(); + assertThat(deserialized.elicitationId()).isNull(); + } + + @Test + void testExplicitFormModeElicitRequest() throws Exception { + McpSchema.ElicitRequest request = McpSchema.ElicitRequest.builder() + .mode("form") + .message("Enter your name") + .requestedSchema(Map.of("type", "object", "properties", Map.of("name", Map.of("type", "string")))) + .build(); + + String value = JSON_MAPPER.writeValueAsString(request); + assertThatJson(value).isObject().containsEntry("mode", "form"); + + McpSchema.ElicitRequest deserialized = JSON_MAPPER.readValue(value, McpSchema.ElicitRequest.class); + assertThat(deserialized.mode()).isEqualTo("form"); + } + + @Test + void testElicitationCompleteNotificationSerialization() throws Exception { + McpSchema.ElicitationCompleteNotification notification = new McpSchema.ElicitationCompleteNotification( + "elicit-abc-123", null); + + String value = JSON_MAPPER.writeValueAsString(notification); + + assertThatJson(value).isObject().isEqualTo(json(""" + {"elicitationId":"elicit-abc-123"}""")); + } + + @Test + void testElicitationCompleteNotificationRoundTrip() throws Exception { + McpSchema.ElicitationCompleteNotification original = new McpSchema.ElicitationCompleteNotification( + "elicit-abc-123", null); + + String json = JSON_MAPPER.writeValueAsString(original); + McpSchema.ElicitationCompleteNotification deserialized = JSON_MAPPER.readValue(json, + McpSchema.ElicitationCompleteNotification.class); + + assertThat(deserialized.elicitationId()).isEqualTo("elicit-abc-123"); + } + + @Test + void testUrlModeElicitRequestBuilderRequiresUrl() { + assertThatThrownBy(() -> McpSchema.ElicitRequest.builder() + .mode("url") + .message("Authorize") + .elicitationId("abc") + // no url + .build()).isInstanceOf(IllegalArgumentException.class).hasMessageContaining("url must be non-null"); + } + + @Test + void testUrlModeElicitRequestBuilderRequiresElicitationId() { + assertThatThrownBy(() -> McpSchema.ElicitRequest.builder() + .mode("url") + .message("Authorize") + .url("https://example.com") + // no elicitationId + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("elicitationId must be non-null"); + } + + @Test + void testUrlModeElicitRequestBuilderRejectsRequestedSchema() { + assertThatThrownBy(() -> McpSchema.ElicitRequest.builder() + .mode("url") + .message("Authorize") + .url("https://example.com") + .elicitationId("abc") + .requestedSchema(Map.of("type", "object")) + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("requestedSchema must not be set"); + } + + @Test + void testUrlModeElicitRequestBuilderSuccess() { + McpSchema.ElicitRequest request = McpSchema.ElicitRequest.builder() + .mode("url") + .message("Authorize") + .url("https://example.com") + .elicitationId("abc") + .build(); + + assertThat(request.mode()).isEqualTo("url"); + assertThat(request.url()).isEqualTo("https://example.com"); + assertThat(request.elicitationId()).isEqualTo("abc"); + } + + @Test + void testUrlElicitationRequiredErrorCode() { + assertThat(McpSchema.ErrorCodes.URL_ELICITATION_REQUIRED).isEqualTo(-32042); + } + + @Test + void testElicitationCompleteMethodConstant() { + assertThat(McpSchema.METHOD_NOTIFICATION_ELICITATION_COMPLETE).isEqualTo("notifications/elicitation/complete"); + } + + @Test + void testElicitRequestToleratesUnknownFields() throws Exception { + String json = """ + {"message":"Enter your name","requestedSchema":{"type":"object","properties":{}},"futureField":"ignored"}"""; + + McpSchema.ElicitRequest request = JSON_MAPPER.readValue(json, McpSchema.ElicitRequest.class); + + assertThat(request.message()).isEqualTo("Enter your name"); + assertThat(request.requestedSchema()).isNotNull(); + } + + @Test + void testElicitationCompleteNotificationDeserializesWithoutMeta() throws Exception { + String json = """ + {"elicitationId":"abc-123"}"""; + + McpSchema.ElicitationCompleteNotification notification = JSON_MAPPER.readValue(json, + McpSchema.ElicitationCompleteNotification.class); + + assertThat(notification.elicitationId()).isEqualTo("abc-123"); + assertThat(notification.meta()).isNull(); + } + + @Test + void testElicitationCompleteNotificationToleratesUnknownFields() throws Exception { + String json = """ + {"elicitationId":"abc-123","futureField":42}"""; + + McpSchema.ElicitationCompleteNotification notification = JSON_MAPPER.readValue(json, + McpSchema.ElicitationCompleteNotification.class); + + assertThat(notification.elicitationId()).isEqualTo("abc-123"); + } + // Progress Notification Tests @Test