Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,16 @@ public class McpAsyncClient {
notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_PROGRESS,
asyncProgressNotificationHandler(progressConsumersFinal));

// Elicitation Complete Notification
List<Function<McpSchema.ElicitationCompleteNotification, Mono<Void>>> elicitationCompleteConsumersFinal = new ArrayList<>();
elicitationCompleteConsumersFinal
.add((notification) -> Mono.fromRunnable(() -> logger.debug("Elicitation complete: {}", notification)));
if (!Utils.isEmpty(features.elicitationCompleteConsumers())) {
elicitationCompleteConsumersFinal.addAll(features.elicitationCompleteConsumers());
}
notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_ELICITATION_COMPLETE,
asyncElicitationCompleteNotificationHandler(elicitationCompleteConsumersFinal));

Function<Initialization, Mono<Void>> postInitializationHook = init -> {

if (init.initializeResult().capabilities().tools() == null || !enableCallToolSchemaCaching) {
Expand Down Expand Up @@ -1039,6 +1049,20 @@ private NotificationHandler asyncProgressNotificationHandler(
};
}

private NotificationHandler asyncElicitationCompleteNotificationHandler(
List<Function<McpSchema.ElicitationCompleteNotification, Mono<Void>>> elicitationCompleteConsumers) {

return params -> {
McpSchema.ElicitationCompleteNotification notification = transport.unmarshalFrom(params,
new TypeRef<McpSchema.ElicitationCompleteNotification>() {
});

return Flux.fromIterable(elicitationCompleteConsumers)
.flatMap(consumer -> consumer.apply(notification))
.then();
};
}

/**
* This method is package-private and used for test only. Should not be called by user
* code.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,15 @@

package io.modelcontextprotocol.client;

import java.time.Duration;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;

import io.modelcontextprotocol.common.McpTransportContext;
import io.modelcontextprotocol.json.McpJsonDefaults;
import io.modelcontextprotocol.json.schema.JsonSchemaValidator;
Expand All @@ -20,15 +29,6 @@
import io.modelcontextprotocol.util.Assert;
import reactor.core.publisher.Mono;

import java.time.Duration;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;

/**
* Factory class for creating Model Context Protocol (MCP) clients. MCP is a protocol that
* enables AI models to interact with external tools and resources through a standardized
Expand Down Expand Up @@ -185,6 +185,8 @@ class SyncSpec {

private final List<Consumer<McpSchema.ProgressNotification>> progressConsumers = new ArrayList<>();

private final List<Consumer<McpSchema.ElicitationCompleteNotification>> elicitationCompleteConsumers = new ArrayList<>();

private Function<CreateMessageRequest, CreateMessageResult> samplingHandler;

private Function<ElicitRequest, ElicitResult> elicitationHandler;
Expand Down Expand Up @@ -437,6 +439,22 @@ public SyncSpec progressConsumers(List<Consumer<McpSchema.ProgressNotification>>
return this;
}

/**
* Adds a consumer to be notified when an elicitation complete notification is
* received from the server. This allows the client to react when a URL-mode
* elicitation flow has been completed.
* @param elicitationCompleteConsumer A consumer that receives elicitation
* complete notifications. Must not be null.
* @return This builder instance for method chaining
* @throws IllegalArgumentException if elicitationCompleteConsumer is null
*/
public SyncSpec elicitationCompleteConsumer(
Consumer<McpSchema.ElicitationCompleteNotification> elicitationCompleteConsumer) {
Assert.notNull(elicitationCompleteConsumer, "Elicitation complete consumer must not be null");
this.elicitationCompleteConsumers.add(elicitationCompleteConsumer);
return this;
}

/**
* Add a provider of {@link McpTransportContext}, providing a context before
* calling any client operation. This allows to extract thread-locals and hand
Expand Down Expand Up @@ -488,7 +506,7 @@ public McpSyncClient build() {
McpClientFeatures.Sync syncFeatures = new McpClientFeatures.Sync(this.clientInfo, this.capabilities,
this.roots, this.toolsChangeConsumers, this.resourcesChangeConsumers, this.resourcesUpdateConsumers,
this.promptsChangeConsumers, this.loggingConsumers, this.progressConsumers, this.samplingHandler,
this.elicitationHandler, this.enableCallToolSchemaCaching);
this.elicitationHandler, this.elicitationCompleteConsumers, this.enableCallToolSchemaCaching);

McpClientFeatures.Async asyncFeatures = McpClientFeatures.Async.fromSync(syncFeatures);

Expand Down Expand Up @@ -541,6 +559,8 @@ class AsyncSpec {

private final List<Function<McpSchema.ProgressNotification, Mono<Void>>> progressConsumers = new ArrayList<>();

private final List<Function<McpSchema.ElicitationCompleteNotification, Mono<Void>>> elicitationCompleteConsumers = new ArrayList<>();

private Function<CreateMessageRequest, Mono<CreateMessageResult>> samplingHandler;

private Function<ElicitRequest, Mono<ElicitResult>> elicitationHandler;
Expand Down Expand Up @@ -795,6 +815,23 @@ public AsyncSpec progressConsumers(
return this;
}

/**
* Adds a consumer to be notified when an elicitation complete notification is
* received from the server. This allows the client to react when a URL-mode
* elicitation flow has been completed.
* @param elicitationCompleteConsumer A function that receives elicitation
* complete notifications and returns a Mono indicating completion. Must not be
* null.
* @return This builder instance for method chaining
* @throws IllegalArgumentException if elicitationCompleteConsumer is null
*/
public AsyncSpec elicitationCompleteConsumer(
Function<McpSchema.ElicitationCompleteNotification, Mono<Void>> elicitationCompleteConsumer) {
Assert.notNull(elicitationCompleteConsumer, "Elicitation complete consumer must not be null");
this.elicitationCompleteConsumers.add(elicitationCompleteConsumer);
return this;
}

/**
* Sets the JSON schema validator to use for validating tool responses against
* output schemas.
Expand Down Expand Up @@ -833,7 +870,8 @@ public McpAsyncClient build() {
new McpClientFeatures.Async(this.clientInfo, this.capabilities, this.roots,
this.toolsChangeConsumers, this.resourcesChangeConsumers, this.resourcesUpdateConsumers,
this.promptsChangeConsumers, this.loggingConsumers, this.progressConsumers,
this.samplingHandler, this.elicitationHandler, this.enableCallToolSchemaCaching));
this.samplingHandler, this.elicitationHandler, this.elicitationCompleteConsumers,
this.enableCallToolSchemaCaching));
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class McpClientFeatures {
* @param progressConsumers the progress consumers.
* @param samplingHandler the sampling handler.
* @param elicitationHandler the elicitation handler.
* @param elicitationCompleteConsumers the elicitation complete consumers.
* @param enableCallToolSchemaCaching whether to enable call tool schema caching.
*/
record Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities,
Expand All @@ -73,6 +74,7 @@ record Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c
List<Function<McpSchema.ProgressNotification, Mono<Void>>> progressConsumers,
Function<McpSchema.CreateMessageRequest, Mono<McpSchema.CreateMessageResult>> samplingHandler,
Function<McpSchema.ElicitRequest, Mono<McpSchema.ElicitResult>> elicitationHandler,
List<Function<McpSchema.ElicitationCompleteNotification, Mono<Void>>> elicitationCompleteConsumers,
boolean enableCallToolSchemaCaching) {

/**
Expand All @@ -86,6 +88,7 @@ record Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c
* @param progressConsumers the progress consumers.
* @param samplingHandler the sampling handler.
* @param elicitationHandler the elicitation handler.
* @param elicitationCompleteConsumers the elicitation complete consumers.
* @param enableCallToolSchemaCaching whether to enable call tool schema caching.
*/
public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities,
Expand All @@ -98,6 +101,7 @@ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c
List<Function<McpSchema.ProgressNotification, Mono<Void>>> progressConsumers,
Function<McpSchema.CreateMessageRequest, Mono<McpSchema.CreateMessageResult>> samplingHandler,
Function<McpSchema.ElicitRequest, Mono<McpSchema.ElicitResult>> elicitationHandler,
List<Function<McpSchema.ElicitationCompleteNotification, Mono<Void>>> elicitationCompleteConsumers,
boolean enableCallToolSchemaCaching) {

Assert.notNull(clientInfo, "Client info must not be null");
Expand All @@ -117,6 +121,8 @@ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c
this.progressConsumers = progressConsumers != null ? progressConsumers : List.of();
this.samplingHandler = samplingHandler;
this.elicitationHandler = elicitationHandler;
this.elicitationCompleteConsumers = elicitationCompleteConsumers != null ? elicitationCompleteConsumers
: List.of();
this.enableCallToolSchemaCaching = enableCallToolSchemaCaching;
}

