Skip to content

Commit

Permalink
Removed Result from model classes (langchain4j#30)
Browse files Browse the repository at this point in the history
  • Loading branch information
langchain4j authored Jul 16, 2023
1 parent 77a767a commit 7d36b0c
Show file tree
Hide file tree
Showing 30 changed files with 265 additions and 293 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ See example [here](https://github.com/langchain4j/langchain4j-examples/blob/main
```java
OpenAiChatModel model = OpenAiChatModel.withApiKey(apiKey);
AiMessage answer = model.sendUserMessage("Hello world!").get();
AiMessage answer = model.sendUserMessage("Hello world!");
System.out.println(answer.text()); // Hello! How can I assist you today?
```
Expand Down
Original file line number Diff line number Diff line change
@@ -1,27 +1,38 @@
package dev.langchain4j.model.chat;

import dev.langchain4j.MightChangeInTheFuture;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.output.Result;
import dev.langchain4j.data.message.UserMessage;

import java.util.List;

/**
* Represents a LLM that has a chat interface.
*/
public interface ChatLanguageModel {

Result<AiMessage> sendUserMessage(String text);
/**
* Sends a message from a user to the LLM and returns response.
*
* @param userMessage User message as a String. Will be wrapped into {@link dev.langchain4j.data.message.UserMessage UserMessage} under the hood.
* @return {@link dev.langchain4j.data.message.AiMessage AiMessage}
*/
AiMessage sendUserMessage(String userMessage);

@MightChangeInTheFuture("not sure this method is useful/needed")
Result<AiMessage> sendUserMessage(Prompt prompt);
AiMessage sendUserMessage(UserMessage userMessage);

@MightChangeInTheFuture("not sure this method is useful/needed")
Result<AiMessage> sendUserMessage(Object structuredPrompt);
/**
* Sends a structured prompt as a user message to the LLM and returns response.
*
* @param structuredPrompt object annotated with {@link dev.langchain4j.model.input.structured.StructuredPrompt @StructuredPrompt}
* @return {@link dev.langchain4j.data.message.AiMessage AiMessage}
*/
AiMessage sendUserMessage(Object structuredPrompt);

Result<AiMessage> sendMessages(ChatMessage... messages);
AiMessage sendMessages(ChatMessage... messages);

Result<AiMessage> sendMessages(List<ChatMessage> messages);
AiMessage sendMessages(List<ChatMessage> messages);

Result<AiMessage> sendMessages(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications);
AiMessage sendMessages(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications);
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,11 @@
public interface StreamingChatLanguageModel {

@WillChangeSoon("Most probably StreamingResultHandler will be replaced with fluent API")
void sendUserMessage(String text, StreamingResultHandler handler);
void sendUserMessage(String userMessage, StreamingResultHandler handler);

@WillChangeSoon("Most probably StreamingResultHandler will be replaced with fluent API")
void sendUserMessage(UserMessage userMessage, StreamingResultHandler handler);

@WillChangeSoon("Most probably StreamingResultHandler will be replaced with fluent API")
void sendUserMessage(Prompt prompt, StreamingResultHandler handler);

@WillChangeSoon("Most probably StreamingResultHandler will be replaced with fluent API")
void sendUserMessage(Object structuredPrompt, StreamingResultHandler handler);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@

import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.model.output.Result;

import java.util.List;

public interface EmbeddingModel {

Result<Embedding> embed(String text);
Embedding embed(String text);

Result<Embedding> embed(TextSegment textSegment);
Embedding embed(TextSegment textSegment);

Result<List<Embedding>> embedAll(List<TextSegment> textSegments);
List<Embedding> embedAll(List<TextSegment> textSegments);
}
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
package dev.langchain4j.model.language;

import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.output.Result;

public interface LanguageModel {

Result<String> process(String text);
String process(String text);

Result<String> process(Prompt prompt);
String process(Prompt prompt);

Result<String> process(Object structuredPrompt);
String process(Object structuredPrompt);
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,20 @@
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.output.Result;

import java.util.List;

public interface ModerationModel {

Result<Moderation> moderate(String text);
Moderation moderate(String text);

Result<Moderation> moderate(Prompt prompt);
Moderation moderate(Prompt prompt);

Result<Moderation> moderate(Object structuredPrompt);
Moderation moderate(Object structuredPrompt);

Result<Moderation> moderate(ChatMessage message);
Moderation moderate(ChatMessage message);

Result<Moderation> moderate(List<ChatMessage> messages);
Moderation moderate(List<ChatMessage> messages);

Result<Moderation> moderate(TextSegment textSegment);
Moderation moderate(TextSegment textSegment);
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public String execute(String userMessage) {

chatMemory.add(userMessage(ensureNotBlank(userMessage, "userMessage")));

AiMessage aiMessage = chatLanguageModel.sendMessages(chatMemory.messages()).get();
AiMessage aiMessage = chatLanguageModel.sendMessages(chatMemory.messages());

chatMemory.add(aiMessage);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ public String execute(String question) {

chatMemory.add(userMessage);

AiMessage answer = chatLanguageModel.sendMessages(chatMemory.messages()).get();
AiMessage answer = chatLanguageModel.sendMessages(chatMemory.messages());

chatMemory.add(answer);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.output.Result;

import java.time.Duration;
import java.util.List;

import static dev.langchain4j.data.message.AiMessage.aiMessage;
import static dev.langchain4j.data.message.UserMessage.userMessage;
import static dev.langchain4j.internal.Utils.isNullOrBlank;
import static dev.langchain4j.model.huggingface.HuggingFaceModelName.TII_UAE_FALCON_7B_INSTRUCT;
import static dev.langchain4j.model.input.structured.StructuredPromptProcessor.toPrompt;
import static java.util.Arrays.asList;
Expand Down Expand Up @@ -51,32 +52,33 @@ public HuggingFaceChatModel(Builder builder) {
}

@Override
public Result<AiMessage> sendUserMessage(String text) {
return sendMessages(userMessage(text));
public AiMessage sendUserMessage(String userMessage) {
return sendUserMessage(userMessage(userMessage));
}

@Override
public Result<AiMessage> sendUserMessage(Prompt userMessage) {
return sendUserMessage(userMessage.text());
public AiMessage sendUserMessage(UserMessage userMessage) {
return sendMessages(userMessage);
}

@Override
public Result<AiMessage> sendUserMessage(Object structuredPrompt) {
return sendUserMessage(toPrompt(structuredPrompt));
public AiMessage sendUserMessage(Object structuredPrompt) {
Prompt prompt = toPrompt(structuredPrompt);
return sendUserMessage(prompt.toUserMessage());
}

@Override
public Result<AiMessage> sendMessages(ChatMessage... messages) {
public AiMessage sendMessages(ChatMessage... messages) {
return sendMessages(asList(messages));
}

@Override
public Result<AiMessage> sendMessages(List<ChatMessage> messages) {
public AiMessage sendMessages(List<ChatMessage> messages) {
return sendMessages(messages, null);
}

@Override
public Result<AiMessage> sendMessages(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
public AiMessage sendMessages(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {

if (toolSpecifications != null && toolSpecifications.size() > 0) {
throw new IllegalArgumentException("Tools are currently not supported for HuggingFace models");
Expand All @@ -98,9 +100,7 @@ public Result<AiMessage> sendMessages(List<ChatMessage> messages, List<ToolSpeci

TextGenerationResponse textGenerationResponse = client.chat(request);

AiMessage aiMessage = aiMessage(textGenerationResponse.generatedText());

return Result.from(aiMessage);
return aiMessage(textGenerationResponse.generatedText());
}

public static Builder builder() {
Expand Down Expand Up @@ -161,10 +161,14 @@ public Builder waitForModel(Boolean waitForModel) {
}

public HuggingFaceChatModel build() {
if (accessToken == null || accessToken.trim().isEmpty()) {
if (isNullOrBlank(accessToken)) {
throw new IllegalArgumentException("HuggingFace access token must be defined. It can be generated here: https://huggingface.co/settings/tokens");
}
return new HuggingFaceChatModel(this);
}
}

public static HuggingFaceChatModel withAccessToken(String accessToken) {
return builder().accessToken(accessToken).build();
}
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
package dev.langchain4j.model.huggingface;

import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.output.Result;
import lombok.Builder;

import java.time.Duration;
Expand Down Expand Up @@ -34,18 +33,18 @@ public HuggingFaceEmbeddingModel(String accessToken, String modelId, Boolean wai
}

@Override
public Result<Embedding> embed(String text) {
Result<List<Embedding>> result = embedTexts(singletonList(text));
return Result.from(result.get().get(0));
public Embedding embed(String text) {
List<Embedding> embeddings = embedTexts(singletonList(text));
return embeddings.get(0);
}

@Override
public Result<Embedding> embed(TextSegment textSegment) {
public Embedding embed(TextSegment textSegment) {
return embed(textSegment.text());
}

@Override
public Result<List<Embedding>> embedAll(List<TextSegment> textSegments) {
public List<Embedding> embedAll(List<TextSegment> textSegments) {

List<String> texts = textSegments.stream()
.map(TextSegment::text)
Expand All @@ -54,17 +53,15 @@ public Result<List<Embedding>> embedAll(List<TextSegment> textSegments) {
return embedTexts(texts);
}

private Result<List<Embedding>> embedTexts(List<String> texts) {
private List<Embedding> embedTexts(List<String> texts) {

EmbeddingRequest request = new EmbeddingRequest(texts, waitForModel);

List<float[]> response = client.embed(request);

List<Embedding> embeddings = response.stream()
return response.stream()
.map(Embedding::from)
.collect(toList());

return Result.from(embeddings);
}

public static HuggingFaceEmbeddingModel withAccessToken(String accessToken) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.language.LanguageModel;
import dev.langchain4j.model.output.Result;

import java.time.Duration;

Expand Down Expand Up @@ -44,7 +43,7 @@ public HuggingFaceLanguageModel(Builder builder) {
}

@Override
public Result<String> process(String text) {
public String process(String text) {

TextGenerationRequest request = TextGenerationRequest.builder()
.inputs(text)
Expand All @@ -60,16 +59,16 @@ public Result<String> process(String text) {

TextGenerationResponse response = client.generate(request);

return Result.from(response.generatedText());
return response.generatedText();
}

@Override
public Result<String> process(Prompt prompt) {
public String process(Prompt prompt) {
return this.process(prompt.text());
}

@Override
public Result<String> process(Object structuredPrompt) {
public String process(Object structuredPrompt) {
return process(toPrompt(structuredPrompt));
}

Expand Down Expand Up @@ -138,4 +137,8 @@ public HuggingFaceLanguageModel build() {
return new HuggingFaceLanguageModel(this);
}
}

public static HuggingFaceLanguageModel withAccessToken(String accessToken) {
return builder().accessToken(accessToken).build();
}
}
Loading

0 comments on commit 7d36b0c

Please sign in to comment.