diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index 5dd8581f1..e75de7c93 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -7,11 +7,13 @@ ### Breaking Changes ### Bug Fixes +* Fixed Databricks CLI `--profile` fallback by detecting the CLI version at init time. The previous error-based detection was broken because `--profile` is a global Cobra flag silently accepted by old CLIs. ### Security Vulnerabilities ### Documentation ### Internal Changes +* Detect Databricks CLI version at init time via `databricks version --output json`, enabling version-gated flag support. Successful detections are cached per CLI path; subprocess failures fall back to the most conservative command and are retried on the next call. ### API Changes diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java index 7855b73c7..3e0dfc298 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java @@ -30,11 +30,6 @@ public class CliTokenSource implements TokenSource { private String accessTokenField; private String expiryField; private Environment env; - // fallbackCmd is tried when the primary command fails with "unknown flag: --profile", - // indicating the CLI is too old to support --profile. Can be removed once support - // for CLI versions predating --profile is dropped. - // See: https://github.com/databricks/databricks-sdk-go/pull/1497 - private List fallbackCmd; /** * Internal exception that carries the clean stderr message but exposes full output for checks. @@ -58,24 +53,11 @@ public CliTokenSource( String accessTokenField, String expiryField, Environment env) { - this(cmd, tokenTypeField, accessTokenField, expiryField, env, null); - } - - public CliTokenSource( - List cmd, - String tokenTypeField, - String accessTokenField, - String expiryField, - Environment env, - List fallbackCmd) { - super(); this.cmd = OSUtils.get(env).getCliExecutableCommand(cmd); this.tokenTypeField = tokenTypeField; this.accessTokenField = accessTokenField; this.expiryField = expiryField; this.env = env; - this.fallbackCmd = - fallbackCmd != null ? OSUtils.get(env).getCliExecutableCommand(fallbackCmd) : null; } /** @@ -158,22 +140,6 @@ public Token getToken() { try { return execCliCommand(this.cmd); } catch (IOException e) { - String textToCheck = - e instanceof CliCommandException - ? ((CliCommandException) e).getFullOutput() - : e.getMessage(); - if (fallbackCmd != null - && textToCheck != null - && textToCheck.contains("unknown flag: --profile")) { - LOG.warn( - "Databricks CLI does not support --profile flag. Falling back to --host. " - + "Please upgrade your CLI to the latest version."); - try { - return execCliCommand(this.fallbackCmd); - } catch (IOException fallbackException) { - throw new DatabricksException(fallbackException.getMessage(), fallbackException); - } - } throw new DatabricksException(e.getMessage(), e); } } diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksCliCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksCliCredentialsProvider.java index ae401280d..daf1f21e2 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksCliCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksCliCredentialsProvider.java @@ -6,12 +6,19 @@ import com.databricks.sdk.core.oauth.OAuthHeaderFactory; import com.databricks.sdk.core.oauth.Token; import com.databricks.sdk.core.oauth.TokenSource; +import com.databricks.sdk.core.utils.Environment; import com.databricks.sdk.core.utils.OSUtils; import com.databricks.sdk.support.InternalApi; import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; +import java.io.IOException; +import java.io.InputStream; import java.nio.charset.StandardCharsets; import java.util.*; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; +import org.apache.commons.io.IOUtils; @InternalApi public class DatabricksCliCredentialsProvider implements CredentialsProvider { @@ -22,6 +29,28 @@ public class DatabricksCliCredentialsProvider implements CredentialsProvider { private static final ObjectMapper MAPPER = new ObjectMapper(); + // ---- Version detection ---- + + // --profile support added in CLI v0.207.1: https://github.com/databricks/cli/pull/855 + static final DatabricksCliVersion CLI_VERSION_FOR_PROFILE = new DatabricksCliVersion(0, 207, 1); + + // 5-second cap on `databricks version` so a hung CLI (slow first-run scan, antivirus, blocked + // stdin) does not wedge SDK init indefinitely. + private static final long VERSION_PROBE_TIMEOUT_SECONDS = 5; + + // Successful version probes keyed by cliPath. Subprocess failures (timeouts, non-zero exit, + // IO errors) and probes that returned UNKNOWN due to unparseable output are deliberately not + // cached, so a transient error does not pin every later token source to the conservative + // fallback for the rest of the process lifetime. + private static final Map VERSION_CACHE = new ConcurrentHashMap<>(); + + /** Test-only hook to clear the cross-test version cache. Package-private. */ + static void clearVersionCache() { + VERSION_CACHE.clear(); + } + + // ---- Scope validation ---- + /** Thrown when the cached CLI token's scopes don't match the SDK's configured scopes. */ static class ScopeMismatchException extends DatabricksException { ScopeMismatchException(String message) { @@ -36,59 +65,13 @@ static class ScopeMismatchException extends DatabricksException { private static final Set SCOPES_IGNORED_FOR_COMPARISON = Collections.singleton("offline_access"); + // ---- Public API ---- + @Override public String authType() { return DATABRICKS_CLI; } - /** - * Builds the CLI command arguments using --host (legacy path). - * - * @param cliPath Path to the databricks CLI executable - * @param config Configuration containing host, account ID, workspace ID, etc. - * @return List of command arguments - */ - List buildHostArgs(String cliPath, DatabricksConfig config) { - List cmd = - new ArrayList<>(Arrays.asList(cliPath, "auth", "token", "--host", config.getHost())); - if (config.getClientType() == ClientType.ACCOUNT) { - cmd.add("--account-id"); - cmd.add(config.getAccountId()); - } - return cmd; - } - - private CliTokenSource getDatabricksCliTokenSource(DatabricksConfig config) { - String cliPath = config.getDatabricksCliPath(); - if (cliPath == null) { - cliPath = OSUtils.get(config.getEnv()).getDatabricksCliPath(); - } - if (cliPath == null) { - LOG.debug("Databricks CLI could not be found"); - return null; - } - - List cmd; - List fallbackCmd = null; - - if (config.getProfile() != null) { - // When profile is set, use --profile as the primary command. - // The profile contains the full config (host, account_id, etc.). - cmd = - new ArrayList<>( - Arrays.asList(cliPath, "auth", "token", "--profile", config.getProfile())); - // Build a --host fallback for older CLIs that don't support --profile. - if (config.getHost() != null) { - fallbackCmd = buildHostArgs(cliPath, config); - } - } else { - cmd = buildHostArgs(cliPath, config); - } - - return new CliTokenSource( - cmd, "token_type", "access_token", "expiry", config.getEnv(), fallbackCmd); - } - @Override public OAuthHeaderFactory configure(DatabricksConfig config) { String host = config.getHost(); @@ -151,6 +134,227 @@ public Token getToken() { } } + // ---- Token source construction ---- + + private CliTokenSource getDatabricksCliTokenSource(DatabricksConfig config) { + String cliPath = config.getDatabricksCliPath(); + if (cliPath == null) { + cliPath = OSUtils.get(config.getEnv()).getDatabricksCliPath(); + } + if (cliPath == null) { + LOG.debug("Databricks CLI could not be found"); + return null; + } + + List cmd = resolveCliCommand(cliPath, config); + return new CliTokenSource(cmd, "token_type", "access_token", "expiry", config.getEnv()); + } + + /** + * Detects the installed CLI version and builds the {@code auth token} command. Falls back to the + * most conservative command when version detection fails. + */ + List resolveCliCommand(String cliPath, DatabricksConfig config) { + DatabricksCliVersion version = getCliVersion(cliPath, config.getEnv()); + return buildCliCommand(cliPath, config, version); + } + + /** + * Builds the {@code auth token} command for the given CLI version. + * + *

