Skip to content

Commit

Permalink
VertexAI: run ITs if environment variable is available
Browse files Browse the repository at this point in the history
  • Loading branch information
langchain4j committed Jan 26, 2024
1 parent 9ba17b5 commit 96c5705
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,19 @@
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;

import static org.assertj.core.api.Assertions.assertThat;

@Disabled("To run this test, you must provide your own endpoint, project and location")
@EnabledIfEnvironmentVariable(named = "VERTEXAI_ENDPOINT", matches = ".+")
class VertexAiChatModelIT {

@Test
void testChatModel() {

VertexAiChatModel vertexAiChatModel = VertexAiChatModel.builder()
.endpoint("us-central1-aiplatform.googleapis.com:443")
.endpoint(System.getenv("VERTEXAI_ENDPOINT"))
.project("langchain4j")
.location("us-central1")
.publisher("google")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,23 @@
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;

import java.util.Arrays;
import java.util.List;

import static java.util.Arrays.asList;
import static org.assertj.core.api.Assertions.assertThat;

@Disabled("To run this test, you must provide your own endpoint, project and location")
@EnabledIfEnvironmentVariable(named = "VERTEXAI_ENDPOINT", matches = ".+")
class VertexAiEmbeddingModelIT {

@Test
void testEmbeddingModel() {

EmbeddingModel embeddingModel = VertexAiEmbeddingModel.builder()
.endpoint("us-central1-aiplatform.googleapis.com:443")
.endpoint(System.getenv("VERTEXAI_ENDPOINT"))
.project("langchain4j")
.location("us-central1")
.publisher("google")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import dev.langchain4j.data.image.Image;
import dev.langchain4j.model.output.Response;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;

import java.io.File;
import java.io.IOException;
Expand All @@ -20,10 +20,10 @@
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertThrows;

@Disabled("To run this test, you must provide your own endpoint, project and location")
@EnabledIfEnvironmentVariable(named = "VERTEXAI_ENDPOINT", matches = ".+")
public class VertexAiImageModelIT {

private static final String ENDPOINT = "us-central1-aiplatform.googleapis.com:443";
private static final String ENDPOINT = System.getenv("VERTEXAI_ENDPOINT");
private static final String LOCATION = "us-central1";
private static final String PROJECT = "langchain4j";
private static final String PUBLISHER = "google";
Expand All @@ -33,9 +33,9 @@ private static Image fromPath(Path path) {
byte[] allBytes = Files.readAllBytes(path);
String base64 = Base64.getEncoder().encodeToString(allBytes);
return Image.builder()
.url(path.toUri())
.base64Data(base64)
.build();
.url(path.toUri())
.base64Data(base64)
.build();
} catch (IOException e) {
throw new RuntimeException(e);
}
Expand All @@ -44,14 +44,14 @@ private static Image fromPath(Path path) {
@Test
public void should_generate_one_image_with_persistence() {
VertexAiImageModel imagenModel = VertexAiImageModel.builder()
.endpoint(ENDPOINT)
.location(LOCATION)
.project(PROJECT)
.publisher(PUBLISHER)
.modelName("imagegeneration@005")
.maxRetries(2)
.withPersisting()
.build();
.endpoint(ENDPOINT)
.location(LOCATION)
.project(PROJECT)
.publisher(PUBLISHER)
.modelName("imagegeneration@005")
.maxRetries(2)
.withPersisting()
.build();

Response<Image> imageResponse = imagenModel.generate("watercolor of a colorful parrot drinking a cup of coffee");
System.out.println(imageResponse.content().url());
Expand All @@ -68,13 +68,13 @@ public void should_generate_one_image_with_persistence() {
@Test
public void should_generate_three_images_with_persistence() {
VertexAiImageModel imagenModel = VertexAiImageModel.builder()
.endpoint(ENDPOINT)
.location(LOCATION)
.project(PROJECT)
.publisher(PUBLISHER)
.modelName("imagegeneration@005")
.withPersisting()
.build();
.endpoint(ENDPOINT)
.location(LOCATION)
.project(PROJECT)
.publisher(PUBLISHER)
.modelName("imagegeneration@005")
.withPersisting()
.build();

Response<List<Image>> imageListResponse = imagenModel.generate("photo of a sunset over Malibu beach", 3);

Expand All @@ -89,17 +89,17 @@ public void should_generate_three_images_with_persistence() {
@Test
public void should_use_image_style_seed_image_source_and_mask_for_editing() throws URISyntaxException {
VertexAiImageModel model = VertexAiImageModel.builder()
.endpoint(ENDPOINT)
.location(LOCATION)
.project(PROJECT)
.publisher(PUBLISHER)
.modelName("imagegeneration@002")
.seed(19707L)
.sampleImageStyle(VertexAiImageModel.ImageStyle.photograph)
.guidanceScale(100)
.maxRetries(4)
.withPersisting()
.build();
.endpoint(ENDPOINT)
.location(LOCATION)
.project(PROJECT)
.publisher(PUBLISHER)
.modelName("imagegeneration@002")
.seed(19707L)
.sampleImageStyle(VertexAiImageModel.ImageStyle.photograph)
.guidanceScale(100)
.maxRetries(4)
.withPersisting()
.build();

Response<Image> forestResp = model.generate("lush forest");
System.out.println(forestResp.content().url());
Expand All @@ -109,7 +109,7 @@ public void should_use_image_style_seed_image_source_and_mask_for_editing() thro
URI maskFileUri = Objects.requireNonNull(getClass().getClassLoader().getResource("mask.png")).toURI();

Response<Image> compositeResp = model.edit(
forestResp.content(), fromPath(Paths.get(maskFileUri)), "red trees"
forestResp.content(), fromPath(Paths.get(maskFileUri)), "red trees"
);
System.out.println(compositeResp.content().url());

Expand All @@ -121,39 +121,39 @@ public void should_use_persistTo_and_image_upscaling() {
Path defaultTempDirPath = Paths.get(System.getProperty("java.io.tmpdir"));

VertexAiImageModel imagenModel = VertexAiImageModel.builder()
.endpoint(ENDPOINT)
.location(LOCATION)
.project(PROJECT)
.publisher(PUBLISHER)
.modelName("imagegeneration@002")
.sampleImageSize(1024)
.withPersisting()
.persistTo(defaultTempDirPath)
.maxRetries(3)
.build();
.endpoint(ENDPOINT)
.location(LOCATION)
.project(PROJECT)
.publisher(PUBLISHER)
.modelName("imagegeneration@002")
.sampleImageSize(1024)
.withPersisting()
.persistTo(defaultTempDirPath)
.maxRetries(3)
.build();

Response<Image> imageResponse =
imagenModel.generate("A black bird looking itself in an antique mirror");
imagenModel.generate("A black bird looking itself in an antique mirror");
System.out.println(imageResponse.content().url());

assertThat(imageResponse.content().url()).isNotNull();
assertThat(new File(imageResponse.content().url())).exists();
assertThat(imageResponse.content().base64Data()).isNotNull();

VertexAiImageModel imagenModelForUpscaling = VertexAiImageModel.builder()
.endpoint(ENDPOINT)
.location(LOCATION)
.project(PROJECT)
.publisher(PUBLISHER)
.modelName("imagegeneration@002")
.sampleImageSize(4096)
.withPersisting()
.persistTo(defaultTempDirPath)
.maxRetries(3)
.build();
.endpoint(ENDPOINT)
.location(LOCATION)
.project(PROJECT)
.publisher(PUBLISHER)
.modelName("imagegeneration@002")
.sampleImageSize(4096)
.withPersisting()
.persistTo(defaultTempDirPath)
.maxRetries(3)
.build();

Response<Image> upscaledImageResponse =
imagenModelForUpscaling.edit(imageResponse.content(), "");
imagenModelForUpscaling.edit(imageResponse.content(), "");
System.out.println(upscaledImageResponse.content().url());

assertThat(upscaledImageResponse.content().url()).isNotNull();
Expand All @@ -164,16 +164,16 @@ public void should_use_persistTo_and_image_upscaling() {
@Test
public void should_use_negative_prompt_and_different_prompt_language() {
VertexAiImageModel imagenModel = VertexAiImageModel.builder()
.endpoint(ENDPOINT)
.location(LOCATION)
.project(PROJECT)
.publisher(PUBLISHER)
.modelName("imagegeneration@005")
.language("ja")
.negativePrompt("pepperoni, pineapple")
.maxRetries(2)
.withPersisting()
.build();
.endpoint(ENDPOINT)
.location(LOCATION)
.project(PROJECT)
.publisher(PUBLISHER)
.modelName("imagegeneration@005")
.language("ja")
.negativePrompt("pepperoni, pineapple")
.maxRetries(2)
.withPersisting()
.build();

Response<Image> imageResponse = imagenModel.generate("ピザ"); // pizza
System.out.println(imageResponse.content().url());
Expand All @@ -185,13 +185,13 @@ public void should_use_negative_prompt_and_different_prompt_language() {
@Test
public void should_raise_error_on_problematic_prompt_or_content_generation() {
VertexAiImageModel imagenModel = VertexAiImageModel.builder()
.endpoint(ENDPOINT)
.location(LOCATION)
.project(PROJECT)
.publisher(PUBLISHER)
.modelName("imagegeneration@005")
.withPersisting()
.build();
.endpoint(ENDPOINT)
.location(LOCATION)
.project(PROJECT)
.publisher(PUBLISHER)
.modelName("imagegeneration@005")
.withPersisting()
.build();

assertThrows(Throwable.class, () -> imagenModel.generate("a nude woman"));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@

import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;

import static org.assertj.core.api.Assertions.assertThat;

@Disabled("To run this test, you must provide your own endpoint, project and location")
@EnabledIfEnvironmentVariable(named = "VERTEXAI_ENDPOINT", matches = ".+")
class VertexAiLanguageModelIT {

@Test
void testLanguageModel() {
VertexAiLanguageModel vertexAiLanguageModel = VertexAiLanguageModel.builder()
.endpoint("us-central1-aiplatform.googleapis.com:443")
.endpoint(System.getenv("VERTEXAI_ENDPOINT"))
.project("langchain4j")
.location("us-central1")
.publisher("google")
Expand Down

0 comments on commit 96c5705

Please sign in to comment.