Skip to content

Commit

Permalink
1465 : Ensuring trailing / in retrofit baseurl (langchain4j#1519)
Browse files Browse the repository at this point in the history
## Issue
Closes langchain4j#1465

## Change
According to
[retrofit](https://github.com/square/retrofit/blob/trunk/retrofit%2Fsrc%2Fmain%2Fjava%2Fretrofit2%2FRetrofit.java#L564)
base urls should always end with `/`.

Added new utility method to ensure that a provided base url always ends
with a `/` and checked existing API classes so that they all start
**without** a `/`.

### Tests
I have added unit test for the new utility method but testing the actual
invocation of the method in the different builder classes is harder. The
existing Ollama test case spins up a temporary web server and I don't
want to replicate this to al lmodules since I suspect build times will
increase a lot etc.

Thoughts?

## General checklist
- [X] There are no breaking changes
- [X] I have added unit and integration tests for my change
- [X] I have manually run all the unit and integration tests in the
module I have added/changed, and they are all green
- [X] I have manually run all the unit and integration tests in the
[core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core)
and
[main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j)
modules, and they are all green
<!-- Before adding documentation and example(s) (below), please wait
until the PR is reviewed and approved. -->
- [ ] I have added/updated the
[documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs)
- [ ] I have added an example in the [examples
repo](https://github.com/langchain4j/langchain4j-examples) (only for
"big" features)
- [ ] I have added/updated [Spring Boot
starter(s)](https://github.com/langchain4j/langchain4j-spring) (if
applicable)
  • Loading branch information
patpe authored Aug 6, 2024
1 parent 3af3b10 commit f2bf600
Show file tree
Hide file tree
Showing 24 changed files with 70 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.fasterxml.jackson.databind.ObjectMapper;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.anthropic.internal.api.*;
import dev.langchain4j.model.output.Response;
Expand Down Expand Up @@ -82,7 +83,7 @@ public DefaultAnthropicClient build() {


Retrofit retrofit = new Retrofit.Builder()
.baseUrl(ensureNotBlank(builder.baseUrl, "baseUrl"))
.baseUrl(Utils.ensureTrailingForwardSlash(ensureNotBlank(builder.baseUrl, "baseUrl")))
.client(okHttpClient)
.addConverterFactory(JacksonConverterFactory.create(OBJECT_MAPPER))
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ interface ChatGlmApi {

int OK = 200;

@POST("/")
@POST
@Headers({"Content-Type: application/json"})
Call<ChatCompletionResponse> chatCompletion(@Body ChatCompletionRequest chatCompletionRequest);
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import dev.langchain4j.internal.Utils;
import lombok.Builder;
import okhttp3.OkHttpClient;
import retrofit2.Response;
Expand Down Expand Up @@ -35,7 +36,7 @@ public ChatGlmClient(String baseUrl, Duration timeout) {
.build();

Retrofit retrofit = new Retrofit.Builder()
.baseUrl(baseUrl)
.baseUrl(Utils.ensureTrailingForwardSlash(baseUrl))
.client(okHttpClient)
.addConverterFactory(GsonConverterFactory.create(GSON))
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,30 +10,30 @@
import retrofit2.http.Path;

interface ChromaApi {
@GET("/api/v1/collections/{collection_name}")
@GET("api/v1/collections/{collection_name}")
@Headers({ "Content-Type: application/json" })
Call<Collection> collection(@Path("collection_name") String collectionName);

@POST("/api/v1/collections")
@POST("api/v1/collections")
@Headers({ "Content-Type: application/json" })
Call<Collection> createCollection(@Body CreateCollectionRequest createCollectionRequest);

@POST("/api/v1/collections/{collection_id}/add")
@POST("api/v1/collections/{collection_id}/add")
@Headers({ "Content-Type: application/json" })
Call<Boolean> addEmbeddings(@Path("collection_id") String collectionId, @Body AddEmbeddingsRequest embedding);

@POST("/api/v1/collections/{collection_id}/query")
@POST("api/v1/collections/{collection_id}/query")
@Headers({ "Content-Type: application/json" })
Call<QueryResponse> queryCollection(@Path("collection_id") String collectionId, @Body QueryRequest queryRequest);

@POST("/api/v1/collections/{collection_id}/delete")
@POST("api/v1/collections/{collection_id}/delete")
@Headers({ "Content-Type: application/json" })
Call<List<String>> deleteEmbeddings(
@Path("collection_id") String collectionId,
@Body DeleteEmbeddingsRequest embedding
);

@DELETE("/api/v1/collections/{collection_name}")
@DELETE("api/v1/collections/{collection_name}")
@Headers({ "Content-Type: application/json" })
Call<Collection> deleteCollection(@Path("collection_name") String collectionName);
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import dev.langchain4j.internal.Utils;

import java.io.IOException;
import java.time.Duration;
import java.util.List;
Expand Down Expand Up @@ -33,7 +35,7 @@ private ChromaClient(Builder builder) {
Gson gson = new GsonBuilder().setFieldNamingPolicy(LOWER_CASE_WITH_UNDERSCORES).create();

Retrofit retrofit = new Retrofit.Builder()
.baseUrl(builder.baseUrl)
.baseUrl(Utils.ensureTrailingForwardSlash(builder.baseUrl))
.client(httpClientBuilder.build())
.addConverterFactory(GsonConverterFactory.create(gson))
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import dev.langchain4j.internal.Utils;
import lombok.Builder;
import okhttp3.OkHttpClient;
import retrofit2.Retrofit;
Expand Down Expand Up @@ -40,7 +41,7 @@ class CohereClient {
}

Retrofit retrofit = new Retrofit.Builder()
.baseUrl(baseUrl)
.baseUrl(Utils.ensureTrailingForwardSlash(baseUrl))
.client(okHttpClientBuilder.build())
.addConverterFactory(GsonConverterFactory.create(GSON))
.build();
Expand Down
10 changes: 10 additions & 0 deletions langchain4j-core/src/main/java/dev/langchain4j/internal/Utils.java
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,16 @@ public static String generateUUIDFrom(String input) {
return UUID.nameUUIDFromBytes(sb.toString().getBytes(UTF_8)).toString();
}

/**
* Creates a new string with a trailing '/' if the provided path does not end with '/'
*
* @param str String to check for trailing '/'
* @return Same string if it already ends with '/' or a new string that ends with '/'
*/
public static String ensureTrailingForwardSlash(String str) {
return str.endsWith("/") ? str : str + "/";
}

/**
* Returns the given object's {@code toString()} surrounded by quotes.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,4 +221,13 @@ void test_copyIfNotNull() {
assertThat(Utils.copyIfNotNull(singletonList("one"))).containsExactly("one");
assertThat(Utils.copyIfNotNull(asList("one", "two"))).containsExactly("one", "two");
}

@Test
void test_ensureTrailingForwardSlash() {
assertThat(Utils.ensureTrailingForwardSlash("https://example.com")).isEqualTo("https://example.com/");
assertThat(Utils.ensureTrailingForwardSlash("https://example.com/")).isEqualTo("https://example.com/");
assertThat(Utils.ensureTrailingForwardSlash("https://example.com/a")).isEqualTo("https://example.com/a/");
assertThat(Utils.ensureTrailingForwardSlash("https://example.com/a/")).isEqualTo("https://example.com/a/");
assertThat(Utils.ensureTrailingForwardSlash("https://example.com/a/b")).isEqualTo("https://example.com/a/b/");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.model.huggingface.client.EmbeddingRequest;
import dev.langchain4j.model.huggingface.client.HuggingFaceClient;
import dev.langchain4j.model.huggingface.client.TextGenerationRequest;
Expand Down Expand Up @@ -38,7 +39,7 @@ class DefaultHuggingFaceClient implements HuggingFaceClient {
.create();

Retrofit retrofit = new Retrofit.Builder()
.baseUrl("https://api-inference.huggingface.co")
.baseUrl(Utils.ensureTrailingForwardSlash("https://api-inference.huggingface.co/"))
.client(okHttpClient)
.addConverterFactory(GsonConverterFactory.create(gson))
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@

interface HuggingFaceApi {

@POST("/models/{modelId}")
@POST("models/{modelId}")
@Headers({"Content-Type: application/json"})
Call<List<TextGenerationResponse>> generate(@Body TextGenerationRequest request, @Path("modelId") String modelId);

@POST("/pipeline/feature-extraction/{modelId}")
@POST("pipeline/feature-extraction/{modelId}")
@Headers({"Content-Type: application/json"})
Call<List<float[]>> embed(@Body EmbeddingRequest request, @Path("modelId") String modelId);
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package dev.langchain4j.model.jina.internal.client;

import com.fasterxml.jackson.databind.ObjectMapper;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.model.jina.internal.api.*;
import lombok.Builder;
import okhttp3.OkHttpClient;
Expand Down Expand Up @@ -37,7 +38,7 @@ public class JinaClient {
}

Retrofit retrofit = new Retrofit.Builder()
.baseUrl(baseUrl)
.baseUrl(Utils.ensureTrailingForwardSlash(baseUrl))
.client(okHttpClientBuilder.build())
.addConverterFactory(JacksonConverterFactory.create(OBJECT_MAPPER))
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.mistralai.internal.api.*;
import dev.langchain4j.model.output.FinishReason;
Expand Down Expand Up @@ -66,18 +67,14 @@ public DefaultMistralAiClient build() {
this.okHttpClient = okHttpClientBuilder.build();

Retrofit retrofit = new Retrofit.Builder()
.baseUrl(formattedUrlForRetrofit(builder.baseUrl))
.baseUrl(Utils.ensureTrailingForwardSlash(builder.baseUrl))
.client(okHttpClient)
.addConverterFactory(JacksonConverterFactory.create(OBJECT_MAPPER))
.build();

mistralAiApi = retrofit.create(MistralAiApi.class);
}

private static String formattedUrlForRetrofit(String baseUrl) {
return baseUrl.endsWith("/") ? baseUrl : baseUrl + "/";
}

@Override
public MistralAiChatCompletionResponse chatCompletion(MistralAiChatCompletionRequest request) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import dev.langchain4j.internal.Utils;
import lombok.Builder;
import okhttp3.OkHttpClient;
import retrofit2.Retrofit;
Expand Down Expand Up @@ -40,7 +41,7 @@ class NomicClient {
}

Retrofit retrofit = new Retrofit.Builder()
.baseUrl(baseUrl)
.baseUrl(Utils.ensureTrailingForwardSlash(baseUrl))
.client(okHttpClientBuilder.build())
.addConverterFactory(GsonConverterFactory.create(GSON))
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.fasterxml.jackson.databind.ObjectMapper;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
Expand Down Expand Up @@ -63,7 +64,7 @@ public OllamaClient(String baseUrl,
OkHttpClient okHttpClient = okHttpClientBuilder.build();

Retrofit retrofit = new Retrofit.Builder()
.baseUrl(baseUrl.endsWith("/") ? baseUrl : baseUrl + "/")
.baseUrl(Utils.ensureTrailingForwardSlash(baseUrl))
.client(okHttpClient)
.addConverterFactory(JacksonConverterFactory.create(OBJECT_MAPPER))
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import java.io.IOException;
import java.util.Arrays;
import com.fasterxml.jackson.databind.ObjectMapper;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.model.ovhai.internal.api.EmbeddingRequest;
import dev.langchain4j.model.ovhai.internal.api.EmbeddingResponse;
import dev.langchain4j.model.ovhai.internal.api.OvhAiApi;
Expand Down Expand Up @@ -61,7 +62,7 @@ public DefaultOvhAiClient build() {
this.okHttpClient = okHttpClientBuilder.build();

Retrofit retrofit = new Retrofit.Builder()
.baseUrl(ensureNotBlank(builder.baseUrl, "baseUrl"))
.baseUrl(Utils.ensureTrailingForwardSlash(ensureNotBlank(builder.baseUrl, "baseUrl")))
.client(okHttpClient)
.addConverterFactory(JacksonConverterFactory.create(OBJECT_MAPPER))
.build();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package dev.langchain4j.model.qianfan.client;


import dev.langchain4j.internal.Utils;
import dev.langchain4j.model.qianfan.client.chat.ChatCompletionRequest;
import dev.langchain4j.model.qianfan.client.chat.ChatCompletionResponse;
import dev.langchain4j.model.qianfan.client.chat.ChatTokenResponse;
Expand Down Expand Up @@ -68,7 +69,7 @@ private QianfanClient(Builder serviceBuilder) {
this.apiKey = serviceBuilder.apiKey;
this.secretKey = serviceBuilder.secretKey;
this.okHttpClient = okHttpClientBuilder.build();
Retrofit retrofit = (new Retrofit.Builder()).baseUrl(serviceBuilder.baseUrl).client(this.okHttpClient)
Retrofit retrofit = (new Retrofit.Builder()).baseUrl(Utils.ensureTrailingForwardSlash(serviceBuilder.baseUrl)).client(this.okHttpClient)
.addConverterFactory(GsonConverterFactory.create(Json.GSON)).build();
this.qianfanApi = retrofit.create(QianfanApi.class);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,33 +12,33 @@ public interface VearchApi {

/* Database Operation */

@GET("/list/db")
@GET("list/db")
Call<ResponseWrapper<List<ListDatabaseResponse>>> listDatabase();

@PUT("/db/_create")
@PUT("db/_create")
Call<ResponseWrapper<CreateDatabaseResponse>> createDatabase(@Body CreateDatabaseRequest request);

@GET("/list/space")
@GET("list/space")
Call<ResponseWrapper<List<ListSpaceResponse>>> listSpaceOfDatabase(@Query("db") String dbName);

/* Space (like a table in relational database) Operation */

@PUT("/space/{db}/_create")
@PUT("space/{db}/_create")
Call<ResponseWrapper<CreateSpaceResponse>> createSpace(@Path("db") String dbName,
@Body CreateSpaceRequest request);

/* Document Operation */

@POST("/{db}/{space}/_bulk")
@POST("{db}/{space}/_bulk")
Call<List<BulkResponse>> bulk(@Path("db") String db,
@Path("space") String space,
@Body RequestBody requestBody);

@POST("/{db}/{space}/_search")
@POST("{db}/{space}/_search")
Call<SearchResponse> search(@Path("db") String db,
@Path("space") String space,
@Body SearchRequest request);

@DELETE("/space/{db}/{space}")
@DELETE("space/{db}/{space}")
Call<Void> deleteSpace(@Path("db") String dbName, @Path("space") String spaceName);
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import dev.langchain4j.internal.Utils;
import lombok.Builder;
import okhttp3.MediaType;
import okhttp3.OkHttpClient;
Expand Down Expand Up @@ -37,7 +38,7 @@ public VearchClient(String baseUrl, Duration timeout) {
.build();

Retrofit retrofit = new Retrofit.Builder()
.baseUrl(baseUrl)
.baseUrl(Utils.ensureTrailingForwardSlash(baseUrl))
.client(okHttpClient)
.addConverterFactory(GsonConverterFactory.create(GSON))
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
package dev.langchain4j.store.embedding.vespa;

import com.google.gson.GsonBuilder;
import dev.langchain4j.internal.Utils;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
Expand Down Expand Up @@ -72,7 +74,7 @@ public static VespaQueryApi createInstance(String baseUrl, Path certificate, Pat
.build();

Retrofit retrofit = new Retrofit.Builder()
.baseUrl(baseUrl)
.baseUrl(Utils.ensureTrailingForwardSlash(baseUrl))
.client(client)
.addConverterFactory(GsonConverterFactory.create(new GsonBuilder().create()))
.build();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package dev.langchain4j.model.workersai.client;

import dev.langchain4j.internal.Utils;
import okhttp3.Interceptor;
import okhttp3.OkHttpClient;
import okhttp3.Request;
Expand Down Expand Up @@ -40,7 +41,7 @@ public static WorkersAiApi createService(String apiToken) {
.build();

Retrofit retrofit = new Retrofit.Builder()
.baseUrl(BASE_URL)
.baseUrl(Utils.ensureTrailingForwardSlash(BASE_URL))
.client(okHttpClient)
.addConverterFactory(JacksonConverterFactory.create())
.build();
Expand Down
Loading

0 comments on commit f2bf600

Please sign in to comment.