diff --git a/dotnet/src/Client.cs b/dotnet/src/Client.cs index 692d511f7..e785d6dcd 100644 --- a/dotnet/src/Client.cs +++ b/dotnet/src/Client.cs @@ -70,6 +70,7 @@ public sealed partial class CopilotClient : IDisposable, IAsyncDisposable private bool _disposed; private readonly int? _optionsPort; private readonly string? _optionsHost; + private readonly string? _effectiveConnectionToken; private int? _actualPort; private int? _negotiatedProtocolVersion; private List? _modelsCache; @@ -140,6 +141,22 @@ public CopilotClient(CopilotClientOptions? options = null) throw new ArgumentException("GitHubToken and UseLoggedInUser cannot be used with CliUrl (external server manages its own auth)"); } + if (_options.TcpConnectionToken is not null) + { + if (_options.TcpConnectionToken.Length == 0) + { + throw new ArgumentException("TcpConnectionToken must be a non-empty string"); + } + if (_options.UseStdio && string.IsNullOrEmpty(_options.CliUrl)) + { + throw new ArgumentException("TcpConnectionToken cannot be used with UseStdio = true"); + } + } + + var sdkSpawnsCli = !_options.UseStdio && string.IsNullOrEmpty(_options.CliUrl); + _effectiveConnectionToken = _options.TcpConnectionToken + ?? (sdkSpawnsCli ? Guid.NewGuid().ToString() : null); + _logger = _options.Logger ?? NullLogger.Instance; _onListModels = _options.OnListModels; @@ -216,7 +233,7 @@ async Task StartCoreAsync(CancellationToken ct) else { // Child process (stdio or TCP) - var (cliProcess, portOrNull, stderrBuffer) = await StartCliServerAsync(_options, _logger, ct); + var (cliProcess, portOrNull, stderrBuffer) = await StartCliServerAsync(_options, _effectiveConnectionToken, _logger, ct); _actualPort = portOrNull; result = ConnectToServerAsync(cliProcess, portOrNull is null ? null : "localhost", portOrNull, stderrBuffer, ct); } @@ -1122,10 +1139,23 @@ private void ConfigureSessionFsHandlers(CopilotSession session, Func( - connection.Rpc, "ping", [new PingRequest()], connection.StderrBuffer, cancellationToken); + int? serverVersion; + try + { + var connectResponse = await InvokeRpcAsync( + connection.Rpc, "connect", [new ConnectRequest { Token = _effectiveConnectionToken }], connection.StderrBuffer, cancellationToken); + serverVersion = (int)connectResponse.ProtocolVersion; + } + catch (RemoteRpcException ex) when (ex.ErrorCode == RemoteRpcException.MethodNotFoundErrorCode) + { + // Legacy server without `connect`; fall back to `ping`. A token, if any, + // is silently dropped — the legacy server can't enforce one. + var pingResponse = await InvokeRpcAsync( + connection.Rpc, "ping", [new PingRequest()], connection.StderrBuffer, cancellationToken); + serverVersion = pingResponse.ProtocolVersion; + } - if (!pingResponse.ProtocolVersion.HasValue) + if (!serverVersion.HasValue) { throw new InvalidOperationException( $"SDK protocol version mismatch: SDK supports versions {MinProtocolVersion}-{maxVersion}, " + @@ -1133,19 +1163,18 @@ private async Task VerifyProtocolVersionAsync(Connection connection, Cancellatio $"Please update your server to ensure compatibility."); } - var serverVersion = pingResponse.ProtocolVersion.Value; - if (serverVersion < MinProtocolVersion || serverVersion > maxVersion) + if (serverVersion.Value < MinProtocolVersion || serverVersion.Value > maxVersion) { throw new InvalidOperationException( $"SDK protocol version mismatch: SDK supports versions {MinProtocolVersion}-{maxVersion}, " + - $"but server reports version {serverVersion}. " + + $"but server reports version {serverVersion.Value}. " + $"Please update your SDK or server to ensure compatibility."); } - _negotiatedProtocolVersion = serverVersion; + _negotiatedProtocolVersion = serverVersion.Value; } - private static async Task<(Process Process, int? DetectedLocalhostTcpPort, StringBuilder StderrBuffer)> StartCliServerAsync(CopilotClientOptions options, ILogger logger, CancellationToken cancellationToken) + private static async Task<(Process Process, int? DetectedLocalhostTcpPort, StringBuilder StderrBuffer)> StartCliServerAsync(CopilotClientOptions options, string? connectionToken, ILogger logger, CancellationToken cancellationToken) { // Use explicit path, COPILOT_CLI_PATH env var (from options.Environment or process env), or bundled CLI - no PATH fallback var envCliPath = options.Environment is not null && options.Environment.TryGetValue("COPILOT_CLI_PATH", out var envValue) ? envValue @@ -1221,6 +1250,11 @@ private async Task VerifyProtocolVersionAsync(Connection connection, Cancellatio startInfo.Environment["COPILOT_SDK_AUTH_TOKEN"] = options.GitHubToken; } + if (!string.IsNullOrEmpty(connectionToken)) + { + startInfo.Environment["COPILOT_CONNECTION_TOKEN"] = connectionToken; + } + // Set telemetry environment variables if configured if (options.Telemetry is { } telemetry) { diff --git a/dotnet/src/Generated/Rpc.cs b/dotnet/src/Generated/Rpc.cs index 9d1a76558..214b8235f 100644 --- a/dotnet/src/Generated/Rpc.cs +++ b/dotnet/src/Generated/Rpc.cs @@ -46,6 +46,30 @@ internal sealed class PingRequest public string? Message { get; set; } } +/// RPC data type for Connect operations. +internal sealed class ConnectResult +{ + /// Always true on success. + [JsonPropertyName("ok")] + public bool Ok { get; set; } + + /// Server protocol version number. + [JsonPropertyName("protocolVersion")] + public long ProtocolVersion { get; set; } + + /// Server package version. + [JsonPropertyName("version")] + public string Version { get; set; } = string.Empty; +} + +/// RPC data type for Connect operations. +internal sealed class ConnectRequest +{ + /// Connection token; required when the server was started with COPILOT_CONNECTION_TOKEN. + [JsonPropertyName("token")] + public string? Token { get; set; } +} + /// Billing information. public sealed class ModelBilling { @@ -483,6 +507,14 @@ internal sealed class SessionsForkRequest public string? ToEventId { get; set; } } +/// RPC data type for SessionSuspend operations. +internal sealed class SessionSuspendRequest +{ + /// Target session identifier. + [JsonPropertyName("sessionId")] + public string SessionId { get; set; } = string.Empty; +} + /// RPC data type for Log operations. public sealed class LogResult { @@ -3122,6 +3154,13 @@ public async Task PingAsync(string? message = null, CancellationToke return await CopilotClient.InvokeRpcAsync(_rpc, "ping", [request], cancellationToken); } + /// Calls "connect". + internal async Task ConnectAsync(string? token = null, CancellationToken cancellationToken = default) + { + var request = new ConnectRequest { Token = token }; + return await CopilotClient.InvokeRpcAsync(_rpc, "connect", [request], cancellationToken); + } + /// Models APIs. public ServerModelsApi Models { get; } @@ -3445,6 +3484,13 @@ internal SessionRpc(JsonRpc rpc, string sessionId) /// Usage APIs. public UsageApi Usage { get; } + /// Calls "session.suspend". + public async Task SuspendAsync(CancellationToken cancellationToken = default) + { + var request = new SessionSuspendRequest { SessionId = _sessionId }; + await CopilotClient.InvokeRpcAsync(_rpc, "session.suspend", [request], cancellationToken); + } + /// Calls "session.log". public async Task LogAsync(string message, SessionLogLevel? level = null, bool? ephemeral = null, string? url = null, CancellationToken cancellationToken = default) { @@ -4237,6 +4283,8 @@ public static void RegisterClientSessionApiHandlers(JsonRpc rpc, Func internal sealed class RemoteRpcException(string message, int errorCode, Exception? innerException = null) : Exception(message, innerException) { + /// JSON-RPC 2.0 reserved error code: requested method does not exist. + public const int MethodNotFoundErrorCode = -32601; + public int ErrorCode { get; } = errorCode; } diff --git a/dotnet/src/Types.cs b/dotnet/src/Types.cs index 3ecc483bf..3015096a8 100644 --- a/dotnet/src/Types.cs +++ b/dotnet/src/Types.cs @@ -70,6 +70,7 @@ protected CopilotClientOptions(CopilotClientOptions? other) OnListModels = other.OnListModels; SessionFs = other.SessionFs; SessionIdleTimeoutSeconds = other.SessionIdleTimeoutSeconds; + TcpConnectionToken = other.TcpConnectionToken; } /// @@ -175,6 +176,13 @@ public string? GithubToken /// public int? SessionIdleTimeoutSeconds { get; set; } + /// + /// Connection token for the headless CLI server (TCP only). When the SDK spawns its own + /// CLI in TCP mode and this is omitted, a GUID is generated automatically so the loopback + /// listener is safe by default. Cannot be combined with = true. + /// + public string? TcpConnectionToken { get; set; } + /// /// Creates a shallow clone of this instance. /// diff --git a/dotnet/test/ConnectionTokenTests.cs b/dotnet/test/ConnectionTokenTests.cs new file mode 100644 index 000000000..fac00af4c --- /dev/null +++ b/dotnet/test/ConnectionTokenTests.cs @@ -0,0 +1,144 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +using GitHub.Copilot.SDK.Test.Harness; +using Xunit; + +namespace GitHub.Copilot.SDK.Test; + +/// +/// Custom fixture that spawns a CLI in TCP mode with an explicit connection token, so +/// sibling clients can attempt to connect to the same port with the right/wrong/no token. +/// +public class ConnectionTokenTestFixture : IAsyncLifetime +{ + public E2ETestContext Ctx { get; private set; } = null!; + public CopilotClient GoodClient { get; private set; } = null!; + public int Port { get; private set; } + + public const string Token = "right-token"; + + public async Task InitializeAsync() + { + Ctx = await E2ETestContext.CreateAsync(); + GoodClient = Ctx.CreateClient(useStdio: false, options: new CopilotClientOptions + { + TcpConnectionToken = Token, + }); + + await GoodClient.StartAsync(); + Port = GoodClient.ActualPort + ?? throw new InvalidOperationException("GoodClient is not using TCP mode; ActualPort is null"); + } + + public async Task DisposeAsync() + { + if (GoodClient is not null) + { + await GoodClient.ForceStopAsync(); + } + + await Ctx.DisposeAsync(); + } +} + +public class ConnectionTokenTests : IClassFixture +{ + private readonly ConnectionTokenTestFixture _fixture; + + public ConnectionTokenTests(ConnectionTokenTestFixture fixture) + { + _fixture = fixture; + } + + [Fact] + public async Task Connects_With_The_Matching_Token() + { + var pong = await _fixture.GoodClient.PingAsync("hi"); + Assert.Equal("pong: hi", pong.Message); + } + + [Fact] + public async Task Rejects_A_Wrong_Token() + { + var wrongClient = new CopilotClient(new CopilotClientOptions + { + CliUrl = $"localhost:{_fixture.Port}", + TcpConnectionToken = "wrong", + }); + + try + { + var ex = await Assert.ThrowsAnyAsync(() => wrongClient.StartAsync()); + Assert.Contains("AUTHENTICATION_FAILED", GetFullMessage(ex)); + } + finally + { + try { await wrongClient.ForceStopAsync(); } catch { } + } + } + + [Fact] + public async Task Rejects_A_Missing_Token_When_One_Is_Required() + { + var noTokenClient = new CopilotClient(new CopilotClientOptions + { + CliUrl = $"localhost:{_fixture.Port}", + }); + + try + { + var ex = await Assert.ThrowsAnyAsync(() => noTokenClient.StartAsync()); + Assert.Contains("AUTHENTICATION_FAILED", GetFullMessage(ex)); + } + finally + { + try { await noTokenClient.ForceStopAsync(); } catch { } + } + } + + private static string GetFullMessage(Exception ex) + { + var messages = new List(); + for (var cur = ex; cur is not null; cur = cur.InnerException) + { + messages.Add(cur.Message); + } + return string.Join(" | ", messages); + } +} + +/// +/// When the SDK spawns its own CLI in TCP mode without an explicit token, it auto-generates +/// a GUID and round-trips it through the spawned CLI. +/// +public class ConnectionTokenAutoGeneratedTests : IAsyncLifetime +{ + private E2ETestContext _ctx = null!; + private CopilotClient _client = null!; + + public async Task InitializeAsync() + { + _ctx = await E2ETestContext.CreateAsync(); + _client = _ctx.CreateClient(useStdio: false); + } + + public async Task DisposeAsync() + { + if (_client is not null) + { + await _client.ForceStopAsync(); + } + + await _ctx.DisposeAsync(); + } + + [Fact] + public async Task The_SDK_Auto_Generated_Guid_Round_Trips_Through_The_Spawned_CLI() + { + await _client.StartAsync(); + var pong = await _client.PingAsync("hi"); + Assert.Equal("pong: hi", pong.Message); + } +} diff --git a/dotnet/test/MultiClientCommandsElicitationTests.cs b/dotnet/test/MultiClientCommandsElicitationTests.cs index c5571b43e..b3a31567f 100644 --- a/dotnet/test/MultiClientCommandsElicitationTests.cs +++ b/dotnet/test/MultiClientCommandsElicitationTests.cs @@ -18,10 +18,15 @@ public class MultiClientCommandsElicitationFixture : IAsyncLifetime public E2ETestContext Ctx { get; private set; } = null!; public CopilotClient Client1 { get; private set; } = null!; + public const string SharedToken = "multi-client-cmd-shared-token"; + public async Task InitializeAsync() { Ctx = await E2ETestContext.CreateAsync(); - Client1 = Ctx.CreateClient(useStdio: false); + Client1 = Ctx.CreateClient(useStdio: false, options: new CopilotClientOptions + { + TcpConnectionToken = SharedToken, + }); } public async Task DisposeAsync() @@ -80,6 +85,7 @@ public async Task InitializeAsync() _client2 = new CopilotClient(new CopilotClientOptions { CliUrl = $"localhost:{port}", + TcpConnectionToken = MultiClientCommandsElicitationFixture.SharedToken, }); } @@ -221,6 +227,7 @@ public async Task Capabilities_Changed_Fires_When_Elicitation_Provider_Disconnec _client3 = new CopilotClient(new CopilotClientOptions { CliUrl = $"localhost:{port}", + TcpConnectionToken = MultiClientCommandsElicitationFixture.SharedToken, }); // Client3 joins WITH elicitation handler diff --git a/dotnet/test/MultiClientTests.cs b/dotnet/test/MultiClientTests.cs index 2a262466e..a9ec8abb0 100644 --- a/dotnet/test/MultiClientTests.cs +++ b/dotnet/test/MultiClientTests.cs @@ -21,10 +21,15 @@ public class MultiClientTestFixture : IAsyncLifetime public E2ETestContext Ctx { get; private set; } = null!; public CopilotClient Client1 { get; private set; } = null!; + public const string SharedToken = "multi-client-shared-token"; + public async Task InitializeAsync() { Ctx = await E2ETestContext.CreateAsync(); - Client1 = Ctx.CreateClient(useStdio: false); + Client1 = Ctx.CreateClient(useStdio: false, options: new CopilotClientOptions + { + TcpConnectionToken = SharedToken, + }); } public async Task DisposeAsync() @@ -78,6 +83,7 @@ public async Task InitializeAsync() _client2 = new CopilotClient(new CopilotClientOptions { CliUrl = $"localhost:{port}", + TcpConnectionToken = MultiClientTestFixture.SharedToken, }); } @@ -336,6 +342,7 @@ public async Task Disconnecting_Client_Removes_Its_Tools() _client2 = new CopilotClient(new CopilotClientOptions { CliUrl = $"localhost:{port}", + TcpConnectionToken = MultiClientTestFixture.SharedToken, }); // Now only stable_tool should be available diff --git a/dotnet/test/SessionFsTests.cs b/dotnet/test/SessionFsTests.cs index a007a6c30..46f539aa8 100644 --- a/dotnet/test/SessionFsTests.cs +++ b/dotnet/test/SessionFsTests.cs @@ -94,7 +94,7 @@ public async Task Should_Reject_SetProvider_When_Sessions_Already_Exist() var providerRoot = CreateProviderRoot(); try { - await using var client1 = CreateSessionFsClient(providerRoot, useStdio: false); + await using var client1 = CreateSessionFsClient(providerRoot, useStdio: false, tcpConnectionToken: "session-fs-shared-token"); var createSessionFsHandler = (Func)(s => new TestSessionFsHandler(s.SessionId, providerRoot)); _ = await client1.CreateSessionAsync(new SessionConfig @@ -113,6 +113,7 @@ public async Task Should_Reject_SetProvider_When_Sessions_Already_Exist() CliUrl = $"localhost:{port}", LogLevel = "error", SessionFs = SessionFsConfig, + TcpConnectionToken = "session-fs-shared-token", }); try @@ -291,7 +292,7 @@ public async Task Should_Persist_Plan_Md_Via_SessionFs() } } - private CopilotClient CreateSessionFsClient(string providerRoot, bool useStdio = true) + private CopilotClient CreateSessionFsClient(string providerRoot, bool useStdio = true, string? tcpConnectionToken = null) { Directory.CreateDirectory(providerRoot); return Ctx.CreateClient( @@ -299,6 +300,7 @@ private CopilotClient CreateSessionFsClient(string providerRoot, bool useStdio = options: new CopilotClientOptions { SessionFs = SessionFsConfig, + TcpConnectionToken = tcpConnectionToken, }); } diff --git a/go/client.go b/go/client.go index b05479336..d5689d1e2 100644 --- a/go/client.go +++ b/go/client.go @@ -112,11 +112,18 @@ type Client struct { processErrorPtr *error osProcess atomic.Pointer[os.Process] negotiatedProtocolVersion int - onListModels func(ctx context.Context) ([]ModelInfo, error) + // effectiveConnectionToken is the token sent in `connect`; auto-generated when + // the SDK spawns its own CLI in TCP mode. + effectiveConnectionToken string + onListModels func(ctx context.Context) ([]ModelInfo, error) // RPC provides typed server-scoped RPC methods. // This field is nil until the client is connected via Start(). RPC *rpc.ServerRpc + + // internalRPC provides SDK-internal RPC methods (handshake helpers etc.). + // Lowercase = not exported; external callers cannot reach it. + internalRPC *rpc.InternalServerRpc } // NewClient creates a new Copilot CLI client with the given options. @@ -163,6 +170,11 @@ func NewClient(options *ClientOptions) *Client { panic("GitHubToken and UseLoggedInUser cannot be used with CLIUrl (external server manages its own auth)") } + // Validate token vs stdio + if options.TCPConnectionToken != "" && options.UseStdio != nil && *options.UseStdio { + panic("TCPConnectionToken cannot be used with UseStdio: true") + } + // Parse CLIUrl if provided if options.CLIUrl != "" { host, port := parseCliUrl(options.CLIUrl) @@ -233,6 +245,14 @@ func NewClient(options *ClientOptions) *Client { } } + // Resolve the effective connection token: explicit value if set; else if the SDK + // spawns its own CLI in TCP mode, generate a UUID; otherwise empty. + if options != nil && options.TCPConnectionToken != "" { + client.effectiveConnectionToken = options.TCPConnectionToken + } else if !client.useStdio && !client.isExternalServer { + client.effectiveConnectionToken = uuid.NewString() + } + client.options = opts return client } @@ -425,6 +445,7 @@ func (c *Client) Stop() error { } c.RPC = nil + c.internalRPC = nil return errors.Join(errs...) } @@ -496,6 +517,7 @@ func (c *Client) ForceStop() { } c.RPC = nil + c.internalRPC = nil } func (c *Client) ensureConnected(ctx context.Context) error { @@ -1324,25 +1346,49 @@ func (c *Client) ListModels(ctx context.Context) ([]ModelInfo, error) { // minProtocolVersion is the minimum protocol version this SDK can communicate with. const minProtocolVersion = 2 -// verifyProtocolVersion verifies that the server's protocol version is within the supported range -// and stores the negotiated version. +// verifyProtocolVersion sends the `connect` handshake (carrying the optional token) and +// verifies the server's protocol version. Falls back to `ping` against legacy servers +// that don't implement `connect`. func (c *Client) verifyProtocolVersion(ctx context.Context) error { + if c.client == nil { + return fmt.Errorf("client not connected") + } maxVersion := GetSdkProtocolVersion() - pingResult, err := c.Ping(ctx, "") + + var serverVersion *int + tokenPtr := (*string)(nil) + if c.effectiveConnectionToken != "" { + t := c.effectiveConnectionToken + tokenPtr = &t + } + connectResult, err := c.internalRPC.Connect(ctx, &rpc.ConnectRequest{Token: tokenPtr}) if err != nil { - return err + var rpcErr *jsonrpc2.Error + if errors.As(err, &rpcErr) && rpcErr.Code == jsonrpc2.ErrMethodNotFound.Code { + // Legacy server without `connect`; fall back to `ping`. A token, if any, + // is silently dropped — the legacy server can't enforce one. + pingResult, perr := c.Ping(ctx, "") + if perr != nil { + return perr + } + serverVersion = pingResult.ProtocolVersion + } else { + return err + } + } else { + v := int(connectResult.ProtocolVersion) + serverVersion = &v } - if pingResult.ProtocolVersion == nil { + if serverVersion == nil { return fmt.Errorf("SDK protocol version mismatch: SDK supports versions %d-%d, but server does not report a protocol version. Please update your server to ensure compatibility", minProtocolVersion, maxVersion) } - serverVersion := *pingResult.ProtocolVersion - if serverVersion < minProtocolVersion || serverVersion > maxVersion { - return fmt.Errorf("SDK protocol version mismatch: SDK supports versions %d-%d, but server reports version %d. Please update your SDK or server to ensure compatibility", minProtocolVersion, maxVersion, serverVersion) + if *serverVersion < minProtocolVersion || *serverVersion > maxVersion { + return fmt.Errorf("SDK protocol version mismatch: SDK supports versions %d-%d, but server reports version %d. Please update your SDK or server to ensure compatibility", minProtocolVersion, maxVersion, *serverVersion) } - c.negotiatedProtocolVersion = serverVersion + c.negotiatedProtocolVersion = *serverVersion return nil } @@ -1415,6 +1461,10 @@ func (c *Client) startCLIServer(ctx context.Context) error { c.process.Env = append(c.process.Env, "COPILOT_SDK_AUTH_TOKEN="+c.options.GitHubToken) } + if c.effectiveConnectionToken != "" { + c.process.Env = append(c.process.Env, "COPILOT_CONNECTION_TOKEN="+c.effectiveConnectionToken) + } + if c.options.Telemetry != nil { t := c.options.Telemetry c.process.Env = append(c.process.Env, "COPILOT_OTEL_ENABLED=true") @@ -1470,6 +1520,7 @@ func (c *Client) startCLIServer(ctx context.Context) error { }() }) c.RPC = rpc.NewServerRpc(c.client) + c.internalRPC = rpc.NewInternalServerRpc(c.client) c.setupNotificationHandler() c.client.Start() @@ -1595,6 +1646,7 @@ func (c *Client) connectViaTcp(ctx context.Context) error { }() }) c.RPC = rpc.NewServerRpc(c.client) + c.internalRPC = rpc.NewInternalServerRpc(c.client) c.setupNotificationHandler() c.client.Start() diff --git a/go/internal/e2e/connection_token_test.go b/go/internal/e2e/connection_token_test.go new file mode 100644 index 000000000..269c5ae5a --- /dev/null +++ b/go/internal/e2e/connection_token_test.go @@ -0,0 +1,114 @@ +package e2e + +import ( + "fmt" + "strings" + "testing" + + copilot "github.com/github/copilot-sdk/go" + "github.com/github/copilot-sdk/go/internal/e2e/testharness" +) + +func TestConnectionToken(t *testing.T) { + t.Run("explicit token round-trips successfully", func(t *testing.T) { + ctx := testharness.NewTestContext(t) + client := ctx.NewClient(func(opts *copilot.ClientOptions) { + opts.UseStdio = copilot.Bool(false) + opts.TCPConnectionToken = "right-token" + }) + t.Cleanup(func() { client.ForceStop() }) + + if err := client.Start(t.Context()); err != nil { + t.Fatalf("Start failed: %v", err) + } + + resp, err := client.Ping(t.Context(), "hi") + if err != nil { + t.Fatalf("Ping failed: %v", err) + } + if resp.Message != "pong: hi" { + t.Errorf("expected message 'pong: hi', got %q", resp.Message) + } + }) + + t.Run("auto-generated token round-trips successfully", func(t *testing.T) { + ctx := testharness.NewTestContext(t) + client := ctx.NewClient(func(opts *copilot.ClientOptions) { + opts.UseStdio = copilot.Bool(false) + }) + t.Cleanup(func() { client.ForceStop() }) + + if err := client.Start(t.Context()); err != nil { + t.Fatalf("Start failed: %v", err) + } + + resp, err := client.Ping(t.Context(), "hi") + if err != nil { + t.Fatalf("Ping failed: %v", err) + } + if resp.Message != "pong: hi" { + t.Errorf("expected message 'pong: hi', got %q", resp.Message) + } + }) + + t.Run("sibling client with wrong token is rejected", func(t *testing.T) { + ctx := testharness.NewTestContext(t) + good := ctx.NewClient(func(opts *copilot.ClientOptions) { + opts.UseStdio = copilot.Bool(false) + opts.TCPConnectionToken = "right-token" + }) + t.Cleanup(func() { good.ForceStop() }) + + if err := good.Start(t.Context()); err != nil { + t.Fatalf("good client Start failed: %v", err) + } + port := good.ActualPort() + if port == 0 { + t.Fatalf("expected non-zero port from TCP mode client") + } + + bad := copilot.NewClient(&copilot.ClientOptions{ + CLIUrl: fmt.Sprintf("localhost:%d", port), + TCPConnectionToken: "wrong", + }) + t.Cleanup(func() { bad.ForceStop() }) + + err := bad.Start(t.Context()) + if err == nil { + t.Fatalf("expected sibling client with wrong token to fail") + } + if !strings.Contains(err.Error(), "AUTHENTICATION_FAILED") { + t.Errorf("expected AUTHENTICATION_FAILED error, got: %v", err) + } + }) + + t.Run("sibling client with no token is rejected", func(t *testing.T) { + ctx := testharness.NewTestContext(t) + good := ctx.NewClient(func(opts *copilot.ClientOptions) { + opts.UseStdio = copilot.Bool(false) + opts.TCPConnectionToken = "right-token" + }) + t.Cleanup(func() { good.ForceStop() }) + + if err := good.Start(t.Context()); err != nil { + t.Fatalf("good client Start failed: %v", err) + } + port := good.ActualPort() + if port == 0 { + t.Fatalf("expected non-zero port from TCP mode client") + } + + none := copilot.NewClient(&copilot.ClientOptions{ + CLIUrl: fmt.Sprintf("localhost:%d", port), + }) + t.Cleanup(func() { none.ForceStop() }) + + err := none.Start(t.Context()) + if err == nil { + t.Fatalf("expected sibling client with no token to fail") + } + if !strings.Contains(err.Error(), "AUTHENTICATION_FAILED") { + t.Errorf("expected AUTHENTICATION_FAILED error, got: %v", err) + } + }) +} diff --git a/go/rpc/generated_rpc.go b/go/rpc/generated_rpc.go index ab1c897a7..6127c7396 100644 --- a/go/rpc/generated_rpc.go +++ b/go/rpc/generated_rpc.go @@ -26,6 +26,8 @@ type RPCTypes struct { AuthInfoType AuthInfoType `json:"AuthInfoType"` CommandsHandlePendingCommandRequest CommandsHandlePendingCommandRequest `json:"CommandsHandlePendingCommandRequest"` CommandsHandlePendingCommandResult CommandsHandlePendingCommandResult `json:"CommandsHandlePendingCommandResult"` + ConnectRequest ConnectRequest `json:"ConnectRequest"` + ConnectResult ConnectResult `json:"ConnectResult"` CurrentModel CurrentModel `json:"CurrentModel"` DiscoveredMCPServer DiscoveredMCPServer `json:"DiscoveredMcpServer"` DiscoveredMCPServerSource MCPServerSource `json:"DiscoveredMcpServerSource"` @@ -201,6 +203,7 @@ type RPCTypes struct { SkillsEnableRequest SkillsEnableRequest `json:"SkillsEnableRequest"` SkillsEnableResult SkillsEnableResult `json:"SkillsEnableResult"` SkillsReloadResult SkillsReloadResult `json:"SkillsReloadResult"` + SuspendResult SuspendResult `json:"SuspendResult"` TaskAgentInfo TaskAgentInfo `json:"TaskAgentInfo"` TaskAgentInfoExecutionMode TaskInfoExecutionMode `json:"TaskAgentInfoExecutionMode"` TaskAgentInfoStatus TaskInfoStatus `json:"TaskAgentInfoStatus"` @@ -347,6 +350,22 @@ type CommandsHandlePendingCommandResult struct { Success bool `json:"success"` } +// Internal: ConnectRequest is an internal SDK API and is not part of the public surface. +type ConnectRequest struct { + // Connection token; required when the server was started with COPILOT_CONNECTION_TOKEN + Token *string `json:"token,omitempty"` +} + +// Internal: ConnectResult is an internal SDK API and is not part of the public surface. +type ConnectResult struct { + // Always true on success + Ok bool `json:"ok"` + // Server protocol version number + ProtocolVersion int64 `json:"protocolVersion"` + // Server package version + Version string `json:"version"` +} + type CurrentModel struct { // Currently active model identifier ModelID *string `json:"modelId,omitempty"` @@ -1543,6 +1562,9 @@ type SkillsEnableResult struct { type SkillsReloadResult struct { } +type SuspendResult struct { +} + type TaskAgentInfo struct { // ISO 8601 timestamp when the current active period began ActiveStartedAt *time.Time `json:"activeStartedAt,omitempty"` @@ -2741,6 +2763,35 @@ func NewServerRpc(client *jsonrpc2.Client) *ServerRpc { return r } +type internalServerApi struct { + client *jsonrpc2.Client +} + +// InternalServerRpc provides internal SDK server-scoped RPC methods (handshake helpers etc.). Not part of the public API. +type InternalServerRpc struct { + common internalServerApi // Reuse a single struct instead of allocating one for each service on the heap. + +} + +// Internal: Connect is part of the SDK's internal handshake/plumbing; external callers should not use it. +func (a *InternalServerRpc) Connect(ctx context.Context, params *ConnectRequest) (*ConnectResult, error) { + raw, err := a.common.client.Request("connect", params) + if err != nil { + return nil, err + } + var result ConnectResult + if err := json.Unmarshal(raw, &result); err != nil { + return nil, err + } + return &result, nil +} + +func NewInternalServerRpc(client *jsonrpc2.Client) *InternalServerRpc { + r := &InternalServerRpc{} + r.common = internalServerApi{client: client} + return r +} + type sessionApi struct { client *jsonrpc2.Client sessionID string @@ -3639,6 +3690,19 @@ type SessionRpc struct { Usage *UsageApi } +func (a *SessionRpc) Suspend(ctx context.Context) (*SuspendResult, error) { + req := map[string]any{"sessionId": a.common.sessionID} + raw, err := a.common.client.Request("session.suspend", req) + if err != nil { + return nil, err + } + var result SuspendResult + if err := json.Unmarshal(raw, &result); err != nil { + return nil, err + } + return &result, nil +} + func (a *SessionRpc) Log(ctx context.Context, params *LogRequest) (*LogResult, error) { req := map[string]any{"sessionId": a.common.sessionID} if params != nil { diff --git a/go/types.go b/go/types.go index e43bf2ed2..13272b5ba 100644 --- a/go/types.go +++ b/go/types.go @@ -30,6 +30,11 @@ type ClientOptions struct { // UseStdio controls whether to use stdio transport instead of TCP. // Default: nil (use default = true, i.e. stdio). Use Bool(false) to explicitly select TCP. UseStdio *bool + // TCPConnectionToken is the token sent in the `connect` handshake when using TCP transport. + // Only meaningful in TCP mode. When the SDK spawns its own CLI in TCP mode and this is + // empty, an auto-generated UUID is used so the loopback listener is safe by default. + // Combining this with UseStdio=true is rejected (stdio is pre-authenticated by transport). + TCPConnectionToken string // CLIUrl is the URL of an existing Copilot CLI server to connect to over TCP // Format: "host:port", "http://host:port", or just "port" (defaults to localhost) // Examples: "localhost:8080", "http://127.0.0.1:9000", "8080" diff --git a/nodejs/src/client.ts b/nodejs/src/client.ts index 931ce59a4..76f3413d2 100644 --- a/nodejs/src/client.ts +++ b/nodejs/src/client.ts @@ -20,11 +20,13 @@ import { dirname, join } from "node:path"; import { fileURLToPath } from "node:url"; import { createMessageConnection, + ErrorCodes, MessageConnection, + ResponseError, StreamMessageReader, StreamMessageWriter, } from "vscode-jsonrpc/node.js"; -import { createServerRpc, registerClientSessionApiHandlers } from "./generated/rpc.js"; +import { createServerRpc, createInternalServerRpc, registerClientSessionApiHandlers } from "./generated/rpc.js"; import { getSdkProtocolVersion } from "./sdkProtocolVersion.js"; import { CopilotSession, NO_RESULT_PERMISSION_V2_ERROR } from "./session.js"; import { createSessionFsAdapter } from "./sessionFsProvider.js"; @@ -221,6 +223,7 @@ export class CopilotClient { | "telemetry" | "onGetTraceContext" | "sessionFs" + | "tcpConnectionToken" > > & { cliPath?: string; @@ -231,6 +234,8 @@ export class CopilotClient { }; private isExternalServer: boolean = false; private forceStopping: boolean = false; + /** Token sent in `connect`; auto-generated when the SDK spawns its own CLI in TCP mode. */ + private effectiveConnectionToken?: string; private onListModels?: () => Promise | ModelInfo[]; private onGetTraceContext?: TraceContextProvider; private modelsCache: ModelInfo[] | null = null; @@ -241,6 +246,7 @@ export class CopilotClient { Set<(event: SessionLifecycleEvent) => void> > = new Map(); private _rpc: ReturnType | null = null; + private _internalRpc: ReturnType | null = null; private processExitPromise: Promise | null = null; // Rejects when CLI process exits private negotiatedProtocolVersion: number | null = null; /** Connection-level session filesystem config, set via constructor option. */ @@ -260,6 +266,20 @@ export class CopilotClient { return this._rpc; } + /** + * Internal RPC surface (e.g. handshake helpers). Not part of the public API. + * @internal + */ + private get internalRpc(): ReturnType { + if (!this.connection) { + throw new Error("Client is not connected. Call start() first."); + } + if (!this._internalRpc) { + this._internalRpc = createInternalServerRpc(this.connection); + } + return this._internalRpc; + } + /** * Creates a new CopilotClient instance. * @@ -300,6 +320,23 @@ export class CopilotClient { ); } + if (options.tcpConnectionToken !== undefined) { + if ( + typeof options.tcpConnectionToken !== "string" || + options.tcpConnectionToken.length === 0 + ) { + throw new Error("tcpConnectionToken must be a non-empty string"); + } + if (options.useStdio === true) { + throw new Error("tcpConnectionToken cannot be used with useStdio: true"); + } + } + + const willUseStdio = options.cliUrl ? false : (options.useStdio ?? true); + const sdkSpawnsCli = !willUseStdio && !options.cliUrl && !options.isChildProcess; + this.effectiveConnectionToken = + options.tcpConnectionToken ?? (sdkSpawnsCli ? randomUUID() : undefined); + if (options.sessionFs) { this.validateSessionFsConfig(options.sessionFs); } @@ -1064,22 +1101,32 @@ export class CopilotClient { } /** - * Verify that the server's protocol version is within the supported range - * and store the negotiated version. + * Send the `connect` handshake (carrying the optional token) and verify the + * server's protocol version. Falls back to `ping` against legacy servers + * that don't implement `connect`. */ private async verifyProtocolVersion(): Promise { + if (!this.connection) { + throw new Error("Client not connected"); + } const maxVersion = getSdkProtocolVersion(); + const raceAgainstExit = (p: Promise): Promise => + this.processExitPromise ? Promise.race([p, this.processExitPromise]) : p; - // Race ping against process exit to detect early CLI failures - let pingResult: Awaited>; - if (this.processExitPromise) { - pingResult = await Promise.race([this.ping(), this.processExitPromise]); - } else { - pingResult = await this.ping(); + let serverVersion: number | undefined; + try { + const result = await raceAgainstExit(this.internalRpc.connect({ token: this.effectiveConnectionToken })); + serverVersion = result.protocolVersion; + } catch (err) { + if (err instanceof ResponseError && err.code === ErrorCodes.MethodNotFound) { + // Legacy server without `connect`; fall back to `ping`. A token, if any, + // is silently dropped — the legacy server can't enforce one. + serverVersion = (await raceAgainstExit(this.ping())).protocolVersion; + } else { + throw err; + } } - const serverVersion = pingResult.protocolVersion; - if (serverVersion === undefined) { throw new Error( `SDK protocol version mismatch: SDK supports versions ${MIN_PROTOCOL_VERSION}-${maxVersion}, but server does not report a protocol version. ` + @@ -1437,6 +1484,10 @@ export class CopilotClient { envWithoutNodeDebug.COPILOT_SDK_AUTH_TOKEN = this.options.gitHubToken; } + if (this.effectiveConnectionToken) { + envWithoutNodeDebug.COPILOT_CONNECTION_TOKEN = this.effectiveConnectionToken; + } + if (!this.options.cliPath) { throw new Error( "Path to Copilot CLI is required. Please provide it via the cliPath option, or use cliUrl to rely on a remote CLI." diff --git a/nodejs/src/generated/rpc.ts b/nodejs/src/generated/rpc.ts index 42cdc039b..6836324ab 100644 --- a/nodejs/src/generated/rpc.ts +++ b/nodejs/src/generated/rpc.ts @@ -395,6 +395,30 @@ export interface CommandsHandlePendingCommandResult { success: boolean; } +/** @internal */ +export interface ConnectRequest { + /** + * Connection token; required when the server was started with COPILOT_CONNECTION_TOKEN + */ + token?: string; +} + +/** @internal */ +export interface ConnectResult { + /** + * Always true on success + */ + ok: true; + /** + * Server protocol version number + */ + protocolVersion: number; + /** + * Server package version + */ + version: string; +} + export interface CurrentModel { /** * Currently active model identifier @@ -2521,9 +2545,23 @@ export function createServerRpc(connection: MessageConnection) { }; } +/** + * Create typed server-scoped RPC methods that are part of the SDK's internal + * surface (e.g. handshake helpers). Not exported on the public client API. + * @internal + */ +export function createInternalServerRpc(connection: MessageConnection) { + return { + connect: async (params: ConnectRequest): Promise => + connection.sendRequest("connect", params), + }; +} + /** Create typed session-scoped RPC methods. */ export function createSessionRpc(connection: MessageConnection, sessionId: string) { return { + suspend: async (): Promise => + connection.sendRequest("session.suspend", { sessionId }), auth: { getStatus: async (): Promise => connection.sendRequest("session.auth.getStatus", { sessionId }), diff --git a/nodejs/src/types.ts b/nodejs/src/types.ts index 93f2360fa..17a87a277 100644 --- a/nodejs/src/types.ts +++ b/nodejs/src/types.ts @@ -194,6 +194,14 @@ export interface CopilotClientOptions { * @default undefined (disabled) */ sessionIdleTimeoutSeconds?: number; + + /** + * Connection token for the headless CLI server (TCP only). When the SDK + * spawns its own CLI in TCP mode and this is omitted, a UUID is generated + * automatically so the loopback listener is safe by default. Rejected with + * `useStdio: true` (stdio is pre-authenticated by transport). + */ + tcpConnectionToken?: string; } /** diff --git a/nodejs/test/e2e/commands.test.ts b/nodejs/test/e2e/commands.test.ts index ea97f0ba0..9047012c3 100644 --- a/nodejs/test/e2e/commands.test.ts +++ b/nodejs/test/e2e/commands.test.ts @@ -9,15 +9,16 @@ import { createSdkTestContext } from "./harness/sdkTestContext.js"; describe("Commands", async () => { // Use TCP mode so a second client can connect to the same CLI process - const ctx = await createSdkTestContext({ useStdio: false }); + const tcpConnectionToken = "commands-test-token"; + const ctx = await createSdkTestContext({ useStdio: false, copilotClientOptions: { tcpConnectionToken } }); const client1 = ctx.copilotClient; // Trigger connection so we can read the port const initSession = await client1.createSession({ onPermissionRequest: approveAll }); await initSession.disconnect(); - const actualPort = (client1 as unknown as { actualPort: number }).actualPort; - const client2 = new CopilotClient({ cliUrl: `localhost:${actualPort}` }); + const { actualPort } = client1 as unknown as { actualPort: number }; + const client2 = new CopilotClient({ cliUrl: `localhost:${actualPort}`, tcpConnectionToken }); afterAll(async () => { await client2.stop(); diff --git a/nodejs/test/e2e/connection_token.test.ts b/nodejs/test/e2e/connection_token.test.ts new file mode 100644 index 000000000..50813778c --- /dev/null +++ b/nodejs/test/e2e/connection_token.test.ts @@ -0,0 +1,49 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + *--------------------------------------------------------------------------------------------*/ + +import { afterAll, describe, expect, it } from "vitest"; +import { CopilotClient } from "../../src/index.js"; +import { createSdkTestContext } from "./harness/sdkTestContext.js"; + +describe("Connection token", async () => { + const ctx = await createSdkTestContext({ + useStdio: false, + copilotClientOptions: { tcpConnectionToken: "right-token" }, + }); + const goodClient = ctx.copilotClient; + await goodClient.start(); + const port = (goodClient as unknown as { actualPort: number }).actualPort; + + const wrongClient = new CopilotClient({ + cliUrl: `localhost:${port}`, + tcpConnectionToken: "wrong", + }); + const noTokenClient = new CopilotClient({ cliUrl: `localhost:${port}` }); + + afterAll(async () => { + await wrongClient.forceStop(); + await noTokenClient.forceStop(); + }); + + it("connects with the matching token", async () => { + await expect(goodClient.ping("hi")).resolves.toMatchObject({ message: "pong: hi" }); + }); + + it("rejects a wrong token", async () => { + await expect(wrongClient.start()).rejects.toThrow(/AUTHENTICATION_FAILED/); + }); + + it("rejects a missing token when one is required", async () => { + await expect(noTokenClient.start()).rejects.toThrow(/AUTHENTICATION_FAILED/); + }); +}); + +describe("Connection token (auto-generated)", async () => { + const { copilotClient } = await createSdkTestContext({ useStdio: false }); + + it("the SDK-auto-generated UUID round-trips through the spawned CLI", async () => { + await copilotClient.start(); + await expect(copilotClient.ping("hi")).resolves.toMatchObject({ message: "pong: hi" }); + }); +}); diff --git a/nodejs/test/e2e/multi-client.test.ts b/nodejs/test/e2e/multi-client.test.ts index f23ae4459..d5accafd1 100644 --- a/nodejs/test/e2e/multi-client.test.ts +++ b/nodejs/test/e2e/multi-client.test.ts @@ -10,7 +10,8 @@ import { createSdkTestContext } from "./harness/sdkTestContext"; describe("Multi-client broadcast", async () => { // Use TCP mode so a second client can connect to the same CLI process - const ctx = await createSdkTestContext({ useStdio: false }); + const tcpConnectionToken = "multi-client-test-token"; + const ctx = await createSdkTestContext({ useStdio: false, copilotClientOptions: { tcpConnectionToken } }); const client1 = ctx.copilotClient; // Trigger connection so we can read the port @@ -18,7 +19,7 @@ describe("Multi-client broadcast", async () => { await initSession.disconnect(); const actualPort = (client1 as unknown as { actualPort: number }).actualPort; - let client2 = new CopilotClient({ cliUrl: `localhost:${actualPort}` }); + let client2 = new CopilotClient({ cliUrl: `localhost:${actualPort}`, tcpConnectionToken }); afterAll(async () => { await client2.stop(); @@ -297,7 +298,7 @@ describe("Multi-client broadcast", async () => { process.removeListener("unhandledRejection", suppressDisposed); // Recreate client2 for cleanup in afterAll (but don't rejoin the session) - client2 = new CopilotClient({ cliUrl: `localhost:${actualPort}` }); + client2 = new CopilotClient({ cliUrl: `localhost:${actualPort}`, tcpConnectionToken }); // Now only stable_tool should be available const afterResponse = await session1.sendAndWait({ diff --git a/nodejs/test/e2e/session_fs.test.ts b/nodejs/test/e2e/session_fs.test.ts index f6af24d34..37d34bf97 100644 --- a/nodejs/test/e2e/session_fs.test.ts +++ b/nodejs/test/e2e/session_fs.test.ts @@ -87,14 +87,15 @@ describe("Session Fs", async () => { }); it("should reject setProvider when sessions already exist", async () => { + const tcpConnectionToken = "session-fs-test-token"; const client = new CopilotClient({ useStdio: false, // Use TCP so we can connect from a second client + tcpConnectionToken, env, }); await client.createSession({ onPermissionRequest: approveAll, createSessionFsHandler }); - // Get the port the first client's runtime is listening on - const port = (client as unknown as { actualPort: number }).actualPort; + const { actualPort: port } = client as unknown as { actualPort: number }; // Second client tries to connect with a session fs — should fail // because sessions already exist on the runtime. @@ -102,6 +103,7 @@ describe("Session Fs", async () => { env, logLevel: "error", cliUrl: `localhost:${port}`, + tcpConnectionToken, sessionFs: sessionFsConfig, }); onTestFinished(() => client2.forceStop()); diff --git a/nodejs/test/e2e/ui_elicitation.test.ts b/nodejs/test/e2e/ui_elicitation.test.ts index ced735d88..302366937 100644 --- a/nodejs/test/e2e/ui_elicitation.test.ts +++ b/nodejs/test/e2e/ui_elicitation.test.ts @@ -53,15 +53,16 @@ describe("UI Elicitation Callback", async () => { describe("UI Elicitation Multi-Client Capabilities", async () => { // Use TCP mode so a second client can connect to the same CLI process - const ctx = await createSdkTestContext({ useStdio: false }); + const tcpConnectionToken = "ui-elicitation-test-token"; + const ctx = await createSdkTestContext({ useStdio: false, copilotClientOptions: { tcpConnectionToken } }); const client1 = ctx.copilotClient; // Trigger connection so we can read the port const initSession = await client1.createSession({ onPermissionRequest: approveAll }); await initSession.disconnect(); - const actualPort = (client1 as unknown as { actualPort: number }).actualPort; - const client2 = new CopilotClient({ cliUrl: `localhost:${actualPort}` }); + const { actualPort } = client1 as unknown as { actualPort: number }; + const client2 = new CopilotClient({ cliUrl: `localhost:${actualPort}`, tcpConnectionToken }); afterAll(async () => { await client2.stop(); @@ -134,7 +135,7 @@ describe("UI Elicitation Multi-Client Capabilities", async () => { }); // Use a dedicated client so we can stop it without affecting shared client2 - const client3 = new CopilotClient({ cliUrl: `localhost:${actualPort}` }); + const client3 = new CopilotClient({ cliUrl: `localhost:${actualPort}`, tcpConnectionToken }); // Client3 joins WITH elicitation handler await client3.resumeSession(session1.sessionId, { diff --git a/python/copilot/client.py b/python/copilot/client.py index 40ea71b83..081e50cd2 100644 --- a/python/copilot/client.py +++ b/python/copilot/client.py @@ -29,12 +29,14 @@ from types import TracebackType from typing import Any, Literal, TypedDict, cast, overload -from ._jsonrpc import JsonRpcClient, ProcessExitedError +from ._jsonrpc import JsonRpcClient, JsonRpcError, ProcessExitedError from ._sdk_protocol_version import get_sdk_protocol_version from ._telemetry import get_trace_context, trace_context from .generated.rpc import ( ClientSessionApiHandlers, + ConnectRequest, ServerRpc, + _InternalServerRpc, register_client_session_api_handlers, ) from .generated.session_events import ( @@ -126,6 +128,14 @@ class SubprocessConfig: use_stdio: bool = True """Use stdio transport (``True``, default) or TCP (``False``).""" + tcp_connection_token: str | None = None + """Connection token for the headless CLI server (TCP only). + + Only meaningful when ``use_stdio=False``. When the SDK spawns the CLI in TCP mode and + this is omitted, a UUID is generated automatically so the loopback listener is safe by + default. Combining this with ``use_stdio=True`` raises :class:`ValueError`. + """ + port: int = 0 """TCP port for the CLI server (only when ``use_stdio=False``). 0 means random.""" @@ -173,6 +183,10 @@ class ExternalServerConfig: _: KW_ONLY + tcp_connection_token: str | None = None + """Connection token sent in the ``connect`` handshake. Required when the server was + started with a token; ignored by legacy servers without ``connect`` support.""" + session_fs: SessionFsConfig | None = None """Connection-level session filesystem provider configuration.""" @@ -883,9 +897,17 @@ def __init__( if isinstance(config, ExternalServerConfig): self._actual_host, actual_port = self._parse_cli_url(config.url) self._actual_port: int | None = actual_port + self._effective_connection_token: str | None = config.tcp_connection_token else: self._actual_port = None + if config.tcp_connection_token is not None and config.use_stdio: + raise ValueError("tcp_connection_token cannot be used with use_stdio=True") + if config.use_stdio: + self._effective_connection_token = None + else: + self._effective_connection_token = config.tcp_connection_token or uuid.uuid4().hex + # Resolve CLI path: explicit > COPILOT_CLI_PATH env var > bundled binary effective_env = config.env if config.env is not None else os.environ if config.cli_path is None: @@ -2151,11 +2173,27 @@ def _dispatch_lifecycle_event(self, event: SessionLifecycleEvent) -> None: pass # Ignore handler errors async def _verify_protocol_version(self) -> None: - """Verify that the server's protocol version is within the supported range - and store the negotiated version.""" + """Send the ``connect`` handshake (with the optional token) and verify + the server's protocol version. Falls back to ``ping`` for legacy servers + that don't implement ``connect``.""" + if not self._client: + raise RuntimeError("Client not connected") max_version = get_sdk_protocol_version() - ping_result = await self.ping() - server_version = ping_result.protocolVersion + + server_version: int | None + try: + connect_result = await _InternalServerRpc(self._client).connect( + ConnectRequest(token=self._effective_connection_token) + ) + server_version = connect_result.protocol_version + except JsonRpcError as err: + if err.code == -32601: + # Legacy server without `connect`; fall back to `ping`. A token, if any, + # is silently dropped — the legacy server can't enforce one. + ping_result = await self.ping() + server_version = ping_result.protocolVersion + else: + raise if server_version is None: raise RuntimeError( @@ -2307,6 +2345,9 @@ async def _start_cli_server(self) -> None: if cfg.github_token: env["COPILOT_SDK_AUTH_TOKEN"] = cfg.github_token + if self._effective_connection_token: + env["COPILOT_CONNECTION_TOKEN"] = self._effective_connection_token + # Set OpenTelemetry environment variables if telemetry config is provided telemetry = cfg.telemetry if telemetry is not None: diff --git a/python/copilot/generated/rpc.py b/python/copilot/generated/rpc.py index dac331aa6..fc3eb7bdf 100644 --- a/python/copilot/generated/rpc.py +++ b/python/copilot/generated/rpc.py @@ -245,6 +245,51 @@ def to_dict(self) -> dict: result["success"] = from_bool(self.success) return result +# Internal: this type is an internal SDK API and is not part of the public surface. +@dataclass +class ConnectRequest: + token: str | None = None + """Connection token; required when the server was started with COPILOT_CONNECTION_TOKEN""" + + @staticmethod + def from_dict(obj: Any) -> 'ConnectRequest': + assert isinstance(obj, dict) + token = from_union([from_str, from_none], obj.get("token")) + return ConnectRequest(token) + + def to_dict(self) -> dict: + result: dict = {} + if self.token is not None: + result["token"] = from_union([from_str, from_none], self.token) + return result + +# Internal: this type is an internal SDK API and is not part of the public surface. +@dataclass +class ConnectResult: + ok: bool + """Always true on success""" + + protocol_version: int + """Server protocol version number""" + + version: str + """Server package version""" + + @staticmethod + def from_dict(obj: Any) -> 'ConnectResult': + assert isinstance(obj, dict) + ok = from_bool(obj.get("ok")) + protocol_version = from_int(obj.get("protocolVersion")) + version = from_str(obj.get("version")) + return ConnectResult(ok, protocol_version, version) + + def to_dict(self) -> dict: + result: dict = {} + result["ok"] = from_bool(self.ok) + result["protocolVersion"] = from_int(self.protocol_version) + result["version"] = from_str(self.version) + return result + @dataclass class CurrentModel: model_id: str | None = None @@ -5562,6 +5607,8 @@ class RPC: auth_info_type: AuthInfoType commands_handle_pending_command_request: CommandsHandlePendingCommandRequest commands_handle_pending_command_result: CommandsHandlePendingCommandResult + connect_request: ConnectRequest + connect_result: ConnectResult current_model: CurrentModel discovered_mcp_server: DiscoveredMCPServer discovered_mcp_server_source: MCPServerSource @@ -5788,6 +5835,8 @@ def from_dict(obj: Any) -> 'RPC': auth_info_type = AuthInfoType(obj.get("AuthInfoType")) commands_handle_pending_command_request = CommandsHandlePendingCommandRequest.from_dict(obj.get("CommandsHandlePendingCommandRequest")) commands_handle_pending_command_result = CommandsHandlePendingCommandResult.from_dict(obj.get("CommandsHandlePendingCommandResult")) + connect_request = ConnectRequest.from_dict(obj.get("ConnectRequest")) + connect_result = ConnectResult.from_dict(obj.get("ConnectResult")) current_model = CurrentModel.from_dict(obj.get("CurrentModel")) discovered_mcp_server = DiscoveredMCPServer.from_dict(obj.get("DiscoveredMcpServer")) discovered_mcp_server_source = MCPServerSource(obj.get("DiscoveredMcpServerSource")) @@ -5998,7 +6047,7 @@ def from_dict(obj: Any) -> 'RPC': workspaces_list_files_result = WorkspacesListFilesResult.from_dict(obj.get("WorkspacesListFilesResult")) workspaces_read_file_request = WorkspacesReadFileRequest.from_dict(obj.get("WorkspacesReadFileRequest")) workspaces_read_file_result = WorkspacesReadFileResult.from_dict(obj.get("WorkspacesReadFileResult")) - return RPC(account_get_quota_request, account_get_quota_result, account_quota_snapshot, agent_get_current_result, agent_info, agent_list, agent_reload_result, agent_select_request, agent_select_result, auth_info_type, commands_handle_pending_command_request, commands_handle_pending_command_result, current_model, discovered_mcp_server, discovered_mcp_server_source, discovered_mcp_server_type, embedded_blob_resource_contents, embedded_text_resource_contents, extension, extension_list, extensions_disable_request, extensions_enable_request, extension_source, extension_status, external_tool_result, external_tool_text_result_for_llm, external_tool_text_result_for_llm_content, external_tool_text_result_for_llm_content_audio, external_tool_text_result_for_llm_content_image, external_tool_text_result_for_llm_content_resource, external_tool_text_result_for_llm_content_resource_details, external_tool_text_result_for_llm_content_resource_link, external_tool_text_result_for_llm_content_resource_link_icon, external_tool_text_result_for_llm_content_resource_link_icon_theme, external_tool_text_result_for_llm_content_terminal, external_tool_text_result_for_llm_content_text, filter_mapping, filter_mapping_string, filter_mapping_value, fleet_start_request, fleet_start_result, handle_pending_tool_call_request, handle_pending_tool_call_result, history_compact_context_window, history_compact_result, history_truncate_request, history_truncate_result, instructions_get_sources_result, instructions_sources, instructions_sources_location, instructions_sources_type, log_request, log_result, mcp_config_add_request, mcp_config_disable_request, mcp_config_enable_request, mcp_config_list, mcp_config_remove_request, mcp_config_update_request, mcp_disable_request, mcp_discover_request, mcp_discover_result, mcp_enable_request, mcp_oauth_login_request, mcp_oauth_login_result, mcp_server, mcp_server_config, mcp_server_config_http, mcp_server_config_http_oauth_grant_type, mcp_server_config_http_type, mcp_server_config_local, mcp_server_config_local_type, mcp_server_list, mcp_server_source, mcp_server_status, model, model_billing, model_capabilities, model_capabilities_limits, model_capabilities_limits_vision, model_capabilities_override, model_capabilities_override_limits, model_capabilities_override_limits_vision, model_capabilities_override_supports, model_capabilities_supports, model_list, model_policy, models_list_request, model_switch_to_request, model_switch_to_result, mode_set_request, name_get_result, name_set_request, permission_decision, permission_decision_approve_for_location, permission_decision_approve_for_location_approval, permission_decision_approve_for_location_approval_commands, permission_decision_approve_for_location_approval_custom_tool, permission_decision_approve_for_location_approval_mcp, permission_decision_approve_for_location_approval_mcp_sampling, permission_decision_approve_for_location_approval_memory, permission_decision_approve_for_location_approval_read, permission_decision_approve_for_location_approval_write, permission_decision_approve_for_session, permission_decision_approve_for_session_approval, permission_decision_approve_for_session_approval_commands, permission_decision_approve_for_session_approval_custom_tool, permission_decision_approve_for_session_approval_mcp, permission_decision_approve_for_session_approval_mcp_sampling, permission_decision_approve_for_session_approval_memory, permission_decision_approve_for_session_approval_read, permission_decision_approve_for_session_approval_write, permission_decision_approve_once, permission_decision_approve_permanently, permission_decision_reject, permission_decision_request, permission_decision_user_not_available, permission_request_result, permissions_reset_session_approvals_request, permissions_reset_session_approvals_result, permissions_set_approve_all_request, permissions_set_approve_all_result, ping_request, ping_result, plan_read_result, plan_update_request, plugin, plugin_list, server_skill, server_skill_list, session_auth_status, session_fs_append_file_request, session_fs_error, session_fs_error_code, session_fs_exists_request, session_fs_exists_result, session_fs_mkdir_request, session_fs_readdir_request, session_fs_readdir_result, session_fs_readdir_with_types_entry, session_fs_readdir_with_types_entry_type, session_fs_readdir_with_types_request, session_fs_readdir_with_types_result, session_fs_read_file_request, session_fs_read_file_result, session_fs_rename_request, session_fs_rm_request, session_fs_set_provider_conventions, session_fs_set_provider_request, session_fs_set_provider_result, session_fs_stat_request, session_fs_stat_result, session_fs_write_file_request, session_log_level, session_mode, sessions_fork_request, sessions_fork_result, shell_exec_request, shell_exec_result, shell_kill_request, shell_kill_result, shell_kill_signal, skill, skill_list, skills_config_set_disabled_skills_request, skills_disable_request, skills_discover_request, skills_enable_request, task_agent_info, task_agent_info_execution_mode, task_agent_info_status, task_info, task_list, tasks_cancel_request, tasks_cancel_result, task_shell_info, task_shell_info_attachment_mode, task_shell_info_execution_mode, task_shell_info_status, tasks_promote_to_background_request, tasks_promote_to_background_result, tasks_remove_request, tasks_remove_result, tasks_start_agent_request, tasks_start_agent_result, tool, tool_list, tools_list_request, ui_elicitation_array_any_of_field, ui_elicitation_array_any_of_field_items, ui_elicitation_array_any_of_field_items_any_of, ui_elicitation_array_enum_field, ui_elicitation_array_enum_field_items, ui_elicitation_field_value, ui_elicitation_request, ui_elicitation_response, ui_elicitation_response_action, ui_elicitation_response_content, ui_elicitation_result, ui_elicitation_schema, ui_elicitation_schema_property, ui_elicitation_schema_property_boolean, ui_elicitation_schema_property_number, ui_elicitation_schema_property_number_type, ui_elicitation_schema_property_string, ui_elicitation_schema_property_string_format, ui_elicitation_string_enum_field, ui_elicitation_string_one_of_field, ui_elicitation_string_one_of_field_one_of, ui_handle_pending_elicitation_request, usage_get_metrics_result, usage_metrics_code_changes, usage_metrics_model_metric, usage_metrics_model_metric_requests, usage_metrics_model_metric_token_detail, usage_metrics_model_metric_usage, usage_metrics_token_detail, workspaces_create_file_request, workspaces_get_workspace_result, workspaces_list_files_result, workspaces_read_file_request, workspaces_read_file_result) + return RPC(account_get_quota_request, account_get_quota_result, account_quota_snapshot, agent_get_current_result, agent_info, agent_list, agent_reload_result, agent_select_request, agent_select_result, auth_info_type, commands_handle_pending_command_request, commands_handle_pending_command_result, connect_request, connect_result, current_model, discovered_mcp_server, discovered_mcp_server_source, discovered_mcp_server_type, embedded_blob_resource_contents, embedded_text_resource_contents, extension, extension_list, extensions_disable_request, extensions_enable_request, extension_source, extension_status, external_tool_result, external_tool_text_result_for_llm, external_tool_text_result_for_llm_content, external_tool_text_result_for_llm_content_audio, external_tool_text_result_for_llm_content_image, external_tool_text_result_for_llm_content_resource, external_tool_text_result_for_llm_content_resource_details, external_tool_text_result_for_llm_content_resource_link, external_tool_text_result_for_llm_content_resource_link_icon, external_tool_text_result_for_llm_content_resource_link_icon_theme, external_tool_text_result_for_llm_content_terminal, external_tool_text_result_for_llm_content_text, filter_mapping, filter_mapping_string, filter_mapping_value, fleet_start_request, fleet_start_result, handle_pending_tool_call_request, handle_pending_tool_call_result, history_compact_context_window, history_compact_result, history_truncate_request, history_truncate_result, instructions_get_sources_result, instructions_sources, instructions_sources_location, instructions_sources_type, log_request, log_result, mcp_config_add_request, mcp_config_disable_request, mcp_config_enable_request, mcp_config_list, mcp_config_remove_request, mcp_config_update_request, mcp_disable_request, mcp_discover_request, mcp_discover_result, mcp_enable_request, mcp_oauth_login_request, mcp_oauth_login_result, mcp_server, mcp_server_config, mcp_server_config_http, mcp_server_config_http_oauth_grant_type, mcp_server_config_http_type, mcp_server_config_local, mcp_server_config_local_type, mcp_server_list, mcp_server_source, mcp_server_status, model, model_billing, model_capabilities, model_capabilities_limits, model_capabilities_limits_vision, model_capabilities_override, model_capabilities_override_limits, model_capabilities_override_limits_vision, model_capabilities_override_supports, model_capabilities_supports, model_list, model_policy, models_list_request, model_switch_to_request, model_switch_to_result, mode_set_request, name_get_result, name_set_request, permission_decision, permission_decision_approve_for_location, permission_decision_approve_for_location_approval, permission_decision_approve_for_location_approval_commands, permission_decision_approve_for_location_approval_custom_tool, permission_decision_approve_for_location_approval_mcp, permission_decision_approve_for_location_approval_mcp_sampling, permission_decision_approve_for_location_approval_memory, permission_decision_approve_for_location_approval_read, permission_decision_approve_for_location_approval_write, permission_decision_approve_for_session, permission_decision_approve_for_session_approval, permission_decision_approve_for_session_approval_commands, permission_decision_approve_for_session_approval_custom_tool, permission_decision_approve_for_session_approval_mcp, permission_decision_approve_for_session_approval_mcp_sampling, permission_decision_approve_for_session_approval_memory, permission_decision_approve_for_session_approval_read, permission_decision_approve_for_session_approval_write, permission_decision_approve_once, permission_decision_approve_permanently, permission_decision_reject, permission_decision_request, permission_decision_user_not_available, permission_request_result, permissions_reset_session_approvals_request, permissions_reset_session_approvals_result, permissions_set_approve_all_request, permissions_set_approve_all_result, ping_request, ping_result, plan_read_result, plan_update_request, plugin, plugin_list, server_skill, server_skill_list, session_auth_status, session_fs_append_file_request, session_fs_error, session_fs_error_code, session_fs_exists_request, session_fs_exists_result, session_fs_mkdir_request, session_fs_readdir_request, session_fs_readdir_result, session_fs_readdir_with_types_entry, session_fs_readdir_with_types_entry_type, session_fs_readdir_with_types_request, session_fs_readdir_with_types_result, session_fs_read_file_request, session_fs_read_file_result, session_fs_rename_request, session_fs_rm_request, session_fs_set_provider_conventions, session_fs_set_provider_request, session_fs_set_provider_result, session_fs_stat_request, session_fs_stat_result, session_fs_write_file_request, session_log_level, session_mode, sessions_fork_request, sessions_fork_result, shell_exec_request, shell_exec_result, shell_kill_request, shell_kill_result, shell_kill_signal, skill, skill_list, skills_config_set_disabled_skills_request, skills_disable_request, skills_discover_request, skills_enable_request, task_agent_info, task_agent_info_execution_mode, task_agent_info_status, task_info, task_list, tasks_cancel_request, tasks_cancel_result, task_shell_info, task_shell_info_attachment_mode, task_shell_info_execution_mode, task_shell_info_status, tasks_promote_to_background_request, tasks_promote_to_background_result, tasks_remove_request, tasks_remove_result, tasks_start_agent_request, tasks_start_agent_result, tool, tool_list, tools_list_request, ui_elicitation_array_any_of_field, ui_elicitation_array_any_of_field_items, ui_elicitation_array_any_of_field_items_any_of, ui_elicitation_array_enum_field, ui_elicitation_array_enum_field_items, ui_elicitation_field_value, ui_elicitation_request, ui_elicitation_response, ui_elicitation_response_action, ui_elicitation_response_content, ui_elicitation_result, ui_elicitation_schema, ui_elicitation_schema_property, ui_elicitation_schema_property_boolean, ui_elicitation_schema_property_number, ui_elicitation_schema_property_number_type, ui_elicitation_schema_property_string, ui_elicitation_schema_property_string_format, ui_elicitation_string_enum_field, ui_elicitation_string_one_of_field, ui_elicitation_string_one_of_field_one_of, ui_handle_pending_elicitation_request, usage_get_metrics_result, usage_metrics_code_changes, usage_metrics_model_metric, usage_metrics_model_metric_requests, usage_metrics_model_metric_token_detail, usage_metrics_model_metric_usage, usage_metrics_token_detail, workspaces_create_file_request, workspaces_get_workspace_result, workspaces_list_files_result, workspaces_read_file_request, workspaces_read_file_result) def to_dict(self) -> dict: result: dict = {} @@ -6014,6 +6063,8 @@ def to_dict(self) -> dict: result["AuthInfoType"] = to_enum(AuthInfoType, self.auth_info_type) result["CommandsHandlePendingCommandRequest"] = to_class(CommandsHandlePendingCommandRequest, self.commands_handle_pending_command_request) result["CommandsHandlePendingCommandResult"] = to_class(CommandsHandlePendingCommandResult, self.commands_handle_pending_command_result) + result["ConnectRequest"] = to_class(ConnectRequest, self.connect_request) + result["ConnectResult"] = to_class(ConnectResult, self.connect_result) result["CurrentModel"] = to_class(CurrentModel, self.current_model) result["DiscoveredMcpServer"] = to_class(DiscoveredMCPServer, self.discovered_mcp_server) result["DiscoveredMcpServerSource"] = to_enum(MCPServerSource, self.discovered_mcp_server_source) @@ -6381,6 +6432,17 @@ async def ping(self, params: PingRequest, *, timeout: float | None = None) -> Pi return PingResult.from_dict(await self._client.request("ping", params_dict, **_timeout_kwargs(timeout))) +class _InternalServerRpc: + """Internal SDK server-scoped RPC methods (handshake helpers etc.). Not part of the public API.""" + def __init__(self, client: "JsonRpcClient"): + self._client = client + + async def connect(self, params: ConnectRequest, *, timeout: float | None = None) -> ConnectResult: + """:meta private: Internal SDK API; not part of the public surface.""" + params_dict = {k: v for k, v in params.to_dict().items() if v is not None} + return ConnectResult.from_dict(await self._client.request("connect", params_dict, **_timeout_kwargs(timeout))) + + class AuthApi: def __init__(self, client: "JsonRpcClient", session_id: str): self._client = client @@ -6763,6 +6825,9 @@ def __init__(self, client: "JsonRpcClient", session_id: str): self.history = HistoryApi(client, session_id) self.usage = UsageApi(client, session_id) + async def suspend(self, *, timeout: float | None = None) -> None: + await self._client.request("session.suspend", {"sessionId": self._session_id}, **_timeout_kwargs(timeout)) + async def log(self, params: LogRequest, *, timeout: float | None = None) -> LogResult: params_dict: dict[str, Any] = {k: v for k, v in params.to_dict().items() if v is not None} params_dict["sessionId"] = self._session_id diff --git a/python/e2e/test_connection_token.py b/python/e2e/test_connection_token.py new file mode 100644 index 000000000..d8d4bc526 --- /dev/null +++ b/python/e2e/test_connection_token.py @@ -0,0 +1,163 @@ +"""E2E Connection Token Tests + +Tests for the optional TCP ``connect`` token handshake. Mirrors the Node SDK's +``connection_token.test.ts``. +""" + +import os +import shutil +import tempfile + +import pytest +import pytest_asyncio + +from copilot import CopilotClient +from copilot.client import ExternalServerConfig, SubprocessConfig +from copilot.session import PermissionHandler + +from .testharness.proxy import CapiProxy + +pytestmark = pytest.mark.asyncio(loop_scope="module") + + +class ConnectionTokenContext: + """Spawns a TCP CLI server with an explicit connection token.""" + + def __init__(self, token: str | None): + self.token = token + self.cli_path: str = "" + self.home_dir: str = "" + self.work_dir: str = "" + self.proxy_url: str = "" + self._proxy: CapiProxy | None = None + self._client: CopilotClient | None = None + + async def setup(self): + from .testharness.context import get_cli_path_for_tests + + self.cli_path = get_cli_path_for_tests() + self.home_dir = tempfile.mkdtemp(prefix="copilot-token-config-") + self.work_dir = tempfile.mkdtemp(prefix="copilot-token-work-") + + self._proxy = CapiProxy() + self.proxy_url = await self._proxy.start() + + github_token = ( + "fake-token-for-e2e-tests" if os.environ.get("GITHUB_ACTIONS") == "true" else None + ) + + self._client = CopilotClient( + SubprocessConfig( + cli_path=self.cli_path, + cwd=self.work_dir, + env=self.get_env(), + use_stdio=False, + tcp_connection_token=self.token, + github_token=github_token, + ) + ) + + # Trigger the spawn + connect handshake so the server is listening. + await self._client.start() + + async def teardown(self): + if self._client: + try: + await self._client.stop() + except Exception: + pass + self._client = None + if self._proxy: + await self._proxy.stop(skip_writing_cache=True) + self._proxy = None + if self.home_dir and os.path.exists(self.home_dir): + shutil.rmtree(self.home_dir, ignore_errors=True) + if self.work_dir and os.path.exists(self.work_dir): + shutil.rmtree(self.work_dir, ignore_errors=True) + + def get_env(self) -> dict: + env = os.environ.copy() + env.update( + { + "COPILOT_API_URL": self.proxy_url, + "COPILOT_HOME": self.home_dir, + "XDG_CONFIG_HOME": self.home_dir, + "XDG_STATE_HOME": self.home_dir, + } + ) + return env + + @property + def client(self) -> CopilotClient: + if not self._client: + raise RuntimeError("Context not set up") + return self._client + + +@pytest_asyncio.fixture(scope="module", loop_scope="module") +async def explicit_token_ctx(): + ctx = ConnectionTokenContext(token="right-token") + await ctx.setup() + yield ctx + await ctx.teardown() + + +@pytest_asyncio.fixture(scope="module", loop_scope="module") +async def auto_token_ctx(): + ctx = ConnectionTokenContext(token=None) + await ctx.setup() + yield ctx + await ctx.teardown() + + +class TestConnectionToken: + async def test_explicit_token_round_trips(self, explicit_token_ctx: ConnectionTokenContext): + """Client started with an explicit token can ping successfully.""" + # Sanity-check that the token was forwarded to the spawned CLI and the + # `connect` handshake succeeded; a real ping must round-trip. + response = await explicit_token_ctx.client.ping("hi") + assert response.message == "pong: hi" + + # Bonus: a fresh session round-trip also exercises the live connection. + session = await explicit_token_ctx.client.create_session( + on_permission_request=PermissionHandler.approve_all + ) + await session.disconnect() + + async def test_auto_generated_token_round_trips(self, auto_token_ctx: ConnectionTokenContext): + """When the SDK spawns its own CLI in TCP mode without an explicit token, + the auto-generated UUID is forwarded and the `connect` handshake succeeds.""" + response = await auto_token_ctx.client.ping("hi") + assert response.message == "pong: hi" + + async def test_wrong_token_is_rejected(self, explicit_token_ctx: ConnectionTokenContext): + """A sibling client connecting with the wrong token is rejected.""" + port = explicit_token_ctx.client.actual_port + assert port is not None + + wrong = CopilotClient( + ExternalServerConfig(url=f"localhost:{port}", tcp_connection_token="wrong") + ) + try: + with pytest.raises(Exception, match="AUTHENTICATION_FAILED"): + await wrong.start() + finally: + try: + await wrong.force_stop() + except Exception: + pass + + async def test_missing_token_is_rejected(self, explicit_token_ctx: ConnectionTokenContext): + """A sibling client with no token is rejected when the server requires one.""" + port = explicit_token_ctx.client.actual_port + assert port is not None + + no_token = CopilotClient(ExternalServerConfig(url=f"localhost:{port}")) + try: + with pytest.raises(Exception, match="AUTHENTICATION_FAILED"): + await no_token.start() + finally: + try: + await no_token.force_stop() + except Exception: + pass diff --git a/scripts/codegen/csharp.ts b/scripts/codegen/csharp.ts index 9c8332c09..48636efa8 100644 --- a/scripts/codegen/csharp.ts +++ b/scripts/codegen/csharp.ts @@ -983,6 +983,16 @@ function emitRpcClass( resolveObjectSchema(schema, rpcDefinitions) ?? resolveSchema(schema, rpcDefinitions) ?? schema; + // Visibility is driven by the JSON Schema definition itself (set via + // `.asInternal()` on the originating Zod schema). The runtime schema + // generator enforces that no public method references an internal type, + // so it's safe to upgrade callers' default to internal here. + if ( + (schema as Record).visibility === "internal" || + (effectiveSchema as Record).visibility === "internal" + ) { + visibility = "internal"; + } const schemaKey = stableStringify(effectiveSchema); const existingSchema = emittedRpcClassSchemas.get(className); if (existingSchema) { @@ -1169,13 +1179,15 @@ function emitServerInstanceMethod( groupDeprecated: boolean ): void { const methodName = toPascalCase(name); + const isInternal = method.visibility === "internal"; + const methodVisibility = isInternal ? "internal" : "public"; const resultSchema = getMethodResultSchema(method); let resultClassName = !isVoidSchema(resultSchema) ? resultTypeName(method) : ""; if (!isVoidSchema(resultSchema) && method.stability === "experimental") { experimentalRpcTypes.add(resultClassName); } if (isObjectSchema(resultSchema)) { - const resultClass = emitRpcClass(resultClassName, resultSchema!, "public", classes); + const resultClass = emitRpcClass(resultClassName, resultSchema!, methodVisibility, classes); if (resultClass) classes.push(resultClass); } else if (!isVoidSchema(resultSchema)) { resultClassName = emitNonObjectResultType(resultClassName, resultSchema!, classes); @@ -1227,7 +1239,7 @@ function emitServerInstanceMethod( sigParams.push("CancellationToken cancellationToken = default"); const taskType = !isVoidSchema(resultSchema) ? `Task<${resultClassName}>` : "Task"; - lines.push(`${indent}public async ${taskType} ${methodName}Async(${sigParams.join(", ")})`); + lines.push(`${indent}${methodVisibility} async ${taskType} ${methodName}Async(${sigParams.join(", ")})`); lines.push(`${indent}{`); if (requestClassName && bodyAssignments.length > 0) { lines.push(`${indent} var request = new ${requestClassName} { ${bodyAssignments.join(", ")} };`); @@ -1275,13 +1287,15 @@ function emitSessionRpcClasses(node: Record, classes: string[]) function emitSessionMethod(key: string, method: RpcMethod, lines: string[], classes: string[], indent: string, groupExperimental: boolean, groupDeprecated: boolean): void { const methodName = toPascalCase(key); + const isInternal = method.visibility === "internal"; + const methodVisibility = isInternal ? "internal" : "public"; const resultSchema = getMethodResultSchema(method); let resultClassName = !isVoidSchema(resultSchema) ? resultTypeName(method) : ""; if (!isVoidSchema(resultSchema) && method.stability === "experimental") { experimentalRpcTypes.add(resultClassName); } if (isObjectSchema(resultSchema)) { - const resultClass = emitRpcClass(resultClassName, resultSchema!, "public", classes); + const resultClass = emitRpcClass(resultClassName, resultSchema!, methodVisibility, classes); if (resultClass) classes.push(resultClass); } else if (!isVoidSchema(resultSchema)) { resultClassName = emitNonObjectResultType(resultClassName, resultSchema!, classes); @@ -1327,7 +1341,7 @@ function emitSessionMethod(key: string, method: RpcMethod, lines: string[], clas sigParams.push("CancellationToken cancellationToken = default"); const taskType = !isVoidSchema(resultSchema) ? `Task<${resultClassName}>` : "Task"; - lines.push(`${indent}public async ${taskType} ${methodName}Async(${sigParams.join(", ")})`); + lines.push(`${indent}${methodVisibility} async ${taskType} ${methodName}Async(${sigParams.join(", ")})`); lines.push(`${indent}{`, `${indent} var request = new ${requestClassName} { ${bodyAssignments.join(", ")} };`); if (!isVoidSchema(resultSchema)) { lines.push(`${indent} return await CopilotClient.InvokeRpcAsync<${resultClassName}>(_rpc, "${method.rpcMethod}", [request], cancellationToken);`, `${indent}}`); diff --git a/scripts/codegen/go.ts b/scripts/codegen/go.ts index c1acc4980..b3506bf29 100644 --- a/scripts/codegen/go.ts +++ b/scripts/codegen/go.ts @@ -13,6 +13,7 @@ import { FetchingJSONSchemaStore, InputData, JSONSchemaInput, quicktype } from " import { promisify } from "util"; import { cloneSchemaForCodegen, + filterNodeByVisibility, fixNullableRequiredRefsInApiSchema, getApiSchemaPath, getRpcSchemaTypeName, @@ -25,6 +26,7 @@ import { getNullableInner, isRpcMethod, postProcessSchema, + stripBooleanLiterals, writeGeneratedFile, collectDefinitionCollections, resolveObjectSchema, @@ -1084,7 +1086,7 @@ async function generateRpc(schemaPath?: string): Promise { const singleSchema: JSONSchema7 = { $schema: "http://json-schema.org/draft-07/schema#", type: "object", - definitions: allDefinitions as Record, + definitions: stripBooleanLiterals(allDefinitions) as Record, properties: Object.fromEntries( Object.keys(allDefinitions).map((name) => [name, { $ref: `#/definitions/${name}` }]) ), @@ -1160,6 +1162,21 @@ async function generateRpc(schemaPath?: string): Promise { `// Deprecated: ${typeName} is deprecated and will be removed in a future version.\n$1` ); } + + // Annotate internal data types (driven by the JSON Schema definition's + // `visibility: "internal"` flag, set via `.asInternal()` on the Zod source). + const internalTypeNames = new Set(); + for (const [name, def] of Object.entries(allDefinitions)) { + if (def && typeof def === "object" && (def as Record).visibility === "internal") { + internalTypeNames.add(name); + } + } + for (const typeName of internalTypeNames) { + qtCode = qtCode.replace( + new RegExp(`^(type ${typeName} struct)`, "m"), + `// Internal: ${typeName} is an internal SDK API and is not part of the public surface.\n$1` + ); + } // Remove trailing blank lines from quicktype output before appending qtCode = qtCode.replace(/\n+$/, ""); // Replace interface{} with any (quicktype emits the pre-1.18 form) @@ -1195,12 +1212,18 @@ async function generateRpc(schemaPath?: string): Promise { // Emit ServerRpc if (schema.server) { - emitRpcWrapper(lines, schema.server, false, resolveType, fieldNames); + const publicNode = filterNodeByVisibility(schema.server, "public"); + if (publicNode) emitRpcWrapper(lines, publicNode, false, resolveType, fieldNames, ""); + const internalNode = filterNodeByVisibility(schema.server, "internal"); + if (internalNode) emitRpcWrapper(lines, internalNode, false, resolveType, fieldNames, "Internal"); } // Emit SessionRpc if (schema.session) { - emitRpcWrapper(lines, schema.session, true, resolveType, fieldNames); + const publicNode = filterNodeByVisibility(schema.session, "public"); + if (publicNode) emitRpcWrapper(lines, publicNode, true, resolveType, fieldNames, ""); + const internalNode = filterNodeByVisibility(schema.session, "internal"); + if (internalNode) emitRpcWrapper(lines, internalNode, true, resolveType, fieldNames, "Internal"); } if (schema.clientSession) { @@ -1256,13 +1279,17 @@ function emitApiGroup( } } -function emitRpcWrapper(lines: string[], node: Record, isSession: boolean, resolveType: (name: string) => string, fieldNames: Map>): void { +function emitRpcWrapper(lines: string[], node: Record, isSession: boolean, resolveType: (name: string) => string, fieldNames: Map>, classPrefix: string = ""): void { const groups = Object.entries(node).filter(([, v]) => typeof v === "object" && v !== null && !isRpcMethod(v)); const topLevelMethods = Object.entries(node).filter(([, v]) => isRpcMethod(v)); - const wrapperName = isSession ? "SessionRpc" : "ServerRpc"; + const wrapperName = classPrefix + (isSession ? "SessionRpc" : "ServerRpc"); const apiSuffix = "Api"; - const serviceName = isSession ? "sessionApi" : "serverApi"; + // Lowercase the prefix so the unexported service struct stays unexported in Go. + const prefixLower = classPrefix ? classPrefix.charAt(0).toLowerCase() + classPrefix.slice(1) : ""; + const serviceName = prefixLower + ? prefixLower + (isSession ? "SessionApi" : "ServerApi") + : (isSession ? "sessionApi" : "serverApi"); // Emit the common service struct (unexported, shared by all API groups via type cast) lines.push(`type ${serviceName} struct {`); @@ -1273,7 +1300,7 @@ function emitRpcWrapper(lines: string[], node: Record, isSessio // Emit API types for groups for (const [groupName, groupNode] of groups) { - const prefix = isSession ? "" : "Server"; + const prefix = classPrefix + (isSession ? "" : "Server"); const apiName = prefix + toPascalCase(groupName) + apiSuffix; const groupExperimental = isNodeFullyExperimental(groupNode as Record); const groupDeprecated = isNodeFullyDeprecated(groupNode as Record); @@ -1287,12 +1314,14 @@ function emitRpcWrapper(lines: string[], node: Record, isSessio const pad = (name: string) => name.padEnd(maxFieldLen); // Emit wrapper struct - lines.push(`// ${wrapperName} provides typed ${isSession ? "session" : "server"}-scoped RPC methods.`); + lines.push(classPrefix === "Internal" + ? `// ${wrapperName} provides internal SDK ${isSession ? "session" : "server"}-scoped RPC methods (handshake helpers etc.). Not part of the public API.` + : `// ${wrapperName} provides typed ${isSession ? "session" : "server"}-scoped RPC methods.`); lines.push(`type ${wrapperName} struct {`); lines.push(`\t${pad("common")} ${serviceName} // Reuse a single struct instead of allocating one for each service on the heap.`); lines.push(``); for (const [groupName] of groups) { - const prefix = isSession ? "" : "Server"; + const prefix = classPrefix + (isSession ? "" : "Server"); lines.push(`\t${pad(toPascalCase(groupName))} *${prefix}${toPascalCase(groupName)}${apiSuffix}`); } lines.push(`}`); @@ -1314,7 +1343,7 @@ function emitRpcWrapper(lines: string[], node: Record, isSessio lines.push(`\tr.common = ${serviceName}{client: client}`); } for (const [groupName] of groups) { - const prefix = isSession ? "" : "Server"; + const prefix = classPrefix + (isSession ? "" : "Server"); lines.push(`\tr.${toPascalCase(groupName)} = (*${prefix}${toPascalCase(groupName)}${apiSuffix})(&r.common)`); } lines.push(`\treturn r`); @@ -1347,6 +1376,9 @@ function emitMethod(lines: string[], receiver: string, name: string, method: Rpc if (method.stability === "experimental" && !groupExperimental) { lines.push(`// Experimental: ${methodName} is an experimental API and may change or be removed in future versions.`); } + if (method.visibility === "internal") { + lines.push(`// Internal: ${methodName} is part of the SDK's internal handshake/plumbing; external callers should not use it.`); + } const sig = hasParams ? `func (a *${receiver}) ${methodName}(ctx context.Context, params *${paramsType}) (*${resultType}, error)` : `func (a *${receiver}) ${methodName}(ctx context.Context) (*${resultType}, error)`; diff --git a/scripts/codegen/python.ts b/scripts/codegen/python.ts index 6a3fe3b7d..8d59b349f 100644 --- a/scripts/codegen/python.ts +++ b/scripts/codegen/python.ts @@ -12,6 +12,7 @@ import type { JSONSchema7 } from "json-schema"; import { fileURLToPath } from "url"; import { cloneSchemaForCodegen, + filterNodeByVisibility, fixNullableRequiredRefsInApiSchema, getApiSchemaPath, getRpcSchemaTypeName, @@ -24,6 +25,7 @@ import { isNodeFullyDeprecated, isSchemaDeprecated, postProcessSchema, + stripBooleanLiterals, writeGeneratedFile, collectDefinitionCollections, hasSchemaPayload, @@ -1652,7 +1654,7 @@ async function generateRpc(schemaPath?: string): Promise { const singleSchema: Record = { $schema: "http://json-schema.org/draft-07/schema#", type: "object", - definitions: allDefinitions, + definitions: stripBooleanLiterals(allDefinitions), properties: Object.fromEntries( Object.keys(allDefinitions).map((name) => [name, { $ref: `#/definitions/${name}` }]) ), @@ -1749,6 +1751,21 @@ async function generateRpc(schemaPath?: string): Promise { ); } + // Annotate internal data types (driven by the JSON Schema definition's + // `visibility: "internal"` flag, set via `.asInternal()` on the Zod source). + const internalTypeNames = new Set(); + for (const [name, def] of Object.entries(allDefinitions)) { + if (def && typeof def === "object" && (def as Record).visibility === "internal") { + internalTypeNames.add(name); + } + } + for (const typeName of internalTypeNames) { + typesCode = typesCode.replace( + new RegExp(`^(@dataclass\\n)?class ${typeName}[:(]`, "m"), + (match) => `# Internal: this type is an internal SDK API and is not part of the public surface.\n${match}` + ); + } + // Extract actual class names generated by quicktype (may differ from toPascalCase, // e.g. quicktype produces "SessionMCPList" not "SessionMcpList") const actualTypeNames = new Map(); @@ -1816,10 +1833,16 @@ def _patch_model_capabilities(data: dict) -> dict: // Emit RPC wrapper classes if (schema.server) { - emitRpcWrapper(lines, schema.server, false, resolveType); + const publicNode = filterNodeByVisibility(schema.server, "public"); + if (publicNode) emitRpcWrapper(lines, publicNode, false, resolveType, ""); + const internalNode = filterNodeByVisibility(schema.server, "internal"); + if (internalNode) emitRpcWrapper(lines, internalNode, false, resolveType, "_Internal"); } if (schema.session) { - emitRpcWrapper(lines, schema.session, true, resolveType); + const publicNode = filterNodeByVisibility(schema.session, "public"); + if (publicNode) emitRpcWrapper(lines, publicNode, true, resolveType, ""); + const internalNode = filterNodeByVisibility(schema.session, "internal"); + if (internalNode) emitRpcWrapper(lines, internalNode, true, resolveType, "_Internal"); } if (schema.clientSession) { emitClientSessionApiRegistration(lines, schema.clientSession, resolveType); @@ -1850,7 +1873,8 @@ function emitPyApiGroup( isSession: boolean, resolveType: (name: string) => string, groupExperimental: boolean, - groupDeprecated: boolean = false + groupDeprecated: boolean = false, + classPrefix: string = "" ): void { const subGroups = Object.entries(node).filter(([, v]) => typeof v === "object" && v !== null && !isRpcMethod(v)); @@ -1859,7 +1883,7 @@ function emitPyApiGroup( const subApiName = apiName.replace(/Api$/, "") + toPascalCase(subGroupName) + "Api"; const subGroupExperimental = isNodeFullyExperimental(subGroupNode as Record); const subGroupDeprecated = isNodeFullyDeprecated(subGroupNode as Record); - emitPyApiGroup(lines, subApiName, subGroupNode as Record, isSession, resolveType, subGroupExperimental, subGroupDeprecated); + emitPyApiGroup(lines, subApiName, subGroupNode as Record, isSession, resolveType, subGroupExperimental, subGroupDeprecated, classPrefix); } // Emit this class @@ -1895,38 +1919,43 @@ function emitPyApiGroup( lines.push(``); } -function emitRpcWrapper(lines: string[], node: Record, isSession: boolean, resolveType: (name: string) => string): void { +function emitRpcWrapper(lines: string[], node: Record, isSession: boolean, resolveType: (name: string) => string, classPrefix: string = ""): void { const groups = Object.entries(node).filter(([, v]) => typeof v === "object" && v !== null && !isRpcMethod(v)); const topLevelMethods = Object.entries(node).filter(([, v]) => isRpcMethod(v)); - const wrapperName = isSession ? "SessionRpc" : "ServerRpc"; + const wrapperName = classPrefix + (isSession ? "SessionRpc" : "ServerRpc"); // Emit API classes for groups (recursively handles sub-groups) for (const [groupName, groupNode] of groups) { - const prefix = isSession ? "" : "Server"; + const prefix = classPrefix + (isSession ? "" : "Server"); const apiName = prefix + toPascalCase(groupName) + "Api"; const groupExperimental = isNodeFullyExperimental(groupNode as Record); const groupDeprecated = isNodeFullyDeprecated(groupNode as Record); - emitPyApiGroup(lines, apiName, groupNode as Record, isSession, resolveType, groupExperimental, groupDeprecated); + emitPyApiGroup(lines, apiName, groupNode as Record, isSession, resolveType, groupExperimental, groupDeprecated, classPrefix); } // Emit wrapper class if (isSession) { lines.push(`class ${wrapperName}:`); - lines.push(` """Typed session-scoped RPC methods."""`); + lines.push(classPrefix === "_Internal" + ? ` """Internal SDK session-scoped RPC methods. Not part of the public API."""` + : ` """Typed session-scoped RPC methods."""`); lines.push(` def __init__(self, client: "JsonRpcClient", session_id: str):`); lines.push(` self._client = client`); lines.push(` self._session_id = session_id`); for (const [groupName] of groups) { - lines.push(` self.${toSnakeCase(groupName)} = ${toPascalCase(groupName)}Api(client, session_id)`); + const prefix = classPrefix + (isSession ? "" : "Server"); + lines.push(` self.${toSnakeCase(groupName)} = ${prefix}${toPascalCase(groupName)}Api(client, session_id)`); } } else { lines.push(`class ${wrapperName}:`); - lines.push(` """Typed server-scoped RPC methods."""`); + lines.push(classPrefix === "_Internal" + ? ` """Internal SDK server-scoped RPC methods (handshake helpers etc.). Not part of the public API."""` + : ` """Typed server-scoped RPC methods."""`); lines.push(` def __init__(self, client: "JsonRpcClient"):`); lines.push(` self._client = client`); for (const [groupName] of groups) { - lines.push(` self.${toSnakeCase(groupName)} = Server${toPascalCase(groupName)}Api(client)`); + lines.push(` self.${toSnakeCase(groupName)} = ${classPrefix}Server${toPascalCase(groupName)}Api(client)`); } } lines.push(``); @@ -1980,6 +2009,9 @@ function emitMethod(lines: string[], name: string, method: RpcMethod, isSession: if (method.stability === "experimental" && !groupExperimental) { lines.push(` """.. warning:: This API is experimental and may change or be removed in future versions."""`); } + if (method.visibility === "internal") { + lines.push(` """:meta private: Internal SDK API; not part of the public surface."""`); + } // Deserialize helper const innerTypeName = hasNullableResult ? resolveType(pythonResultTypeName(method, nullableInner)) : resultType; diff --git a/scripts/codegen/typescript.ts b/scripts/codegen/typescript.ts index d032c34fd..5fdb829ee 100644 --- a/scripts/codegen/typescript.ts +++ b/scripts/codegen/typescript.ts @@ -338,6 +338,16 @@ import type { MessageConnection } from "vscode-jsonrpc/node.js"; const experimentalTypes = new Set(); // Track which type names come from deprecated methods for JSDoc annotations. const deprecatedTypes = new Set(); + // Types are tagged @internal directly via `visibility: "internal"` on the JSON Schema + // definition (set by `.asInternal()` on the originating Zod schema). The runtime + // schema generator enforces that no public method references an internal type, so + // there's no transitive propagation to do here. + const internalTypes = new Set(); + for (const [name, def] of Object.entries(combinedSchema.definitions ?? {})) { + if (def && typeof def === "object" && (def as Record).visibility === "internal") { + internalTypes.add(name); + } + } for (const method of [...allMethods, ...clientSessionMethods]) { const resultSchema = getMethodResultSchema(method); @@ -425,29 +435,75 @@ import type { MessageConnection } from "vscode-jsonrpc/node.js"; `$1/** @deprecated */\n$2` ); } + // Add @internal JSDoc annotations for types from internal methods + for (const intType of internalTypes) { + annotatedTs = annotatedTs.replace( + new RegExp(`(^|\\n)(export (?:interface|type) ${intType}\\b)`, "m"), + `$1/** @internal */\n$2` + ); + } lines.push(annotatedTs); lines.push(""); } // Generate factory functions +function hasInternalMethods(node: Record): boolean { + for (const value of Object.values(node)) { + if (isRpcMethod(value)) { + if ((value as RpcMethod).visibility === "internal") return true; + } else if (typeof value === "object" && value !== null) { + if (hasInternalMethods(value as Record)) return true; + } + } + return false; +} + if (schema.server) { lines.push(`/** Create typed server-scoped RPC methods (no session required). */`); lines.push(`export function createServerRpc(connection: MessageConnection) {`); lines.push(` return {`); - lines.push(...emitGroup(schema.server, " ", false)); + lines.push(...emitGroup(schema.server, " ", false, false, false, "public")); lines.push(` };`); lines.push(`}`); lines.push(""); + + if (hasInternalMethods(schema.server)) { + lines.push(`/**`); + lines.push(` * Create typed server-scoped RPC methods that are part of the SDK's internal`); + lines.push(` * surface (e.g. handshake helpers). Not exported on the public client API.`); + lines.push(` * @internal`); + lines.push(` */`); + lines.push(`export function createInternalServerRpc(connection: MessageConnection) {`); + lines.push(` return {`); + lines.push(...emitGroup(schema.server, " ", false, false, false, "internal")); + lines.push(` };`); + lines.push(`}`); + lines.push(""); + } } if (schema.session) { lines.push(`/** Create typed session-scoped RPC methods. */`); lines.push(`export function createSessionRpc(connection: MessageConnection, sessionId: string) {`); lines.push(` return {`); - lines.push(...emitGroup(schema.session, " ", true)); + lines.push(...emitGroup(schema.session, " ", true, false, false, "public")); lines.push(` };`); lines.push(`}`); lines.push(""); + + if (hasInternalMethods(schema.session)) { + lines.push(`/**`); + lines.push(` * Create typed session-scoped RPC methods that are part of the SDK's internal`); + lines.push(` * surface. Not exported on the public client API.`); + lines.push(` * @internal`); + lines.push(` */`); + lines.push(`export function createInternalSessionRpc(connection: MessageConnection, sessionId: string) {`); + lines.push(` return {`); + lines.push(...emitGroup(schema.session, " ", true, false, false, "internal")); + lines.push(` };`); + lines.push(`}`); + lines.push(""); + } } // Generate client session API handler interfaces and registration function @@ -459,10 +515,20 @@ import type { MessageConnection } from "vscode-jsonrpc/node.js"; console.log(` ✓ ${outPath}`); } -function emitGroup(node: Record, indent: string, isSession: boolean, parentExperimental = false, parentDeprecated = false): string[] { +function emitGroup( + node: Record, + indent: string, + isSession: boolean, + parentExperimental = false, + parentDeprecated = false, + visibilityFilter?: "public" | "internal", +): string[] { const lines: string[] = []; for (const [key, value] of Object.entries(node)) { if (isRpcMethod(value)) { + const isInternalMethod = (value as RpcMethod).visibility === "internal"; + if (visibilityFilter === "public" && isInternalMethod) continue; + if (visibilityFilter === "internal" && !isInternalMethod) continue; const { rpcMethod, params } = value; const resultType = tsResultType(value); const paramsType = paramsTypeName(value); @@ -508,6 +574,16 @@ function emitGroup(node: Record, indent: string, isSession: boo } else if (typeof value === "object" && value !== null) { const groupExperimental = isNodeFullyExperimental(value as Record); const groupDeprecated = isNodeFullyDeprecated(value as Record); + const childLines = emitGroup( + value as Record, + indent + " ", + isSession, + groupExperimental, + groupDeprecated, + visibilityFilter, + ); + // Skip the wrapper if the visibility filter dropped every method in this subtree. + if (childLines.length === 0) continue; if (groupDeprecated) { lines.push(`${indent}/** @deprecated */`); } @@ -515,7 +591,7 @@ function emitGroup(node: Record, indent: string, isSession: boo lines.push(`${indent}/** @experimental */`); } lines.push(`${indent}${key}: {`); - lines.push(...emitGroup(value as Record, indent + " ", isSession, groupExperimental, groupDeprecated)); + lines.push(...childLines); lines.push(`${indent}},`); } } diff --git a/scripts/codegen/utils.ts b/scripts/codegen/utils.ts index 4a4c31f3f..7279abee0 100644 --- a/scripts/codegen/utils.ts +++ b/scripts/codegen/utils.ts @@ -128,6 +128,38 @@ export function postProcessSchema(schema: JSONSchema7): JSONSchema7 { return processed; } +/** + * Strip boolean literal constraints (`const: true/false`, `enum: [true]`, `enum: [false]`) + * from a schema, recursively. quicktype's Python and Go renderers attempt to derive + * identifier names from enum values; deriving a name from a boolean throws inside + * `snakeNameStyle` (TypeError: s.codePointAt is not a function). + * + * The literal narrowing isn't expressible in Python/Go anyway, so we drop it and + * keep just `type: "boolean"`. TypeScript/C# codegen runs on the original schema. + */ +export function stripBooleanLiterals(schema: T): T { + if (typeof schema !== "object" || schema === null) return schema; + if (Array.isArray(schema)) { + return schema.map((item) => stripBooleanLiterals(item)) as unknown as T; + } + const result: Record = {}; + const src = schema as unknown as Record; + const isBooleanType = src.type === "boolean"; + for (const [key, value] of Object.entries(src)) { + if (isBooleanType && key === "const" && typeof value === "boolean") continue; + if ( + isBooleanType && + key === "enum" && + Array.isArray(value) && + value.every((v) => typeof v === "boolean") + ) { + continue; + } + result[key] = stripBooleanLiterals(value); + } + return result as T; +} + /** * Normalize schema defects where a required property with a `$ref` to an object type * has a description explicitly mentioning "null" as a valid value. @@ -216,6 +248,7 @@ export interface RpcMethod { params: JSONSchema7 | null; result: JSONSchema7 | null; stability?: string; + visibility?: string; deprecated?: boolean; } @@ -374,6 +407,33 @@ export function isNodeFullyDeprecated(node: Record): boolean { return methods.length > 0 && methods.every(m => m.deprecated === true); } +/** + * Returns a filtered copy of an API tree containing only methods whose visibility + * matches `keep`. Sub-groups that end up empty are pruned. Returns null if nothing + * survives the filter. + * + * `"public"` keeps methods without `visibility === "internal"`. + * `"internal"` keeps methods with `visibility === "internal"`. + */ +export function filterNodeByVisibility( + node: Record, + keep: "public" | "internal", +): Record | null { + const result: Record = {}; + for (const [key, value] of Object.entries(node)) { + if (isRpcMethod(value)) { + const isInternal = (value as RpcMethod).visibility === "internal"; + if (keep === "public" && isInternal) continue; + if (keep === "internal" && !isInternal) continue; + result[key] = value; + } else if (typeof value === "object" && value !== null) { + const sub = filterNodeByVisibility(value as Record, keep); + if (sub) result[key] = sub; + } + } + return Object.keys(result).length === 0 ? null : result; +} + /** Returns true when a JSON Schema node is marked as deprecated. */ export function isSchemaDeprecated(schema: JSONSchema7 | null | undefined): boolean { return typeof schema === "object" && schema !== null && (schema as Record).deprecated === true;