Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,10 @@ private <T> T execute(Request in, Class<T> target) throws IOException {
}

private Response getResponse(Request in) {
in.withUrl(config.getHost() + in.getUrl());
return executeInner(in);
return executeInner(in, in.getUrl());
}

private Response executeInner(Request in) {
private Response executeInner(Request in, String path) {
RetryStrategy retryStrategy = retryStrategyPicker.getRetryStrategy(in);
int attemptNumber = 0;
while (true) {
Expand All @@ -247,6 +246,10 @@ private Response executeInner(Request in) {
// Authenticate the request. Failures should not be retried.
in.withHeaders(config.authenticate());

// Prepend host to URL only after config.authenticate().
// This call may configure the host (e.g. in case of notebook native auth).
in.withUrl(config.getHost() + path);

// Set User-Agent with auth type info, which is available only
// after the first invocation to config.authenticate()
String userAgent = String.format("%s auth/%s", UserAgent.asString(), config.getAuthType());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,26 +42,24 @@ public boolean equals(Object o) {
}
}

private ApiClient getApiClient(Request request, List<ResponseProvider> responses) {
private ApiClient getApiClient(
DatabricksConfig config, Request request, List<ResponseProvider> responses) {
DummyHttpClient hc = new DummyHttpClient();
for (ResponseProvider response : responses) {
hc.with(request, response);
}
return new ApiClient(config.setHttpClient(hc), new FakeTimer());
}

private ApiClient getApiClient(Request request, List<ResponseProvider> responses) {
String host = request.getUri().getScheme() + "://" + request.getUri().getHost();
DatabricksConfig config =
new DatabricksConfig()
.setHttpClient(hc)
.setHost(host)
.setCredentialsProvider(new DummyCredentialsProvider());
return new ApiClient(config, new FakeTimer());
new DatabricksConfig().setHost(host).setCredentialsProvider(new DummyCredentialsProvider());
return getApiClient(config, request, responses);
}

private <T> void runApiClientTest(
Request request,
List<ResponseProvider> responses,
Class<? extends T> clazz,
T expectedResponse) {
ApiClient client = getApiClient(request, responses);
ApiClient client, Request request, Class<? extends T> clazz, T expectedResponse) {
T response;
if (request.getMethod().equals(Request.GET)) {
response = client.GET(request.getUri().getPath(), clazz, Collections.emptyMap());
Expand All @@ -73,6 +71,15 @@ private <T> void runApiClientTest(
assertEquals(response, expectedResponse);
}

private <T> void runApiClientTest(
Request request,
List<ResponseProvider> responses,
Class<? extends T> clazz,
T expectedResponse) {
ApiClient client = getApiClient(request, responses);
runApiClientTest(client, request, clazz, expectedResponse);
}

private void runFailingApiClientTest(
Request request, List<ResponseProvider> responses, Class<?> clazz, String expectedMessage) {
DatabricksException exception =
Expand Down Expand Up @@ -347,6 +354,39 @@ void retryUnknownHostException() {
new MyEndpointResponse().setKey("value"));
}

class HostPopulatingCredentialsProvider implements CredentialsProvider {
private final String host;
private final CredentialsProvider parent;

public HostPopulatingCredentialsProvider(String host) {
this.host = host;
this.parent = new DummyCredentialsProvider();
}

@Override
public String authType() {
return parent.authType();
}

@Override
public HeaderFactory configure(DatabricksConfig config) {
config.setHost(this.host);
return parent.configure(config);
}
}

@Test
void populateHostFromCredentialProvider() {
Request req = getBasicRequest();
DatabricksConfig config =
new DatabricksConfig()
.setCredentialsProvider(new HostPopulatingCredentialsProvider("http://my.host"));
ApiClient client =
getApiClient(config, req, Collections.singletonList(getSuccessResponse(req)));
runApiClientTest(
client, req, MyEndpointResponse.class, new MyEndpointResponse().setKey("value"));
}

@Test
void testGetBackoffFromRetryAfterHeader() {
Request req = getBasicRequest();
Expand Down