Skip to content

Commit

Permalink
Removed duplicated logic between model classes (langchain4j#47)
Browse files Browse the repository at this point in the history
  • Loading branch information
langchain4j authored Jul 24, 2023
1 parent 540741c commit 8d6f9f6
Show file tree
Hide file tree
Showing 16 changed files with 85 additions and 223 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.input.Prompt;

import java.util.List;

import static dev.langchain4j.model.input.structured.StructuredPromptProcessor.toPrompt;
import static java.util.Arrays.asList;

/**
* Represents a LLM that has a chat interface.
*/
Expand All @@ -18,19 +22,28 @@ public interface ChatLanguageModel {
* @param userMessage User message as a String. Will be wrapped into {@link dev.langchain4j.data.message.UserMessage UserMessage} under the hood.
* @return {@link dev.langchain4j.data.message.AiMessage AiMessage}
*/
AiMessage sendUserMessage(String userMessage);
default AiMessage sendUserMessage(String userMessage) {
return sendUserMessage(UserMessage.from(userMessage));
}

AiMessage sendUserMessage(UserMessage userMessage);
default AiMessage sendUserMessage(UserMessage userMessage) {
return sendMessages(userMessage);
}

/**
* Sends a structured prompt as a user message to the LLM and returns response.
*
* @param structuredPrompt object annotated with {@link dev.langchain4j.model.input.structured.StructuredPrompt @StructuredPrompt}
* @return {@link dev.langchain4j.data.message.AiMessage AiMessage}
*/
AiMessage sendUserMessage(Object structuredPrompt);

AiMessage sendMessages(ChatMessage... messages);
default AiMessage sendUserMessage(Object structuredPrompt) {
Prompt prompt = toPrompt(structuredPrompt);
return sendUserMessage(prompt.toUserMessage());
}

default AiMessage sendMessages(ChatMessage... messages) {
return sendMessages(asList(messages));
}

AiMessage sendMessages(List<ChatMessage> messages);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,28 @@
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.input.Prompt;

import java.util.List;

import static dev.langchain4j.data.message.UserMessage.userMessage;
import static dev.langchain4j.model.input.structured.StructuredPromptProcessor.toPrompt;
import static java.util.Collections.singletonList;

public interface StreamingChatLanguageModel {

void sendUserMessage(String userMessage, StreamingResponseHandler handler);
default void sendUserMessage(String userMessage, StreamingResponseHandler handler) {
sendUserMessage(userMessage(userMessage), handler);
}

void sendUserMessage(UserMessage userMessage, StreamingResponseHandler handler);
default void sendUserMessage(UserMessage userMessage, StreamingResponseHandler handler) {
sendMessages(singletonList(userMessage), handler);
}

void sendUserMessage(Object structuredPrompt, StreamingResponseHandler handler);
default void sendUserMessage(Object structuredPrompt, StreamingResponseHandler handler) {
Prompt prompt = toPrompt(structuredPrompt);
sendUserMessage(prompt.toUserMessage(), handler);
}

void sendMessages(List<ChatMessage> messages, StreamingResponseHandler handler);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,33 @@

import java.util.List;

import static dev.langchain4j.data.message.UserMessage.userMessage;
import static dev.langchain4j.model.input.structured.StructuredPromptProcessor.toPrompt;
import static java.util.Collections.singletonList;

public interface TokenCountEstimator {

int estimateTokenCount(String text);
default int estimateTokenCount(String text) {
return estimateTokenCount(userMessage(text));
}

int estimateTokenCount(UserMessage userMessage);
default int estimateTokenCount(UserMessage userMessage) {
return estimateTokenCount(singletonList(userMessage));
}

@MightChangeInTheFuture("not sure this method is useful/needed")
int estimateTokenCount(Prompt prompt);
default int estimateTokenCount(Prompt prompt) {
return estimateTokenCount(prompt.text());
}

@MightChangeInTheFuture("not sure this method is useful/needed")
int estimateTokenCount(Object structuredPrompt);
default int estimateTokenCount(Object structuredPrompt) {
return estimateTokenCount(toPrompt(structuredPrompt));
}

int estimateTokenCount(List<ChatMessage> messages);

int estimateTokenCount(TextSegment textSegment);
default int estimateTokenCount(TextSegment textSegment) {
return estimateTokenCount(textSegment.text());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ public interface TokenCountEstimator {

int estimateTokenCount(String text);

int estimateTokenCount(TextSegment textSegment);
default int estimateTokenCount(TextSegment textSegment) {
return estimateTokenCount(textSegment.text());
}

int estimateTokenCount(List<TextSegment> textSegments);
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import dev.langchain4j.model.input.Prompt;

import static dev.langchain4j.model.input.structured.StructuredPromptProcessor.toPrompt;

/**
* Represents a LLM with a simple text interface.
* It is recommended to use the ChatLanguageModel instead, as it offers greater capabilities.
Expand All @@ -11,7 +13,11 @@ public interface LanguageModel {

String process(String text);

String process(Prompt prompt);
default String process(Prompt prompt) {
return process(prompt.text());
}

String process(Object structuredPrompt);
default String process(Object structuredPrompt) {
return process(toPrompt(structuredPrompt));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,17 @@
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.input.Prompt;

import static dev.langchain4j.model.input.structured.StructuredPromptProcessor.toPrompt;

public interface StreamingLanguageModel {

void process(String text, StreamingResponseHandler handler);

void process(Prompt prompt, StreamingResponseHandler handler);
default void process(Prompt prompt, StreamingResponseHandler handler) {
process(prompt.text(), handler);
}

void process(Object structuredPrompt, StreamingResponseHandler handler);
default void process(Object structuredPrompt, StreamingResponseHandler handler) {
process(toPrompt(structuredPrompt), handler);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,21 @@
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.input.Prompt;

import static dev.langchain4j.model.input.structured.StructuredPromptProcessor.toPrompt;

public interface TokenCountEstimator {

int estimateTokenCount(String text);

int estimateTokenCount(Prompt prompt);
default int estimateTokenCount(Prompt prompt) {
return estimateTokenCount(prompt.text());
}

int estimateTokenCount(Object structuredPrompt);
default int estimateTokenCount(Object structuredPrompt) {
return estimateTokenCount(toPrompt(structuredPrompt));
}

int estimateTokenCount(TextSegment textSegment);
default int estimateTokenCount(TextSegment textSegment) {
return estimateTokenCount(textSegment.text());
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package dev.langchain4j.model.input;
package dev.langchain4j.model.input.structured;

import dev.langchain4j.model.input.structured.StructuredPrompt;
import dev.langchain4j.model.input.Prompt;
import org.junit.jupiter.api.Test;

import java.util.List;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,14 @@
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt;

import java.time.Duration;
import java.util.List;

import static dev.langchain4j.data.message.AiMessage.aiMessage;
import static dev.langchain4j.data.message.UserMessage.userMessage;
import static dev.langchain4j.internal.Utils.isNullOrBlank;
import static dev.langchain4j.model.huggingface.HuggingFaceModelName.TII_UAE_FALCON_7B_INSTRUCT;
import static dev.langchain4j.model.input.structured.StructuredPromptProcessor.toPrompt;
import static java.util.Arrays.asList;
import static java.util.stream.Collectors.joining;

public class HuggingFaceChatModel implements ChatLanguageModel {
Expand Down Expand Up @@ -51,27 +46,6 @@ public HuggingFaceChatModel(Builder builder) {
this.waitForModel = builder.waitForModel;
}

@Override
public AiMessage sendUserMessage(String userMessage) {
return sendUserMessage(userMessage(userMessage));
}

@Override
public AiMessage sendUserMessage(UserMessage userMessage) {
return sendMessages(userMessage);
}

@Override
public AiMessage sendUserMessage(Object structuredPrompt) {
Prompt prompt = toPrompt(structuredPrompt);
return sendUserMessage(prompt.toUserMessage());
}

@Override
public AiMessage sendMessages(ChatMessage... messages) {
return sendMessages(asList(messages));
}

@Override
public AiMessage sendMessages(List<ChatMessage> messages) {
return sendMessages(messages, null);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
package dev.langchain4j.model.huggingface;

import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.language.LanguageModel;

import java.time.Duration;

import static dev.langchain4j.model.huggingface.HuggingFaceModelName.TII_UAE_FALCON_7B_INSTRUCT;
import static dev.langchain4j.model.input.structured.StructuredPromptProcessor.toPrompt;

public class HuggingFaceLanguageModel implements LanguageModel {

Expand Down Expand Up @@ -62,17 +60,6 @@ public String process(String text) {
return response.generatedText();
}

@Override
public String process(Prompt prompt) {
return this.process(prompt.text());
}

@Override
public String process(Object structuredPrompt) {
return process(toPrompt(structuredPrompt));
}


public static Builder builder() {
return new Builder();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,16 @@
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.TokenCountEstimator;
import dev.langchain4j.model.input.Prompt;
import lombok.Builder;

import java.time.Duration;
import java.util.List;

import static dev.langchain4j.data.message.UserMessage.userMessage;
import static dev.langchain4j.internal.RetryUtils.withRetry;
import static dev.langchain4j.model.input.structured.StructuredPromptProcessor.toPrompt;
import static dev.langchain4j.model.openai.OpenAiHelper.*;
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
import static java.util.Arrays.asList;
import static java.util.Collections.singletonList;

public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {

Expand Down Expand Up @@ -79,27 +72,6 @@ public OpenAiChatModel(String apiKey,
this.tokenizer = new OpenAiTokenizer(this.modelName);
}

@Override
public AiMessage sendUserMessage(String userMessage) {
return sendUserMessage(userMessage(userMessage));
}

@Override
public AiMessage sendUserMessage(UserMessage userMessage) {
return sendMessages(userMessage);
}

@Override
public AiMessage sendUserMessage(Object structuredPrompt) {
Prompt prompt = toPrompt(structuredPrompt);
return sendUserMessage(prompt.toUserMessage());
}

@Override
public AiMessage sendMessages(ChatMessage... messages) {
return sendMessages(asList(messages));
}

@Override
public AiMessage sendMessages(List<ChatMessage> messages) {
return sendMessages(messages, null);
Expand All @@ -124,36 +96,11 @@ public AiMessage sendMessages(List<ChatMessage> messages, List<ToolSpecification
return aiMessageFrom(response);
}

@Override
public int estimateTokenCount(String text) {
return estimateTokenCount(userMessage(text));
}

@Override
public int estimateTokenCount(UserMessage userMessage) {
return estimateTokenCount(singletonList(userMessage));
}

@Override
public int estimateTokenCount(Prompt prompt) {
return estimateTokenCount(prompt.text());
}

@Override
public int estimateTokenCount(Object structuredPrompt) {
return estimateTokenCount(toPrompt(structuredPrompt));
}

@Override
public int estimateTokenCount(List<ChatMessage> messages) {
return tokenizer.countTokens(messages);
}

@Override
public int estimateTokenCount(TextSegment textSegment) {
return estimateTokenCount(textSegment.text());
}

public static OpenAiChatModel withApiKey(String apiKey) {
return builder().apiKey(apiKey).build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,6 @@ public int estimateTokenCount(String text) {
return tokenizer.countTokens(text);
}

@Override
public int estimateTokenCount(TextSegment textSegment) {
return estimateTokenCount(textSegment.text());
}

@Override
public int estimateTokenCount(List<TextSegment> textSegments) {
int tokenCount = 0;
Expand Down
Loading

0 comments on commit 8d6f9f6

Please sign in to comment.