Falls back to {@code --host} when {@code --profile} is either not configured or not + * supported by the installed CLI. + */ + List buildCliCommand( + String cliPath, DatabricksConfig config, DatabricksCliVersion version) { + if (config.getProfile() == null) { + return buildHostArgs(cliPath, config); + } + + // Flag --profile is a global CLI flag and is recognized for all commands even the ones that + // do not support it. Only use --profile in CLI versions known to support it in `auth token`. + if (!version.atLeast(CLI_VERSION_FOR_PROFILE)) { + if (version.isDefaultDevBuild()) { + // A default-marker dev build has no injected version, so every feature gate fails. + // Surface an informational hint so users know why their feature flags aren't taking + // effect. + LOG.info( + "Databricks CLI {} is a development build; feature detection will use conservative " + + "fallbacks. Rebuild the CLI with an explicit version to enable capability-based " + + "flag selection.", + version); + } else if (version.equals(DatabricksCliVersion.UNKNOWN)) { + LOG.warn( + "Could not confirm --profile support for Databricks CLI {} (requires >= {}). " + + "Falling back to --host.", + version, + CLI_VERSION_FOR_PROFILE); + } else { + LOG.warn( + "Databricks CLI {} does not support --profile (requires >= {}). " + + "Falling back to --host.", + version, + CLI_VERSION_FOR_PROFILE); + } + return buildHostArgs(cliPath, config); + } + + return new ArrayList<>( + Arrays.asList(cliPath, "auth", "token", "--profile", config.getProfile())); + } + + /** + * Builds the CLI command arguments using --host (legacy path). + * + * @param cliPath Path to the databricks CLI executable + * @param config Configuration containing host, account ID, workspace ID, etc. + * @return List of command arguments + */ + List buildHostArgs(String cliPath, DatabricksConfig config) { + if (config.getHost() == null) { + // Without this guard a null host would silently produce ["--host", null] and surface as + // an obscure NPE deep inside ProcessBuilder.start(). The production path is gated by + // configure()'s early return, but a future caller (or a direct unit test) could bypass it. + throw new DatabricksException( + "Cannot build Databricks CLI auth command: config.host is required"); + } + List cmd = + new ArrayList<>(Arrays.asList(cliPath, "auth", "token", "--host", config.getHost())); + if (config.getClientType() == ClientType.ACCOUNT) { + cmd.add("--account-id"); + cmd.add(config.getAccountId()); + } + return cmd; + } + + // ---- Version detection ---- + + /** + * Returns the CLI version, catching subprocess failures so the caller can proceed with the + * conservative fallback. Successful results are cached per {@code cliPath} for the process + * lifetime; failures are not cached and will be retried on the next call. + */ + DatabricksCliVersion getCliVersion(String cliPath, Environment env) { + DatabricksCliVersion cached = VERSION_CACHE.get(cliPath); + if (cached != null) { + return cached; + } + + try { + DatabricksCliVersion version = probeCliVersion(cliPath, env); + // Don't cache UNKNOWN: a parseable-but-malformed payload (e.g. a transiently corrupt + // CLI response) would otherwise pin every later token source to the conservative + // fallback for the rest of the process lifetime. + if (!version.equals(DatabricksCliVersion.UNKNOWN)) { + VERSION_CACHE.put(cliPath, version); + } + return version; + } catch (IOException e) { + LOG.warn( + "Failed to detect Databricks CLI version: {}. Falling back to conservative flag set.", + e.getMessage(), + e); + return DatabricksCliVersion.UNKNOWN; + } + } + + /** + * Runs {@code databricks version --output json} and returns the parsed {@link + * DatabricksCliVersion}. + * + *

Reads stdout after {@code waitFor} returns. This is safe only because the {@code version} + * subcommand emits a fixed-shape JSON blob (~200 bytes) that fits well within any platform's pipe + * buffer (>= 4 KB on Windows, ~64 KB on Linux/macOS). Do not reuse this method for + * unbounded-output subcommands — the pattern would deadlock if the child ever filled the pipe. + */ + DatabricksCliVersion probeCliVersion(String cliPath, Environment env) throws IOException { + List versionArgs = Arrays.asList(cliPath, "version", "--output", "json"); + List cmd = OSUtils.get(env).getCliExecutableCommand(versionArgs); + + ProcessBuilder pb = new ProcessBuilder(cmd); + pb.environment().putAll(env.getEnv()); + // Merge stderr into stdout so we drain a single stream and surface any stderr diagnostics + // in the same message on non-zero exit. + pb.redirectErrorStream(true); + Process process = pb.start(); + + try { + if (!process.waitFor(VERSION_PROBE_TIMEOUT_SECONDS, TimeUnit.SECONDS)) { + throw new IOException( + "timed out after " + + VERSION_PROBE_TIMEOUT_SECONDS + + "s waiting for `databricks version`"); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new IOException("interrupted waiting for `databricks version`", e); + } finally { + // No-op when the process exited cleanly; cleans up timeout/interrupt paths. + // destroyForcibly() is asynchronous on some platforms; await briefly so the OS releases + // the process handle and FDs before we return. + try { + process.destroyForcibly().waitFor(1, TimeUnit.SECONDS); + } catch (InterruptedException ignored) { + Thread.currentThread().interrupt(); + } + } + + String output = readStream(process.getInputStream()); + if (process.exitValue() != 0) { + throw new IOException( + "`databricks version` exited with code " + process.exitValue() + ": " + output.trim()); + } + return parseCliVersion(output); + } + + /** + * Parses the JSON output of {@code databricks version --output json}. + * + *

Takes Major/Minor/Patch from the JSON's pre-parsed numeric fields. The Prerelease field and + * the Version string are intentionally ignored: for our feature-gate purposes the base triple is + * sufficient, and the (0, 0, 0) case already identifies the default dev build (a CLI built + * without version metadata leaves these fields at their zero defaults). + * + *

Returns {@link DatabricksCliVersion#UNKNOWN} on failure so that an unparseable version + * disables every feature gate. + */ + DatabricksCliVersion parseCliVersion(String output) { + try { + JsonNode node = MAPPER.readTree(output); + JsonNode major = node.get("Major"); + JsonNode minor = node.get("Minor"); + JsonNode patch = node.get("Patch"); + if (major == null || minor == null || patch == null) { + LOG.debug( + "Failed to parse Databricks CLI version: missing Major/Minor/Patch in {}", output); + return DatabricksCliVersion.UNKNOWN; + } + // JsonNode.asInt() silently coerces strings, JSON null, arrays, and objects to 0, which + // would collide with the dev-build sentinel (0,0,0). Only accept genuine integers so a + // garbage payload returns UNKNOWN instead of "valid dev build". + if (!major.isIntegralNumber() || !minor.isIntegralNumber() || !patch.isIntegralNumber()) { + LOG.debug( + "Failed to parse Databricks CLI version: non-integer Major/Minor/Patch in {}", output); + return DatabricksCliVersion.UNKNOWN; + } + return new DatabricksCliVersion(major.asInt(), minor.asInt(), patch.asInt()); + } catch (JsonProcessingException e) { + LOG.debug( + "Failed to parse Databricks CLI version from output: {} ({})", output, e.getMessage()); + return DatabricksCliVersion.UNKNOWN; + } + } + + private static String readStream(InputStream stream) throws IOException { + try { + return new String(IOUtils.toByteArray(stream), StandardCharsets.UTF_8); + } finally { + stream.close(); + } + } + + // ---- Scope validation ---- + /** * Validate that the token's scopes match the requested scopes from the config. * diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksCliVersion.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksCliVersion.java new file mode 100644 index 000000000..6d1bdc03d --- /dev/null +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksCliVersion.java @@ -0,0 +1,99 @@ +package com.databricks.sdk.core; + +import com.databricks.sdk.support.InternalApi; +import java.util.Objects; + +/** + * Semver version triple of the Databricks CLI used for capability gating. + * + *

Three sentinel states in the (major, minor, patch) tuple: + * + *

    + *
  • {@code (-1, -1, -1)} — the {@link #UNKNOWN} sentinel, meaning version detection failed. It + * compares less than every real release so every feature gate fails. + *
  • {@code (0, 0, 0)} — the CLI's default dev build, emitted when the binary was built without + * version metadata. See {@link #isDefaultDevBuild()}. + *
  • anything else — a real CLI version. + *
+ * + *

Prerelease tags are deliberately ignored: feature gates are release-based, so a prerelease of + * a version with a flag is assumed to have the flag too. + */ +@InternalApi +public final class DatabricksCliVersion implements Comparable { + public static final DatabricksCliVersion UNKNOWN = new DatabricksCliVersion(-1, -1, -1); + + /** Default dev build sentinel — emitted when the CLI was built without version metadata. */ + public static final DatabricksCliVersion DEFAULT_DEV_BUILD = new DatabricksCliVersion(0, 0, 0); + + private final int major; + private final int minor; + private final int patch; + + // Package-private so tests and the credentials provider can construct concrete versions, but + // external callers cannot manufacture another instance that .equals(UNKNOWN) but isn't the + // singleton. + DatabricksCliVersion(int major, int minor, int patch) { + this.major = major; + this.minor = minor; + this.patch = patch; + } + + public int getMajor() { + return major; + } + + public int getMinor() { + return minor; + } + + public int getPatch() { + return patch; + } + + /** Returns true when {@code this} is greater than or equal to {@code other}. */ + public boolean atLeast(DatabricksCliVersion other) { + return compareTo(other) >= 0; + } + + /** + * Returns true when the version is the CLI's {@link #DEFAULT_DEV_BUILD} sentinel. A CLI built + * without version metadata leaves these fields at their zero defaults. + */ + public boolean isDefaultDevBuild() { + return equals(DEFAULT_DEV_BUILD); + } + + @Override + public int compareTo(DatabricksCliVersion o) { + int c = Integer.compare(major, o.major); + if (c != 0) return c; + c = Integer.compare(minor, o.minor); + if (c != 0) return c; + return Integer.compare(patch, o.patch); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (!(o instanceof DatabricksCliVersion)) return false; + DatabricksCliVersion that = (DatabricksCliVersion) o; + return major == that.major && minor == that.minor && patch == that.patch; + } + + @Override + public int hashCode() { + return Objects.hash(major, minor, patch); + } + + @Override + public String toString() { + if (equals(UNKNOWN)) { + return "unknown"; + } + if (isDefaultDevBuild()) { + return "v0.0.0-dev"; + } + return "v" + major + "." + minor + "." + patch; + } +} diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java index 8476c6de5..28d3deaf6 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java @@ -28,7 +28,6 @@ import java.util.List; import java.util.Map; import java.util.TimeZone; -import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; import java.util.stream.IntStream; import java.util.stream.Stream; @@ -217,128 +216,24 @@ public void testParseExpiry(String input, Instant expectedInstant, String descri } } - // ---- Fallback tests for --profile flag handling ---- + // ---- Error propagation ---- - private CliTokenSource makeTokenSource( - Environment env, List primaryCmd, List fallbackCmd) { + private CliTokenSource makeTokenSource(Environment env, List cmd) { OSUtilities osUtils = mock(OSUtilities.class); when(osUtils.getCliExecutableCommand(any())).thenAnswer(inv -> inv.getArgument(0)); try (MockedStatic mockedOSUtils = mockStatic(OSUtils.class)) { mockedOSUtils.when(() -> OSUtils.get(any())).thenReturn(osUtils); - return new CliTokenSource( - primaryCmd, "token_type", "access_token", "expiry", env, fallbackCmd); - } - } - - private String validTokenJson(String accessToken) { - String expiry = - ZonedDateTime.now() - .plusHours(1) - .format(DateTimeFormatter.ofPattern("yyyy-MM-dd'T'HH:mm:ss.SSSXXX")); - return String.format( - "{\"token_type\":\"Bearer\",\"access_token\":\"%s\",\"expiry\":\"%s\"}", - accessToken, expiry); - } - - @Test - public void testFallbackOnUnknownProfileFlagInStderr() { - Environment env = mock(Environment.class); - when(env.getEnv()).thenReturn(new HashMap<>()); - - List primaryCmd = - Arrays.asList("databricks", "auth", "token", "--profile", "my-profile"); - List fallbackCmdList = - Arrays.asList("databricks", "auth", "token", "--host", "https://workspace.databricks.com"); - - CliTokenSource tokenSource = makeTokenSource(env, primaryCmd, fallbackCmdList); - - AtomicInteger callCount = new AtomicInteger(0); - try (MockedConstruction mocked = - mockConstruction( - ProcessBuilder.class, - (pb, context) -> { - if (callCount.getAndIncrement() == 0) { - Process failProcess = mock(Process.class); - when(failProcess.getInputStream()) - .thenReturn(new ByteArrayInputStream(new byte[0])); - when(failProcess.getErrorStream()) - .thenReturn( - new ByteArrayInputStream("Error: unknown flag: --profile".getBytes())); - when(failProcess.waitFor()).thenReturn(1); - when(pb.start()).thenReturn(failProcess); - } else { - Process successProcess = mock(Process.class); - when(successProcess.getInputStream()) - .thenReturn( - new ByteArrayInputStream(validTokenJson("fallback-token").getBytes())); - when(successProcess.getErrorStream()) - .thenReturn(new ByteArrayInputStream(new byte[0])); - when(successProcess.waitFor()).thenReturn(0); - when(pb.start()).thenReturn(successProcess); - } - })) { - Token token = tokenSource.getToken(); - assertEquals("fallback-token", token.getAccessToken()); - assertEquals(2, mocked.constructed().size()); + return new CliTokenSource(cmd, "token_type", "access_token", "expiry", env); } } @Test - public void testFallbackTriggeredWhenUnknownFlagInStdout() { - // Fallback triggers even when "unknown flag" appears in stdout rather than stderr. + public void testCliErrorPropagates() { Environment env = mock(Environment.class); when(env.getEnv()).thenReturn(new HashMap<>()); - List primaryCmd = - Arrays.asList("databricks", "auth", "token", "--profile", "my-profile"); - List fallbackCmdList = - Arrays.asList("databricks", "auth", "token", "--host", "https://workspace.databricks.com"); - - CliTokenSource tokenSource = makeTokenSource(env, primaryCmd, fallbackCmdList); - - AtomicInteger callCount = new AtomicInteger(0); - try (MockedConstruction mocked = - mockConstruction( - ProcessBuilder.class, - (pb, context) -> { - if (callCount.getAndIncrement() == 0) { - Process failProcess = mock(Process.class); - when(failProcess.getInputStream()) - .thenReturn( - new ByteArrayInputStream("Error: unknown flag: --profile".getBytes())); - when(failProcess.getErrorStream()) - .thenReturn(new ByteArrayInputStream(new byte[0])); - when(failProcess.waitFor()).thenReturn(1); - when(pb.start()).thenReturn(failProcess); - } else { - Process successProcess = mock(Process.class); - when(successProcess.getInputStream()) - .thenReturn( - new ByteArrayInputStream(validTokenJson("fallback-token").getBytes())); - when(successProcess.getErrorStream()) - .thenReturn(new ByteArrayInputStream(new byte[0])); - when(successProcess.waitFor()).thenReturn(0); - when(pb.start()).thenReturn(successProcess); - } - })) { - Token token = tokenSource.getToken(); - assertEquals("fallback-token", token.getAccessToken()); - assertEquals(2, mocked.constructed().size()); - } - } - - @Test - public void testNoFallbackOnRealAuthError() { - // When the primary fails with a real error (not unknown flag), no fallback is attempted. - Environment env = mock(Environment.class); - when(env.getEnv()).thenReturn(new HashMap<>()); - - List primaryCmd = - Arrays.asList("databricks", "auth", "token", "--profile", "my-profile"); - List fallbackCmdList = - Arrays.asList("databricks", "auth", "token", "--host", "https://workspace.databricks.com"); - - CliTokenSource tokenSource = makeTokenSource(env, primaryCmd, fallbackCmdList); + CliTokenSource tokenSource = + makeTokenSource(env, Arrays.asList("databricks", "auth", "token", "--host", "https://x")); try (MockedConstruction mocked = mockConstruction( @@ -358,33 +253,4 @@ public void testNoFallbackOnRealAuthError() { assertEquals(1, mocked.constructed().size()); } } - - @Test - public void testNoFallbackWhenFallbackCmdNotSet() { - // When fallbackCmd is null and the primary fails with unknown flag, original error propagates. - Environment env = mock(Environment.class); - when(env.getEnv()).thenReturn(new HashMap<>()); - - List primaryCmd = - Arrays.asList("databricks", "auth", "token", "--profile", "my-profile"); - - CliTokenSource tokenSource = makeTokenSource(env, primaryCmd, null); - - try (MockedConstruction mocked = - mockConstruction( - ProcessBuilder.class, - (pb, context) -> { - Process failProcess = mock(Process.class); - when(failProcess.getInputStream()).thenReturn(new ByteArrayInputStream(new byte[0])); - when(failProcess.getErrorStream()) - .thenReturn( - new ByteArrayInputStream("Error: unknown flag: --profile".getBytes())); - when(failProcess.waitFor()).thenReturn(1); - when(pb.start()).thenReturn(failProcess); - })) { - DatabricksException ex = assertThrows(DatabricksException.class, tokenSource::getToken); - assertTrue(ex.getMessage().contains("unknown flag: --profile")); - assertEquals(1, mocked.constructed().size()); - } - } } diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/DatabricksCliCredentialsProviderTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/DatabricksCliCredentialsProviderTest.java index 0f1ca5059..817027b85 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/DatabricksCliCredentialsProviderTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/DatabricksCliCredentialsProviderTest.java @@ -1,10 +1,37 @@ package com.databricks.sdk.core; import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockConstruction; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import com.databricks.sdk.core.utils.Environment; +import com.databricks.sdk.core.utils.OSUtilities; +import com.databricks.sdk.core.utils.OSUtils; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.charset.StandardCharsets; import java.util.Arrays; +import java.util.HashMap; import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Stream; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.MockedConstruction; +import org.mockito.MockedStatic; class DatabricksCliCredentialsProviderTest { @@ -12,11 +39,19 @@ class DatabricksCliCredentialsProviderTest { private static final String HOST = "https://my-workspace.cloud.databricks.com"; private static final String ACCOUNT_HOST = "https://accounts.cloud.databricks.com"; private static final String ACCOUNT_ID = "test-account-123"; + private static final String PROFILE = "my-profile"; private final DatabricksCliCredentialsProvider provider = new DatabricksCliCredentialsProvider(); + @BeforeEach + void resetVersionCache() { + DatabricksCliCredentialsProvider.clearVersionCache(); + } + + // ---- buildHostArgs tests ---- + @Test - void testBuildHostArgs_WorkspaceHost() { + void testBuildHostCommand_WorkspaceHost() { DatabricksConfig config = new DatabricksConfig().setHost(HOST); List cmd = provider.buildHostArgs(CLI_PATH, config); @@ -25,7 +60,7 @@ void testBuildHostArgs_WorkspaceHost() { } @Test - void testBuildHostArgs_AccountHost() { + void testBuildHostCommand_AccountHost() { DatabricksConfig config = new DatabricksConfig().setHost(ACCOUNT_HOST).setAccountId(ACCOUNT_ID); List cmd = provider.buildHostArgs(CLI_PATH, config); @@ -37,7 +72,7 @@ void testBuildHostArgs_AccountHost() { } @Test - void testBuildHostArgs_NonAccountsHostWithAccountId() { + void testBuildHostCommand_NonAccountsHostWithAccountId() { // Non-accounts hosts should not pass --account-id even if accountId is set DatabricksConfig config = new DatabricksConfig().setHost(HOST).setAccountId(ACCOUNT_ID); @@ -45,4 +80,311 @@ void testBuildHostArgs_NonAccountsHostWithAccountId() { assertEquals(Arrays.asList(CLI_PATH, "auth", "token", "--host", HOST), cmd); } + + @Test + void testBuildHostCommand_NullHost_ThrowsClearError() { + DatabricksConfig config = new DatabricksConfig(); // no host + + DatabricksException ex = + assertThrows(DatabricksException.class, () -> provider.buildHostArgs(CLI_PATH, config)); + assertTrue(ex.getMessage().contains("host is required"), ex.getMessage()); + } + + @Test + void testBuildCliCommand_ProfileWithNullHost_ThrowsClearError() { + // profile + null host + old CLI → would fall through to buildHostArgs and emit ["--host", + // null]. The buildHostArgs guard turns that into a clear DatabricksException instead. + DatabricksConfig config = new DatabricksConfig().setProfile(PROFILE); + + assertThrows( + DatabricksException.class, + () -> provider.buildCliCommand(CLI_PATH, config, new DatabricksCliVersion(0, 207, 0))); + } + + // ---- buildCliCommand tests ---- + + private static Stream buildCliCommandCases() { + return Stream.of( + Arguments.of( + "host only — old CLI", + new DatabricksConfig().setHost(HOST), + new DatabricksCliVersion(0, 200, 0), + Arrays.asList(CLI_PATH, "auth", "token", "--host", HOST)), + Arguments.of( + "account host — old CLI", + new DatabricksConfig().setHost(ACCOUNT_HOST).setAccountId(ACCOUNT_ID), + new DatabricksCliVersion(0, 200, 0), + Arrays.asList( + CLI_PATH, "auth", "token", "--host", ACCOUNT_HOST, "--account-id", ACCOUNT_ID)), + Arguments.of( + "profile with new CLI — uses --profile", + new DatabricksConfig().setProfile(PROFILE).setHost(HOST), + DatabricksCliCredentialsProvider.CLI_VERSION_FOR_PROFILE, + Arrays.asList(CLI_PATH, "auth", "token", "--profile", PROFILE)), + Arguments.of( + "profile with old CLI — falls back to --host", + new DatabricksConfig().setProfile(PROFILE).setHost(HOST), + new DatabricksCliVersion(0, 207, 0), + Arrays.asList(CLI_PATH, "auth", "token", "--host", HOST)), + Arguments.of( + "unknown version — falls back to --host", + new DatabricksConfig().setProfile(PROFILE).setHost(HOST), + DatabricksCliVersion.UNKNOWN, + Arrays.asList(CLI_PATH, "auth", "token", "--host", HOST)), + Arguments.of( + "dev build — falls back to --host", + new DatabricksConfig().setProfile(PROFILE).setHost(HOST), + new DatabricksCliVersion(0, 0, 0), + Arrays.asList(CLI_PATH, "auth", "token", "--host", HOST))); + } + + @ParameterizedTest(name = "{0}") + @MethodSource("buildCliCommandCases") + void testBuildCliCommand( + String name, DatabricksConfig config, DatabricksCliVersion version, List expected) { + assertEquals(expected, provider.buildCliCommand(CLI_PATH, config, version)); + } + + // ---- parseCliVersion tests ---- + + @Test + void testParseCliVersion_StandardOutput() { + String json = + "{\"Version\":\"v0.295.0\",\"Major\":0,\"Minor\":295,\"Patch\":0,\"Prerelease\":\"\",\"BuildMetadata\":\"\"}"; + assertEquals(new DatabricksCliVersion(0, 295, 0), provider.parseCliVersion(json)); + } + + @Test + void testParseCliVersion_ProfileVersion() { + String json = "{\"Version\":\"v0.207.1\",\"Major\":0,\"Minor\":207,\"Patch\":1}"; + assertEquals(new DatabricksCliVersion(0, 207, 1), provider.parseCliVersion(json)); + } + + @Test + void testParseCliVersion_DevBuild() { + String json = + "{\"Version\":\"v0.0.0-dev+abc123\",\"Major\":0,\"Minor\":0,\"Patch\":0,\"Prerelease\":\"dev\"}"; + assertEquals(new DatabricksCliVersion(0, 0, 0), provider.parseCliVersion(json)); + } + + @Test + void testParseCliVersion_MissingFields() { + String json = "{\"Version\":\"v0.295.0\"}"; + assertEquals(DatabricksCliVersion.UNKNOWN, provider.parseCliVersion(json)); + } + + @Test + void testParseCliVersion_MalformedJson() { + assertEquals(DatabricksCliVersion.UNKNOWN, provider.parseCliVersion("not json")); + } + + @Test + void testParseCliVersion_EmptyString() { + assertEquals(DatabricksCliVersion.UNKNOWN, provider.parseCliVersion("")); + } + + @ParameterizedTest(name = "non-integer field: {0}") + @ValueSource( + strings = { + // Major as string + "{\"Major\":\"v0\",\"Minor\":295,\"Patch\":0}", + // Minor as string + "{\"Major\":0,\"Minor\":\"bad\",\"Patch\":0}", + // Patch as string + "{\"Major\":0,\"Minor\":295,\"Patch\":\"x\"}", + // Major as JSON null + "{\"Major\":null,\"Minor\":295,\"Patch\":0}", + // Major as array + "{\"Major\":[0],\"Minor\":295,\"Patch\":0}", + // Major as object + "{\"Major\":{\"v\":0},\"Minor\":295,\"Patch\":0}", + // Major as boolean + "{\"Major\":true,\"Minor\":295,\"Patch\":0}", + // Major as floating-point (not integral) + "{\"Major\":0.5,\"Minor\":295,\"Patch\":0}" + }) + void testParseCliVersion_NonIntegerFields(String json) { + assertEquals(DatabricksCliVersion.UNKNOWN, provider.parseCliVersion(json)); + } + + // ---- getCliVersion cache tests ---- + + /** + * Subclassable provider whose {@code probeCliVersion} returns canned values and counts + * invocations. Lets cache tests verify cache hit/miss without spawning real subprocesses. + */ + private static class FakeProvider extends DatabricksCliCredentialsProvider { + final AtomicInteger probeCount = new AtomicInteger(); + DatabricksCliVersion[] sequence; + IOException throwOnFirstCall; + + FakeProvider(DatabricksCliVersion... sequence) { + this.sequence = sequence; + } + + @Override + DatabricksCliVersion probeCliVersion(String cliPath, Environment env) throws IOException { + int call = probeCount.getAndIncrement(); + if (call == 0 && throwOnFirstCall != null) { + throw throwOnFirstCall; + } + return sequence[Math.min(call, sequence.length - 1)]; + } + } + + @Test + void testGetCliVersion_SuccessIsCached() { + FakeProvider p = new FakeProvider(new DatabricksCliVersion(0, 295, 0)); + Environment env = mock(Environment.class); + + DatabricksCliVersion first = p.getCliVersion(CLI_PATH, env); + DatabricksCliVersion second = p.getCliVersion(CLI_PATH, env); + + assertEquals(new DatabricksCliVersion(0, 295, 0), first); + assertEquals(first, second); + assertEquals(1, p.probeCount.get(), "Successful probe should be cached and reused"); + } + + @Test + void testGetCliVersion_ThrownProbeIsNotCached() { + FakeProvider p = new FakeProvider(new DatabricksCliVersion(0, 295, 0)); + p.throwOnFirstCall = new IOException("transient failure"); + Environment env = mock(Environment.class); + + assertEquals(DatabricksCliVersion.UNKNOWN, p.getCliVersion(CLI_PATH, env)); + assertEquals(new DatabricksCliVersion(0, 295, 0), p.getCliVersion(CLI_PATH, env)); + assertEquals(2, p.probeCount.get(), "Failed probe should be retried, not cached"); + } + + @Test + void testGetCliVersion_UnknownReturnIsNotCached() { + // probe returns UNKNOWN (parseable-but-malformed JSON) on the first call, real version after. + FakeProvider p = + new FakeProvider(DatabricksCliVersion.UNKNOWN, new DatabricksCliVersion(0, 295, 0)); + Environment env = mock(Environment.class); + + assertEquals(DatabricksCliVersion.UNKNOWN, p.getCliVersion(CLI_PATH, env)); + assertEquals(new DatabricksCliVersion(0, 295, 0), p.getCliVersion(CLI_PATH, env)); + assertEquals(2, p.probeCount.get(), "UNKNOWN result should not pin the cache"); + } + + @Test + void testGetCliVersion_DistinctCliPathsKeptSeparate() { + DatabricksCliCredentialsProvider p = + new DatabricksCliCredentialsProvider() { + @Override + DatabricksCliVersion probeCliVersion(String cliPath, Environment env) { + return cliPath.equals("/cli-a") + ? new DatabricksCliVersion(0, 200, 0) + : new DatabricksCliVersion(0, 300, 0); + } + }; + Environment env = mock(Environment.class); + + assertEquals(new DatabricksCliVersion(0, 200, 0), p.getCliVersion("/cli-a", env)); + assertEquals(new DatabricksCliVersion(0, 300, 0), p.getCliVersion("/cli-b", env)); + // Both paths should now be cached and consistent across re-reads. + assertEquals(new DatabricksCliVersion(0, 200, 0), p.getCliVersion("/cli-a", env)); + assertEquals(new DatabricksCliVersion(0, 300, 0), p.getCliVersion("/cli-b", env)); + } + + // ---- probeCliVersion subprocess tests ---- + + /** Builds a mock Process whose merged-output stream returns {@code stdout}. */ + private static Process mockProcess(String stdout, int exitCode, boolean exited) throws Exception { + Process process = mock(Process.class); + when(process.getInputStream()) + .thenReturn(new ByteArrayInputStream(stdout.getBytes(StandardCharsets.UTF_8))); + when(process.getOutputStream()).thenReturn(new ByteArrayOutputStream()); + when(process.waitFor(anyLong(), any(TimeUnit.class))).thenReturn(exited); + when(process.exitValue()).thenReturn(exitCode); + // destroyForcibly() returns the Process so callers can chain .waitFor(...) on it. + when(process.destroyForcibly()).thenReturn(process); + return process; + } + + private static Environment mockEnv() { + Environment env = mock(Environment.class); + when(env.getEnv()).thenReturn(new HashMap<>()); + return env; + } + + private static OSUtilities passthroughOsUtils() { + OSUtilities osUtils = mock(OSUtilities.class); + when(osUtils.getCliExecutableCommand(any())).thenAnswer(inv -> inv.getArgument(0)); + return osUtils; + } + + @Test + void testProbeCliVersion_SuccessReturnsParsedVersion() throws Exception { + Environment env = mockEnv(); + OSUtilities osUtils = passthroughOsUtils(); + Process process = + mockProcess("{\"Version\":\"v0.295.0\",\"Major\":0,\"Minor\":295,\"Patch\":0}", 0, true); + + try (MockedStatic mockedOSUtils = mockStatic(OSUtils.class); + MockedConstruction mocked = + mockConstruction( + ProcessBuilder.class, + (pb, ctx) -> { + when(pb.redirectErrorStream(anyBoolean())).thenReturn(pb); + when(pb.start()).thenReturn(process); + })) { + mockedOSUtils.when(() -> OSUtils.get(any())).thenReturn(osUtils); + + DatabricksCliVersion version = provider.probeCliVersion(CLI_PATH, env); + + assertEquals(new DatabricksCliVersion(0, 295, 0), version); + // The argument passed to the OS wrapper should be the un-wrapped command. + verify(osUtils) + .getCliExecutableCommand(Arrays.asList(CLI_PATH, "version", "--output", "json")); + } + } + + @Test + void testProbeCliVersion_TimeoutThrowsAndDestroys() throws Exception { + Environment env = mockEnv(); + OSUtilities osUtils = passthroughOsUtils(); + Process process = mockProcess("", 0, /* exited= */ false); + + try (MockedStatic mockedOSUtils = mockStatic(OSUtils.class); + MockedConstruction mocked = + mockConstruction( + ProcessBuilder.class, + (pb, ctx) -> { + when(pb.redirectErrorStream(anyBoolean())).thenReturn(pb); + when(pb.start()).thenReturn(process); + })) { + mockedOSUtils.when(() -> OSUtils.get(any())).thenReturn(osUtils); + + IOException ex = + assertThrows(IOException.class, () -> provider.probeCliVersion(CLI_PATH, env)); + assertTrue(ex.getMessage().contains("timed out"), ex.getMessage()); + verify(process, atLeastOnce()).destroyForcibly(); + } + } + + @Test + void testProbeCliVersion_NonZeroExitSurfacesOutput() throws Exception { + Environment env = mockEnv(); + OSUtilities osUtils = passthroughOsUtils(); + String stderr = "command not found"; + Process process = mockProcess(stderr, 1, true); + + try (MockedStatic mockedOSUtils = mockStatic(OSUtils.class); + MockedConstruction mocked = + mockConstruction( + ProcessBuilder.class, + (pb, ctx) -> { + when(pb.redirectErrorStream(anyBoolean())).thenReturn(pb); + when(pb.start()).thenReturn(process); + })) { + mockedOSUtils.when(() -> OSUtils.get(any())).thenReturn(osUtils); + + IOException ex = + assertThrows(IOException.class, () -> provider.probeCliVersion(CLI_PATH, env)); + assertTrue(ex.getMessage().contains("exited with code 1"), ex.getMessage()); + assertTrue(ex.getMessage().contains(stderr), ex.getMessage()); + } + } } diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/DatabricksCliVersionTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/DatabricksCliVersionTest.java new file mode 100644 index 000000000..ecec8a010 --- /dev/null +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/DatabricksCliVersionTest.java @@ -0,0 +1,109 @@ +package com.databricks.sdk.core; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.stream.Stream; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +class DatabricksCliVersionTest { + + @Test + void testAtLeast_equal() { + assertTrue(new DatabricksCliVersion(0, 207, 1).atLeast(new DatabricksCliVersion(0, 207, 1))); + } + + @Test + void testAtLeast_higherPatch() { + assertTrue(new DatabricksCliVersion(0, 207, 2).atLeast(new DatabricksCliVersion(0, 207, 1))); + assertFalse(new DatabricksCliVersion(0, 207, 0).atLeast(new DatabricksCliVersion(0, 207, 1))); + } + + @Test + void testAtLeast_higherMinor() { + assertTrue(new DatabricksCliVersion(0, 296, 0).atLeast(new DatabricksCliVersion(0, 207, 1))); + assertFalse(new DatabricksCliVersion(0, 100, 99).atLeast(new DatabricksCliVersion(0, 207, 1))); + } + + @Test + void testAtLeast_higherMajor() { + assertTrue(new DatabricksCliVersion(1, 0, 0).atLeast(new DatabricksCliVersion(0, 999, 999))); + assertFalse(new DatabricksCliVersion(0, 999, 999).atLeast(new DatabricksCliVersion(1, 0, 0))); + } + + @Test + void testAtLeast_unknownIsLessThanEverything() { + assertFalse(DatabricksCliVersion.UNKNOWN.atLeast(new DatabricksCliVersion(0, 0, 0))); + assertFalse(DatabricksCliVersion.UNKNOWN.atLeast(new DatabricksCliVersion(0, 207, 1))); + } + + @Test + void testIsDefaultDevBuild() { + assertTrue(DatabricksCliVersion.DEFAULT_DEV_BUILD.isDefaultDevBuild()); + assertTrue(new DatabricksCliVersion(0, 0, 0).isDefaultDevBuild()); + assertFalse(new DatabricksCliVersion(0, 0, 1).isDefaultDevBuild()); + assertFalse(DatabricksCliVersion.UNKNOWN.isDefaultDevBuild()); + } + + @Test + void testToString() { + assertEquals("v0.207.1", new DatabricksCliVersion(0, 207, 1).toString()); + assertEquals("v1.0.0", new DatabricksCliVersion(1, 0, 0).toString()); + assertEquals("v0.0.0-dev", DatabricksCliVersion.DEFAULT_DEV_BUILD.toString()); + assertEquals("unknown", DatabricksCliVersion.UNKNOWN.toString()); + } + + // ---- compareTo sign tests ---- + + private static Stream compareToCases() { + DatabricksCliVersion v207_1 = new DatabricksCliVersion(0, 207, 1); + DatabricksCliVersion v207_0 = new DatabricksCliVersion(0, 207, 0); + DatabricksCliVersion v100_99 = new DatabricksCliVersion(0, 100, 99); + DatabricksCliVersion v999_999 = new DatabricksCliVersion(0, 999, 999); + DatabricksCliVersion v1_0_0 = new DatabricksCliVersion(1, 0, 0); + return Stream.of( + Arguments.of("equal", v207_1, v207_1, 0), + Arguments.of("lesser by patch", v207_0, v207_1, -1), + Arguments.of("greater by patch", v207_1, v207_0, 1), + Arguments.of("lesser by minor", v100_99, v207_1, -1), + Arguments.of("greater by minor", v207_1, v100_99, 1), + Arguments.of("lesser by major", v999_999, v1_0_0, -1), + Arguments.of("greater by major", v1_0_0, v999_999, 1), + Arguments.of( + "UNKNOWN < dev", + DatabricksCliVersion.UNKNOWN, + DatabricksCliVersion.DEFAULT_DEV_BUILD, + -1), + Arguments.of("UNKNOWN < real", DatabricksCliVersion.UNKNOWN, v207_1, -1), + Arguments.of("dev < real", DatabricksCliVersion.DEFAULT_DEV_BUILD, v207_1, -1), + Arguments.of("real > dev", v207_1, DatabricksCliVersion.DEFAULT_DEV_BUILD, 1), + Arguments.of( + "UNKNOWN equals UNKNOWN", + DatabricksCliVersion.UNKNOWN, + DatabricksCliVersion.UNKNOWN, + 0)); + } + + @ParameterizedTest(name = "{0}") + @MethodSource("compareToCases") + void testCompareTo_Sign( + String name, DatabricksCliVersion a, DatabricksCliVersion b, int expectedSign) { + assertEquals(expectedSign, Integer.signum(a.compareTo(b))); + } + + @Test + void testEqualsAndHashCode() { + assertEquals(new DatabricksCliVersion(0, 207, 1), new DatabricksCliVersion(0, 207, 1)); + assertEquals( + new DatabricksCliVersion(0, 207, 1).hashCode(), + new DatabricksCliVersion(0, 207, 1).hashCode()); + assertNotEquals(new DatabricksCliVersion(0, 207, 1), new DatabricksCliVersion(0, 207, 2)); + assertNotEquals(new DatabricksCliVersion(0, 207, 1), null); + assertNotEquals(new DatabricksCliVersion(0, 207, 1), "v0.207.1"); + } +}