Skip to content

Extend observation support to Azure OpenAI models #1491

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -687,4 +687,13 @@ private ChatCompletionsResponseFormat toAzureResponseFormat(AzureOpenAiResponseF
return new ChatCompletionsTextResponseFormat();
}

/**
* Use the provided convention for reporting observation data
* @param observationConvention The provided convention
*/
public void setObservationConvention(ChatModelObservationConvention observationConvention) {
Assert.notNull(observationConvention, "observationConvention cannot be null");
this.observationConvention = observationConvention;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
import com.azure.ai.openai.models.EmbeddingItem;
import com.azure.ai.openai.models.Embeddings;
import com.azure.ai.openai.models.EmbeddingsOptions;

import io.micrometer.observation.ObservationRegistry;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.azure.openai.metadata.AzureOpenAiEmbeddingUsage;
Expand All @@ -29,7 +32,12 @@
import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
import org.springframework.ai.embedding.observation.DefaultEmbeddingModelObservationConvention;
import org.springframework.ai.embedding.observation.EmbeddingModelObservationContext;
import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention;
import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation;
import org.springframework.ai.model.EmbeddingUtils;
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;

Expand All @@ -54,6 +62,18 @@ public class AzureOpenAiEmbeddingModel extends AbstractEmbeddingModel {

private final MetadataMode metadataMode;

private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention();

/**
* Observation registry used for instrumentation.
*/
private final ObservationRegistry observationRegistry;

/**
* Conventions to use for generating observations.
*/
private EmbeddingModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;

public AzureOpenAiEmbeddingModel(OpenAIClient azureOpenAiClient) {
this(azureOpenAiClient, MetadataMode.EMBED);
}
Expand All @@ -65,12 +85,20 @@ public AzureOpenAiEmbeddingModel(OpenAIClient azureOpenAiClient, MetadataMode me

public AzureOpenAiEmbeddingModel(OpenAIClient azureOpenAiClient, MetadataMode metadataMode,
AzureOpenAiEmbeddingOptions options) {
this(azureOpenAiClient, metadataMode, options, ObservationRegistry.NOOP);
}

public AzureOpenAiEmbeddingModel(OpenAIClient azureOpenAiClient, MetadataMode metadataMode,
AzureOpenAiEmbeddingOptions options, ObservationRegistry observationRegistry) {

Assert.notNull(azureOpenAiClient, "com.azure.ai.openai.OpenAIClient must not be null");
Assert.notNull(metadataMode, "Metadata mode must not be null");
Assert.notNull(options, "Options must not be null");
Assert.notNull(observationRegistry, "Observation registry must not be null");
this.azureOpenAiClient = azureOpenAiClient;
this.metadataMode = metadataMode;
this.defaultOptions = options;
this.observationRegistry = observationRegistry;
}

@Override
Expand All @@ -91,11 +119,29 @@ public float[] embed(Document document) {
public EmbeddingResponse call(EmbeddingRequest embeddingRequest) {
logger.debug("Retrieving embeddings");

EmbeddingsOptions azureOptions = toEmbeddingOptions(embeddingRequest);
Embeddings embeddings = this.azureOpenAiClient.getEmbeddings(azureOptions.getModel(), azureOptions);

logger.debug("Embeddings retrieved");
return generateEmbeddingResponse(embeddings);
AzureOpenAiEmbeddingOptions options = AzureOpenAiEmbeddingOptions.builder()
.from(this.defaultOptions)
.merge(embeddingRequest.getOptions())
.build();
EmbeddingsOptions azureOptions = options.toAzureOptions(embeddingRequest.getInstructions());

var observationContext = EmbeddingModelObservationContext.builder()
.embeddingRequest(embeddingRequest)
.provider(AiProvider.AZURE_OPENAI.value())
.requestOptions(options)
.build();

return EmbeddingModelObservationDocumentation.EMBEDDING_MODEL_OPERATION
.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
this.observationRegistry)
.observe(() -> {
Embeddings embeddings = this.azureOpenAiClient.getEmbeddings(azureOptions.getModel(), azureOptions);

logger.debug("Embeddings retrieved");
var embeddingResponse = generateEmbeddingResponse(embeddings);
observationContext.setResponse(embeddingResponse);
return embeddingResponse;
});
}

/**
Expand Down Expand Up @@ -132,4 +178,13 @@ public AzureOpenAiEmbeddingOptions getDefaultOptions() {
return this.defaultOptions;
}

/**
* Use the provided convention for reporting observation data
* @param observationConvention The provided convention
*/
public void setObservationConvention(EmbeddingModelObservationConvention observationConvention) {
Assert.notNull(observationConvention, "observationConvention cannot be null");
this.observationConvention = observationConvention;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/*
* Copyright 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.azure.openai;

import static org.assertj.core.api.Assertions.assertThat;

import java.util.List;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.springframework.ai.document.MetadataMode;
import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
import org.springframework.ai.embedding.observation.DefaultEmbeddingModelObservationConvention;
import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.HighCardinalityKeyNames;
import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.LowCardinalityKeyNames;
import org.springframework.ai.observation.conventions.AiOperationType;
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.context.annotation.Bean;

import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.OpenAIClientBuilder;
import com.azure.core.credential.AzureKeyCredential;

import io.micrometer.observation.tck.TestObservationRegistry;
import io.micrometer.observation.tck.TestObservationRegistryAssert;

/**
* Integration tests for observation instrumentation in {@link AzureOpenAiEmbeddingModel}.
*
* @author Christian Tzolov
*/
@SpringBootTest(classes = AzureOpenAiEmbeddingModelObservationIT.Config.class)
@EnabledIfEnvironmentVariable(named = "AZURE_OPENAI_API_KEY", matches = ".+")
@EnabledIfEnvironmentVariable(named = "AZURE_OPENAI_ENDPOINT", matches = ".+")
public class AzureOpenAiEmbeddingModelObservationIT {

@Autowired
TestObservationRegistry observationRegistry;

@Autowired
AzureOpenAiEmbeddingModel embeddingModel;

@Test
void observationForEmbeddingOperation() {
var options = AzureOpenAiEmbeddingOptions.builder()
.withDeploymentName("text-embedding-ada-002")
.withDimensions(1536)
.build();

EmbeddingRequest embeddingRequest = new EmbeddingRequest(List.of("Here comes the sun"), options);

EmbeddingResponse embeddingResponse = embeddingModel.call(embeddingRequest);
assertThat(embeddingResponse.getResults()).isNotEmpty();

EmbeddingResponseMetadata responseMetadata = embeddingResponse.getMetadata();
assertThat(responseMetadata).isNotNull();

TestObservationRegistryAssert.assertThat(observationRegistry)
.doesNotHaveAnyRemainingCurrentObservation()
.hasObservationWithNameEqualTo(DefaultEmbeddingModelObservationConvention.DEFAULT_NAME)
.that()
.hasContextualNameEqualTo("embedding " + "text-embedding-ada-002")
.hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_OPERATION_TYPE.asString(),
AiOperationType.EMBEDDING.value())
.hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_PROVIDER.asString(), AiProvider.AZURE_OPENAI.value())
.hasLowCardinalityKeyValue(LowCardinalityKeyNames.REQUEST_MODEL.asString(), "text-embedding-ada-002")
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_EMBEDDING_DIMENSIONS.asString(), "1536")
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_INPUT_TOKENS.asString(),
String.valueOf(responseMetadata.getUsage().getPromptTokens()))
.hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_TOTAL_TOKENS.asString(),
String.valueOf(responseMetadata.getUsage().getTotalTokens()))
.hasBeenStarted()
.hasBeenStopped();
}

