Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEXT_CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## Release v0.104.0

### New Features and Improvements
* Add support for authentication through Azure Managed Service Identity (MSI) via the new `azure-msi` credential provider.
* Added automatic detection of AI coding agents (Antigravity, Claude Code, Cline, Codex, Copilot CLI, Cursor, Gemini CLI, OpenCode) in the user-agent string. The SDK now appends `agent/<name>` to HTTP request headers when running inside a known AI agent environment.

### Bug Fixes
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package com.databricks.sdk.core;

import com.databricks.sdk.core.oauth.AzureMsiTokenSource;
import com.databricks.sdk.core.oauth.CachedTokenSource;
import com.databricks.sdk.core.oauth.OAuthHeaderFactory;
import com.databricks.sdk.core.oauth.Token;
import com.databricks.sdk.core.utils.AzureUtils;
import com.databricks.sdk.support.InternalApi;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.util.HashMap;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* Adds refreshed Azure Active Directory (AAD) tokens obtained via Azure Managed Service Identity
* (MSI) to every request. This provider authenticates using the Azure Instance Metadata Service
* (IMDS) endpoint, which is available on Azure VMs and other compute resources with managed
* identities enabled.
*/
@InternalApi
public class AzureMsiCredentialsProvider implements CredentialsProvider {
private static final Logger LOG = LoggerFactory.getLogger(AzureMsiCredentialsProvider.class);
private final ObjectMapper mapper = new ObjectMapper();

@Override
public String authType() {
return "azure-msi";
}

@Override
public OAuthHeaderFactory configure(DatabricksConfig config) {
if (!config.isAzure()) {
return null;
}

if (!isAzureUseMsi(config)) {
return null;
}

if (config.getAzureWorkspaceResourceId() == null && config.getHost() == null) {
return null;
}

LOG.debug("Generating AAD token via Azure MSI");

AzureUtils.ensureHostPresent(config, mapper, this::tokenSourceFor);

CachedTokenSource inner = tokenSourceFor(config, config.getEffectiveAzureLoginAppId());
CachedTokenSource cloud =
tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint());

return OAuthHeaderFactory.fromSuppliers(
inner::getToken,
() -> {
Token token = inner.getToken();
Map<String, String> headers = new HashMap<>();
headers.put("Authorization", "Bearer " + token.getAccessToken());
AzureUtils.addSpManagementToken(cloud, headers);
AzureUtils.addWorkspaceResourceId(config, headers);
return headers;
});
}

/**
* Null-safe check for the azureUseMsi config flag. The underlying field is a boxed Boolean, but
* the getter auto-unboxes to primitive boolean, which would NPE when the field is null. This
* helper treats null as false.
*/
private static boolean isAzureUseMsi(DatabricksConfig config) {
try {
return config.getAzureUseMsi();
} catch (NullPointerException e) {
return false;
}
}

/**
* Creates a CachedTokenSource for the specified Azure resource using MSI authentication.
*
* @param config The DatabricksConfig instance containing the required authentication parameters.
* @param resource The Azure resource for which OAuth tokens need to be fetched.
* @return A CachedTokenSource instance capable of fetching OAuth tokens for the specified Azure
* resource.
*/
CachedTokenSource tokenSourceFor(DatabricksConfig config, String resource) {
AzureMsiTokenSource tokenSource =
new AzureMsiTokenSource(config.getHttpClient(), resource, config.getAzureClientId());
return new CachedTokenSource.Builder(tokenSource)
.setAsyncDisabled(config.getDisableAsyncTokenRefresh())
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ private synchronized void addDefaultCredentialsProviders(DatabricksConfig config
addOIDCCredentialsProviders(config);

providers.add(new AzureGithubOidcCredentialsProvider());
providers.add(new AzureMsiCredentialsProvider());
providers.add(new AzureServicePrincipalCredentialsProvider());
providers.add(new AzureCliCredentialsProvider());
providers.add(new ExternalBrowserCredentialsProvider());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
package com.databricks.sdk.core.oauth;

import com.databricks.sdk.core.DatabricksException;
import com.databricks.sdk.core.http.HttpClient;
import com.databricks.sdk.core.http.Request;
import com.databricks.sdk.core.http.Response;
import com.databricks.sdk.support.InternalApi;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.IOException;
import java.time.Instant;

/**
* A {@link TokenSource} that fetches OAuth tokens from the Azure Instance Metadata Service (IMDS)
* endpoint for Managed Service Identity (MSI) authentication.
*
* <p>This token source makes HTTP GET requests to the well-known IMDS endpoint at {@code
* http://169.254.169.254/metadata/identity/oauth2/token} to obtain access tokens for the specified
* Azure resource.
*/
@InternalApi
public class AzureMsiTokenSource implements TokenSource {

private static final String IMDS_ENDPOINT =
"http://169.254.169.254/metadata/identity/oauth2/token";

private final HttpClient httpClient;
private final String resource;
private final String clientId;
private final ObjectMapper mapper = new ObjectMapper();

/** Response structure from the Azure IMDS token endpoint. */
@JsonIgnoreProperties(ignoreUnknown = true)
static class MsiTokenResponse {
@JsonProperty("token_type")
private String tokenType;

@JsonProperty("access_token")
private String accessToken;

@JsonProperty("expires_on")
private String expiresOn;

Token toToken() {
if (accessToken == null || accessToken.isEmpty()) {
throw new DatabricksException("MSI token response missing or empty 'access_token' field");
}
if (tokenType == null || tokenType.isEmpty()) {
throw new DatabricksException("MSI token response missing or empty 'token_type' field");
}
if (expiresOn == null || expiresOn.isEmpty()) {
throw new DatabricksException("MSI token response missing 'expires_on' field");
}
long epoch;
try {
epoch = Long.parseLong(expiresOn);
} catch (NumberFormatException e) {
throw new DatabricksException(
"Invalid 'expires_on' value in MSI token response: " + expiresOn, e);
}
return new Token(accessToken, tokenType, Instant.ofEpochSecond(epoch));
}
}

/**
* Creates a new AzureMsiTokenSource.
*
* @param httpClient The HTTP client to use for requests to the IMDS endpoint.
* @param resource The Azure resource for which to obtain an access token.
* @param clientId The client ID of the managed identity to use. May be null for system-assigned
* identities.
*/
public AzureMsiTokenSource(HttpClient httpClient, String resource, String clientId) {
this.httpClient = httpClient;
this.resource = resource;
this.clientId = clientId;
}

@Override
public Token getToken() {
Request req = new Request("GET", IMDS_ENDPOINT);
req.withQueryParam("api-version", "2018-02-01");
req.withQueryParam("resource", resource);
if (clientId != null && !clientId.isEmpty()) {
req.withQueryParam("client_id", clientId);
}
req.withHeader("Metadata", "true");

Response resp;
try {
resp = httpClient.execute(req);
} catch (IOException e) {
throw new DatabricksException(
"Failed to request MSI token from IMDS endpoint: " + e.getMessage(), e);
}

if (resp.getStatusCode() != 200) {
throw new DatabricksException(
"Failed to request MSI token: status code "
+ resp.getStatusCode()
+ ", response body: "
+ resp.getDebugBody());
}

try {
MsiTokenResponse msiToken = mapper.readValue(resp.getBody(), MsiTokenResponse.class);
return msiToken.toToken();
} catch (IOException e) {
throw new DatabricksException("Failed to parse MSI token response: " + e.getMessage(), e);
}
}
}
Loading
Loading