Skip to content

Issues3073:support ToolProvider automatically wired into the AI Service if available in the application contex #131

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import dev.langchain4j.rag.RetrievalAugmentor;
import dev.langchain4j.rag.content.retriever.ContentRetriever;
import dev.langchain4j.service.AiServices;
import dev.langchain4j.service.tool.ToolProvider;
import org.springframework.stereotype.Service;

import java.lang.annotation.Retention;
Expand Down Expand Up @@ -103,4 +104,10 @@
* this attribute specifies the names of beans containing methods annotated with {@link Tool} that should be used by this AI Service.
*/
String[] tools() default {};

/**
* When the {@link #wiringMode()} is set to {@link AiServiceWiringMode#EXPLICIT},
* this attribute specifies the name of a {@link ToolProvider} bean that should be used by this AI Service.
*/
String toolProvider() default "";
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import dev.langchain4j.service.AiServices;
import dev.langchain4j.service.tool.DefaultToolExecutor;
import dev.langchain4j.service.tool.ToolExecutor;
import dev.langchain4j.service.tool.ToolProvider;
import org.springframework.beans.factory.FactoryBean;

import java.lang.reflect.Method;
Expand All @@ -36,6 +37,7 @@ class AiServiceFactory implements FactoryBean<Object> {
private RetrievalAugmentor retrievalAugmentor;
private ModerationModel moderationModel;
private List<Object> tools;
private ToolProvider toolProvider;

public AiServiceFactory(Class<Object> aiServiceClass) {
this.aiServiceClass = aiServiceClass;
Expand Down Expand Up @@ -73,6 +75,10 @@ public void setTools(List<Object> tools) {
this.tools = tools;
}

public void setToolProvider(ToolProvider toolProvider) {
this.toolProvider = toolProvider;
}

@Override
public Object getObject() {

Expand Down Expand Up @@ -113,6 +119,9 @@ public Object getObject() {
}
}
}
if (toolProvider != null) {
builder = builder.toolProvider(toolProvider);
}

return builder.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import dev.langchain4j.rag.content.retriever.ContentRetriever;
import dev.langchain4j.service.IllegalConfigurationException;
import dev.langchain4j.service.spring.event.AiServiceRegisteredEvent;
import dev.langchain4j.service.tool.ToolProvider;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.MutablePropertyValues;
Expand Down Expand Up @@ -58,6 +59,7 @@ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() {
String[] contentRetrievers = beanFactory.getBeanNamesForType(ContentRetriever.class);
String[] retrievalAugmentors = beanFactory.getBeanNamesForType(RetrievalAugmentor.class);
String[] moderationModels = beanFactory.getBeanNamesForType(ModerationModel.class);
String[] toolProviders = beanFactory.getBeanNamesForType(ToolProvider.class);

Set<String> toolBeanNames = new HashSet<>();
List<ToolSpecification> toolSpecifications = new ArrayList<>();
Expand Down Expand Up @@ -165,6 +167,16 @@ BeanFactoryPostProcessor aiServicesRegisteringBeanFactoryPostProcessor() {
propertyValues
);

addBeanReference(
ToolProvider.class,
aiServiceAnnotation,
aiServiceAnnotation.toolProvider(),
toolProviders,
"toolProvider",
"toolProvider",
propertyValues
);

if (aiServiceAnnotation.wiringMode() == EXPLICIT) {
propertyValues.add("tools", toManagedList(asList(aiServiceAnnotation.tools())));
} else if (aiServiceAnnotation.wiringMode() == AUTOMATIC) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package dev.langchain4j.service.spring.mode.automatic.issue3073;

import dev.langchain4j.service.spring.AiService;

@AiService
public interface AiServiceWithToolProvider {
String chat(String userMessage);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package dev.langchain4j.service.spring.mode.automatic.issue3073;

import dev.langchain4j.model.chat.ChatModel;
import dev.langchain4j.model.openai.OpenAiChatModel;
import dev.langchain4j.service.tool.ToolProvider;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.context.annotation.Bean;

import static dev.langchain4j.model.openai.OpenAiChatModelName.GPT_4_O_MINI;


@SpringBootApplication
public class TestAutowireAiServiceToolProviderApplication {

@Bean
ChatModel chatModel() {
return OpenAiChatModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.modelName(GPT_4_O_MINI)
.build();
}

@Bean
public ToolProvider toolProvider() {
return new TestMcpToolProvider();
}

public static void main(String[] args) {
SpringApplication.run(TestAutowireAiServiceToolProviderApplication.class, args);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package dev.langchain4j.service.spring.mode.automatic.issue3073;

import dev.langchain4j.service.AiServiceContext;
import dev.langchain4j.service.spring.AiServicesAutoConfig;
import dev.langchain4j.service.tool.ToolProvider;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;

import java.lang.reflect.Field;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Proxy;

public class TestAutowrieToolProvider {

ApplicationContextRunner contextRunner = new ApplicationContextRunner()
.withConfiguration(AutoConfigurations.of(AiServicesAutoConfig.class));

@Test
void should_fail_to_create_AI_service_when_conflicting_chat_models_are_found() {
contextRunner
.withUserConfiguration(TestAutowireAiServiceToolProviderApplication.class)
.run(context -> {
ToolProvider toolProviderBean = context.getBean(ToolProvider.class);
Assertions.assertNotNull(toolProviderBean, "ToolProvider bean should be present in the context.");

AiServiceWithToolProvider aiServiceProxy = context.getBean(AiServiceWithToolProvider.class);
Assertions.assertNotNull(aiServiceProxy, "AiServiceWithToolProvider should be created successfully.");

InvocationHandler handler = Proxy.getInvocationHandler(aiServiceProxy);
Assertions.assertNotNull(handler, "InvocationHandler should be present.");
Field this$0 = handler.getClass().getDeclaredField("this$0");
this$0.setAccessible(true);
Object aiServices = this$0.get(handler);

Field contextField = aiServices.getClass().getSuperclass().getDeclaredField("context");
contextField.setAccessible(true);
AiServiceContext aiServiceContext = (AiServiceContext) contextField.get(aiServices);
Assertions.assertNotNull(aiServiceContext, "AiServiceContext should be present.");

Field aiServiceField = aiServiceContext.getClass().getDeclaredField("toolService");
aiServiceField.setAccessible(true);
Object toolService = aiServiceField.get(aiServiceContext);

Field toolProvider = toolService.getClass().getDeclaredField("toolProvider");
toolProvider.setAccessible(true);
Object toolProviderObj = toolProvider.get(toolService);
Assertions.assertInstanceOf(TestMcpToolProvider.class, toolProviderObj, "ToolProvider should be TestMcpToolProvider.");


});
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package dev.langchain4j.service.spring.mode.automatic.issue3073;

import dev.langchain4j.service.tool.ToolProvider;
import dev.langchain4j.service.tool.ToolProviderRequest;
import dev.langchain4j.service.tool.ToolProviderResult;

public class TestMcpToolProvider implements ToolProvider {
@Override
public ToolProviderResult provideTools(ToolProviderRequest toolProviderRequest) {
return null;
}
}