@SpringBootConfiguration
static class Config {

@Bean
public TestObservationRegistry observationRegistry() {
return TestObservationRegistry.create();
}

@Bean
public OpenAIClient openAIClient() {
return new OpenAIClientBuilder().credential(new AzureKeyCredential(System.getenv("AZURE_OPENAI_API_KEY")))
.endpoint(System.getenv("AZURE_OPENAI_ENDPOINT"))
.buildClient();
}

@Bean
public AzureOpenAiEmbeddingModel azureEmbeddingModel(OpenAIClient openAIClient,
TestObservationRegistry observationRegistry) {
return new AzureOpenAiEmbeddingModel(openAIClient, MetadataMode.EMBED,
AzureOpenAiEmbeddingOptions.builder().withDeploymentName("text-embedding-ada-002").build(),
observationRegistry);
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,11 @@
import org.springframework.ai.azure.openai.AzureOpenAiChatModel;
import org.springframework.ai.azure.openai.AzureOpenAiEmbeddingModel;
import org.springframework.ai.azure.openai.AzureOpenAiImageModel;
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
Expand All @@ -36,14 +39,15 @@
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.OpenAIClientBuilder;
import com.azure.core.credential.AzureKeyCredential;
import com.azure.core.credential.KeyCredential;
import com.azure.core.credential.TokenCredential;
import com.azure.core.util.ClientOptions;
import com.azure.core.util.Header;

import io.micrometer.observation.ObservationRegistry;

/**
* @author Piotr Olaszewski
* @author Soby Chacko
Expand Down Expand Up @@ -106,20 +110,33 @@ public OpenAIClientBuilder openAIClientWithTokenCredential(AzureOpenAiConnection
matchIfMissing = true)
public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder,
AzureOpenAiChatProperties chatProperties, List<FunctionCallback> toolFunctionCallbacks,
FunctionCallbackContext functionCallbackContext) {
FunctionCallbackContext functionCallbackContext, ObjectProvider<ObservationRegistry> observationRegistry,
ObjectProvider<ChatModelObservationConvention> observationConvention) {

var chatModel = new AzureOpenAiChatModel(openAIClientBuilder, chatProperties.getOptions(),
functionCallbackContext, toolFunctionCallbacks,
observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP));
observationConvention.ifAvailable(chatModel::setObservationConvention);

return new AzureOpenAiChatModel(openAIClientBuilder, chatProperties.getOptions(), functionCallbackContext,
toolFunctionCallbacks);
return chatModel;
}

@Bean
@ConditionalOnMissingBean
@ConditionalOnProperty(prefix = AzureOpenAiEmbeddingProperties.CONFIG_PREFIX, name = "enabled",
havingValue = "true", matchIfMissing = true)
public AzureOpenAiEmbeddingModel azureOpenAiEmbeddingModel(OpenAIClientBuilder openAIClient,
AzureOpenAiEmbeddingProperties embeddingProperties) {
return new AzureOpenAiEmbeddingModel(openAIClient.buildClient(), embeddingProperties.getMetadataMode(),
embeddingProperties.getOptions());
AzureOpenAiEmbeddingProperties embeddingProperties, ObjectProvider<ObservationRegistry> observationRegistry,
ObjectProvider<EmbeddingModelObservationConvention> observationConvention) {

var embeddingModel = new AzureOpenAiEmbeddingModel(openAIClient.buildClient(),
embeddingProperties.getMetadataMode(), embeddingProperties.getOptions(),
observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP));

observationConvention.ifAvailable(embeddingModel::setObservationConvention);

return embeddingModel;

}

@Bean
Expand Down