Expand All @@ -134,7 +140,7 @@ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c
Function<McpSchema.ElicitRequest, Mono<McpSchema.ElicitResult>> elicitationHandler) {
this(clientInfo, clientCapabilities, roots, toolsChangeConsumers, resourcesChangeConsumers,
resourcesUpdateConsumers, promptsChangeConsumers, loggingConsumers, List.of(), samplingHandler,
elicitationHandler, false);
elicitationHandler, List.of(), false);
}

/**
Expand Down Expand Up @@ -182,6 +188,13 @@ public static Async fromSync(Sync syncSpec) {
.subscribeOn(Schedulers.boundedElastic()));
}

List<Function<McpSchema.ElicitationCompleteNotification, Mono<Void>>> elicitationCompleteConsumers = new ArrayList<>();
for (Consumer<McpSchema.ElicitationCompleteNotification> consumer : syncSpec
.elicitationCompleteConsumers()) {
elicitationCompleteConsumers.add(n -> Mono.<Void>fromRunnable(() -> consumer.accept(n))
.subscribeOn(Schedulers.boundedElastic()));
}

Function<McpSchema.CreateMessageRequest, Mono<McpSchema.CreateMessageResult>> samplingHandler = r -> Mono
.fromCallable(() -> syncSpec.samplingHandler().apply(r))
.subscribeOn(Schedulers.boundedElastic());
Expand All @@ -193,7 +206,7 @@ public static Async fromSync(Sync syncSpec) {
return new Async(syncSpec.clientInfo(), syncSpec.clientCapabilities(), syncSpec.roots(),
toolsChangeConsumers, resourcesChangeConsumers, resourcesUpdateConsumers, promptsChangeConsumers,
loggingConsumers, progressConsumers, samplingHandler, elicitationHandler,
syncSpec.enableCallToolSchemaCaching);
elicitationCompleteConsumers, syncSpec.enableCallToolSchemaCaching);
}
}

Expand All @@ -211,6 +224,7 @@ public static Async fromSync(Sync syncSpec) {
* @param progressConsumers the progress consumers.
* @param samplingHandler the sampling handler.
* @param elicitationHandler the elicitation handler.
* @param elicitationCompleteConsumers the elicitation complete consumers.
* @param enableCallToolSchemaCaching whether to enable call tool schema caching.
*/
public record Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities,
Expand All @@ -222,6 +236,7 @@ public record Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabili
List<Consumer<McpSchema.ProgressNotification>> progressConsumers,
Function<McpSchema.CreateMessageRequest, McpSchema.CreateMessageResult> samplingHandler,
Function<McpSchema.ElicitRequest, McpSchema.ElicitResult> elicitationHandler,
List<Consumer<McpSchema.ElicitationCompleteNotification>> elicitationCompleteConsumers,
boolean enableCallToolSchemaCaching) {

/**
Expand All @@ -237,6 +252,7 @@ public record Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabili
* @param progressConsumers the progress consumers.
* @param samplingHandler the sampling handler.
* @param elicitationHandler the elicitation handler.
* @param elicitationCompleteConsumers the elicitation complete consumers.
* @param enableCallToolSchemaCaching whether to enable call tool schema caching.
*/
public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities clientCapabilities,
Expand All @@ -248,6 +264,7 @@ public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities cl
List<Consumer<McpSchema.ProgressNotification>> progressConsumers,
Function<McpSchema.CreateMessageRequest, McpSchema.CreateMessageResult> samplingHandler,
Function<McpSchema.ElicitRequest, McpSchema.ElicitResult> elicitationHandler,
List<Consumer<McpSchema.ElicitationCompleteNotification>> elicitationCompleteConsumers,
boolean enableCallToolSchemaCaching) {

Assert.notNull(clientInfo, "Client info must not be null");
Expand All @@ -267,6 +284,8 @@ public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities cl
this.progressConsumers = progressConsumers != null ? progressConsumers : List.of();
this.samplingHandler = samplingHandler;
this.elicitationHandler = elicitationHandler;
this.elicitationCompleteConsumers = elicitationCompleteConsumers != null ? elicitationCompleteConsumers
: List.of();
this.enableCallToolSchemaCaching = enableCallToolSchemaCaching;
}

