Skip to content

Commit 471ee80

Browse files
committed
fix: bedrock titan embeddings should return usage
Signed-off-by: Gareth Evans <gareth@bryncynfelin.co.uk>
1 parent 694bb50 commit 471ee80

File tree

1 file changed

+16
-6
lines changed

1 file changed

+16
-6
lines changed

models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingModel.java

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import java.util.ArrayList;
2020
import java.util.List;
21+
import java.util.Optional;
2122
import java.util.concurrent.atomic.AtomicInteger;
2223

2324
import io.micrometer.observation.Observation;
@@ -28,12 +29,9 @@
2829
import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi;
2930
import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi.TitanEmbeddingRequest;
3031
import org.springframework.ai.bedrock.titan.api.TitanEmbeddingBedrockApi.TitanEmbeddingResponse;
32+
import org.springframework.ai.chat.metadata.DefaultUsage;
3133
import org.springframework.ai.document.Document;
32-
import org.springframework.ai.embedding.AbstractEmbeddingModel;
33-
import org.springframework.ai.embedding.Embedding;
34-
import org.springframework.ai.embedding.EmbeddingOptions;
35-
import org.springframework.ai.embedding.EmbeddingRequest;
36-
import org.springframework.ai.embedding.EmbeddingResponse;
34+
import org.springframework.ai.embedding.*;
3735
import org.springframework.util.Assert;
3836

3937
/**
@@ -89,6 +87,7 @@ public EmbeddingResponse call(EmbeddingRequest request) {
8987

9088
List<Embedding> embeddings = new ArrayList<>();
9189
var indexCounter = new AtomicInteger(0);
90+
int tokenUsage = 0;
9291

9392
for (String inputContent : request.getInstructions()) {
9493
var apiRequest = createTitanEmbeddingRequest(inputContent, request.getOptions());
@@ -111,6 +110,10 @@ public EmbeddingResponse call(EmbeddingRequest request) {
111110
}
112111

113112
embeddings.add(new Embedding(response.embedding(), indexCounter.getAndIncrement()));
113+
114+
if (response.inputTextTokenCount() != null) {
115+
tokenUsage += response.inputTextTokenCount();
116+
}
114117
}
115118
catch (Exception ex) {
116119
logger.error("Titan API embedding failed for input at index {}: {}", indexCounter.get(),
@@ -120,7 +123,10 @@ public EmbeddingResponse call(EmbeddingRequest request) {
120123
}
121124
}
122125

123-
return new EmbeddingResponse(embeddings);
126+
EmbeddingResponseMetadata embeddingResponseMetadata = new EmbeddingResponseMetadata("",
127+
getDefaultUsage(tokenUsage));
128+
129+
return new EmbeddingResponse(embeddings, embeddingResponseMetadata);
124130
}
125131

126132
private TitanEmbeddingRequest createTitanEmbeddingRequest(String inputContent, EmbeddingOptions requestOptions) {
@@ -155,6 +161,10 @@ private String summarizeInput(String input) {
155161
return input.length() > 100 ? input.substring(0, 100) + "..." : input;
156162
}
157163

164+
private DefaultUsage getDefaultUsage(int tokens) {
165+
return new DefaultUsage(tokens, 0);
166+
}
167+
158168
public enum InputType {
159169

160170
TEXT, IMAGE

0 commit comments

Comments
 (0)