Expand All @@ -283,7 +302,7 @@ public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities cl
Function<McpSchema.ElicitRequest, McpSchema.ElicitResult> elicitationHandler) {
this(clientInfo, clientCapabilities, roots, toolsChangeConsumers, resourcesChangeConsumers,
resourcesUpdateConsumers, promptsChangeConsumers, loggingConsumers, List.of(), samplingHandler,
elicitationHandler, false);
elicitationHandler, List.of(), false);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,15 @@

package io.modelcontextprotocol.server;

import io.modelcontextprotocol.common.McpTransportContext;
import java.util.ArrayList;
import java.util.Collections;

import io.modelcontextprotocol.common.McpTransportContext;
import io.modelcontextprotocol.json.TypeRef;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpLoggableSession;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.LoggingLevel;
import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification;
import io.modelcontextprotocol.spec.McpSession;
import io.modelcontextprotocol.util.Assert;
import reactor.core.publisher.Mono;

Expand Down Expand Up @@ -152,10 +150,31 @@ public Mono<McpSchema.ElicitResult> createElicitation(McpSchema.ElicitRequest el
if (this.clientCapabilities.elicitation() == null) {
return Mono.error(new IllegalStateException("Client must be configured with elicitation capabilities"));
}
if ("url".equals(elicitRequest.mode()) && this.clientCapabilities.elicitation().url() == null) {
return Mono.error(new IllegalStateException("Client must be configured with URL elicitation capabilities"));
}
return this.session.sendRequest(McpSchema.METHOD_ELICITATION_CREATE, elicitRequest,
ELICITATION_RESULT_TYPE_REF);
}

/**
* Sends a notification to the client that an out-of-band URL elicitation interaction
* has completed.
* @param elicitationId The ID of the elicitation that completed
* @return A Mono that completes when the notification has been sent
*/
public Mono<Void> sendElicitationComplete(String elicitationId) {
if (this.clientCapabilities == null) {
return Mono
.error(new IllegalStateException("Client must be initialized. Call the initialize method first!"));
}
if (this.clientCapabilities.elicitation() == null || this.clientCapabilities.elicitation().url() == null) {
return Mono.error(new IllegalStateException("Client must be configured with URL elicitation capabilities"));
}
return this.session.sendNotification(McpSchema.METHOD_NOTIFICATION_ELICITATION_COMPLETE,
new McpSchema.ElicitationCompleteNotification(elicitationId, null));
}

/**
* Retrieves the list of all roots provided by the client.
* @return A Mono that emits the list of roots result.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,15 @@ public McpSchema.ElicitResult createElicitation(McpSchema.ElicitRequest elicitRe
return this.exchange.createElicitation(elicitRequest).block();
}

/**
* Sends a notification to the client that an out-of-band URL elicitation interaction
* has completed.
* @param elicitationId The ID of the elicitation that completed
*/
public void sendElicitationComplete(String elicitationId) {
this.exchange.sendElicitationComplete(elicitationId).block();
}

/**
* Retrieves the list of all roots provided by the client.
* @return The list of roots result.
Expand Down
Loading