Skip to content

Commit

Permalink
refactor: AssertJ best practices (langchain4j#622)
Browse files Browse the repository at this point in the history
Hi! Noticed _almost all_ tests used AssertJ, but in some cases JUnit was
still used. In addition to that some tests don't use the most expressive
assertions. Figured clean that up such that you get better assertions if
any tests were to fail. Compare for instance
```diff
-        assertThat(document.metadata().asMap().size()).isEqualTo(4);
+        assertThat(document.metadata().asMap()).hasSize(4);
```
The first one will print expected 5 to be equal to 4, whereas the second
one shows the contents of the map involved.

Being consistent with your test library also stops bad patterns from
repeating accidentally through copy-and-paste. If you want to enforce
these best practices through an automated pull request check that's also
an option. Let me know if you'd want that as well. Hope that helps!
  • Loading branch information
timtebeek authored Mar 5, 2024
1 parent 677d3e0 commit 5f522e5
Show file tree
Hide file tree
Showing 16 changed files with 87 additions and 94 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ public void should_load_single_document() {

// then
assertThat(document.text()).isEqualTo(TEST_CONTENT);
assertThat(document.metadata().asMap().size()).isEqualTo(1);
assertThat(document.metadata().asMap()).hasSize(1);
assertThat(document.metadata("source")).isEqualTo("s3://test-bucket/test-file.txt");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public void should_load_single_document() {
Document document = loader.loadDocument(TEST_CONTAINER, TEST_BLOB, parser);

assertThat(document.text()).isEqualTo(TEST_CONTENT);
assertThat(document.metadata().asMap().size()).isEqualTo(4);
assertThat(document.metadata().asMap()).hasSize(4);
assertThat(document.metadata("source")).endsWith("/test-file.txt");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public void should_load_single_document() {
Document document = loader.loadDocument(TEST_CONTAINER, TEST_BLOB, parser);

assertThat(document.text()).isEqualTo(TEST_CONTENT);
assertThat(document.metadata().asMap().size()).isEqualTo(4);
assertThat(document.metadata().asMap()).hasSize(4);
assertThat(document.metadata("source")).endsWith("/test-file.txt");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public void should_load_file() {
Document document = loader.loadDocument(TEST_OWNER, TEST_REPO, "main", "pom.xml", parser);

assertThat(document.text()).contains("<groupId>dev.langchain4j</groupId>");
assertThat(document.metadata().asMap().size()).isEqualTo(9);
assertThat(document.metadata().asMap()).hasSize(9);
assertThat(document.metadata("github_git_url")).startsWith("https://api.github.com/repos/langchain4j/langchain4j");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
import java.nio.file.Path;
import java.util.Base64;

import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.assertj.core.api.Assertions.assertThat;

public class AzureOpenAiImageModelIT {

Expand All @@ -34,13 +33,13 @@ void should_generate_image_with_url() {
logger.info(response.toString());

Image image = response.content();
assertNotNull(image);
assertNotNull(image.url());
assertThat(image).isNotNull();
assertThat(image.url()).isNotNull();
logger.info("The remote image is here: {}", image.url());

assertNull(image.base64Data());
assertThat(image.base64Data()).isNull();

assertNotNull(image.revisedPrompt());
assertThat(image.revisedPrompt()).isNotNull();
logger.info("The revised prompt is: {}", image.revisedPrompt());
}

Expand All @@ -57,9 +56,9 @@ void should_generate_image_in_base64() throws IOException {
Response<Image> response = model.generate("A croissant in Paris, France");

Image image = response.content();
assertNotNull(image);
assertNull(image.url());
assertNotNull(image.base64Data());
assertThat(image).isNotNull();
assertThat(image.url()).isNull();
assertThat(image.base64Data()).isNotNull();
logger.info("The image data is: {} characters", image.base64Data().length());

if (logger.isDebugEnabled()) {
Expand All @@ -69,7 +68,7 @@ void should_generate_image_in_base64() throws IOException {
logger.debug("The image is here: {}", temp.toAbsolutePath());
}

assertNotNull(image.revisedPrompt());
assertThat(image.revisedPrompt()).isNotNull();
logger.info("The revised prompt is: {}", image.revisedPrompt());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
import java.util.Collection;
import java.util.List;

import static org.junit.jupiter.api.Assertions.*;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;

class InternalAzureOpenAiHelperTest {

Expand All @@ -33,7 +34,7 @@ void setupOpenAIClientShouldReturnClientWithCorrectConfiguration() {

OpenAIClient client = InternalAzureOpenAiHelper.setupOpenAIClient(endpoint, serviceVersion, apiKey, timeout, maxRetries, null, logRequestsAndResponses);

assertNotNull(client);
assertThat(client).isNotNull();
}

@Test
Expand All @@ -42,7 +43,7 @@ void getOpenAIServiceVersionShouldReturnCorrectVersion() {

OpenAIServiceVersion version = InternalAzureOpenAiHelper.getOpenAIServiceVersion(serviceVersion);

assertEquals(serviceVersion, version.getVersion());
assertThat(version.getVersion()).isEqualTo(serviceVersion);
}

@Test
Expand All @@ -51,7 +52,7 @@ void getOpenAIServiceVersionShouldReturnLatestVersionIfIncorrect() {

OpenAIServiceVersion version = InternalAzureOpenAiHelper.getOpenAIServiceVersion(serviceVersion);

assertEquals(OpenAIServiceVersion.getLatest().getVersion(), version.getVersion());
assertThat(version.getVersion()).isEqualTo(OpenAIServiceVersion.getLatest().getVersion());
}

@Test
Expand All @@ -61,7 +62,7 @@ void toOpenAiMessagesShouldReturnCorrectMessages() {

List<ChatRequestMessage> openAiMessages = InternalAzureOpenAiHelper.toOpenAiMessages(messages);

assertEquals(messages.size(), openAiMessages.size());
assertThat(openAiMessages).hasSize(messages.size());
assertInstanceOf(ChatRequestUserMessage.class, openAiMessages.get(0));
}

Expand All @@ -76,14 +77,14 @@ void toFunctionsShouldReturnCorrectFunctions() {

List<FunctionDefinition> functions = InternalAzureOpenAiHelper.toFunctions(toolSpecifications);

assertEquals(toolSpecifications.size(), functions.size());
assertEquals(toolSpecifications.iterator().next().name(), functions.get(0).getName());
assertThat(functions).hasSize(toolSpecifications.size());
assertThat(functions.get(0).getName()).isEqualTo(toolSpecifications.iterator().next().name());
}

@Test
void finishReasonFromShouldReturnCorrectFinishReason() {
CompletionsFinishReason completionsFinishReason = CompletionsFinishReason.STOPPED;
FinishReason finishReason = InternalAzureOpenAiHelper.finishReasonFrom(completionsFinishReason);
assertEquals(FinishReason.STOP, finishReason);
assertThat(finishReason).isEqualTo(FinishReason.STOP);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@

import static com.dtsx.astra.sdk.utils.TestUtils.getAstraToken;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;

@Disabled("AstraDB is not available in the CI")
@TestMethodOrder(MethodOrderer.OrderAnnotation.class)
Expand All @@ -48,12 +45,12 @@ class AstraDbEmbeddingStoreIT extends EmbeddingStoreIT {
public static void initStoreForTests() {
AstraDBAdmin astraDBAdminClient = new AstraDBAdmin(getAstraToken());
dbId = astraDBAdminClient.createDatabase(TEST_DB);
assertNotNull(dbId);
assertThat(dbId).isNotNull();
log.info("[init] - Database exists id={}", dbId);

// Select the Database as working object
db = astraDBAdminClient.database(dbId);
assertNotNull(db);
assertThat(db).isNotNull();

AstraDBCollection collection =
db.createCollection(TEST_COLLECTION, 1536, SimilarityMetric.cosine);
Expand Down Expand Up @@ -90,11 +87,11 @@ void testAddEmbeddingAndFindRelevant() {
Embedding embedding = Embedding.from(new float[]{9.9F, 4.5F, 3.5F, 1.3F, 1.7F, 5.7F, 6.4F, 5.5F, 8.2F, 9.3F, 1.5F});
TextSegment textSegment = TextSegment.from("Text", Metadata.from("Key", "Value"));
String id = embeddingStore.add(embedding, textSegment);
assertTrue(id != null && !id.isEmpty());
assertThat(id != null && !id.isEmpty()).isTrue();

Embedding refereceEmbedding = Embedding.from(new float[]{8.7F, 4.5F, 3.4F, 1.2F, 5.5F, 5.6F, 6.4F, 5.5F, 8.1F, 9.1F, 1.1F});
List<EmbeddingMatch<TextSegment>> embeddingMatches = embeddingStore.findRelevant(refereceEmbedding, 1);
assertEquals(1, embeddingMatches.size());
assertThat(embeddingMatches).hasSize(1);

EmbeddingMatch<TextSegment> embeddingMatch = embeddingMatches.get(0);
assertThat(embeddingMatch.score()).isBetween(0d, 1d);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
import static java.util.Arrays.asList;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.data.Percentage.withPercentage;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;

@Slf4j
@TestMethodOrder(MethodOrderer.OrderAnnotation.class)
Expand Down Expand Up @@ -73,10 +71,10 @@ void should_retrieve_inserted_vector_by_ann() {
Embedding sourceEmbedding = embeddingModel().embed(sourceSentence).content();
TextSegment sourceTextSegment = TextSegment.from(sourceSentence);
String id = embeddingStore().add(sourceEmbedding, sourceTextSegment);
assertTrue(id != null && !id.isEmpty());
assertThat(id != null && !id.isEmpty()).isTrue();

List<EmbeddingMatch<TextSegment>> embeddingMatches = embeddingStore.findRelevant(sourceEmbedding, 10);
assertEquals(1, embeddingMatches.size());
assertThat(embeddingMatches).hasSize(1);

EmbeddingMatch<TextSegment> embeddingMatch = embeddingMatches.get(0);
assertThat(embeddingMatch.score()).isBetween(0d, 1d);
Expand All @@ -93,21 +91,21 @@ void should_retrieve_inserted_vector_by_ann_and_metadata() {
.add("user", "GOD")
.add("test", "false"));
String id = embeddingStore().add(sourceEmbedding, sourceTextSegment);
assertTrue(id != null && !id.isEmpty());
assertThat(id != null && !id.isEmpty()).isTrue();

// Should be found with no filter
List<EmbeddingMatch<TextSegment>> matchesAnnOnly = embeddingStore
.findRelevant(sourceEmbedding, 10);
assertEquals(1, matchesAnnOnly.size());
assertThat(matchesAnnOnly).hasSize(1);

// Should retrieve if user is god
List<EmbeddingMatch<TextSegment>> matchesGod = embeddingStore
.findRelevant(sourceEmbedding, 10, .5d, Metadata.from("user", "GOD"));
assertEquals(1, matchesGod.size());
assertThat(matchesGod).hasSize(1);

List<EmbeddingMatch<TextSegment>> matchesJohn = embeddingStore
.findRelevant(sourceEmbedding, 10, .5d, Metadata.from("user", "JOHN"));
assertEquals(0, matchesJohn.size());
assertThat(matchesJohn).isEmpty();
}

// metrics returned are 1.95% off we updated to "withPercentage(2)"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import static com.dtsx.astra.sdk.utils.TestUtils.TEST_REGION;
import static com.dtsx.astra.sdk.utils.TestUtils.getAstraToken;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.assertj.core.api.Assertions.assertThat;

/**
* Test Cassandra Chat Memory Store with a Saas DB.
Expand All @@ -25,9 +25,9 @@ class CassandraChatMemoryStoreAstraIT extends CassandraChatMemoryStoreTestSuppor
@Override
void createDatabase() {
token = getAstraToken();
assertNotNull(token);
assertThat(token).isNotNull();
dbId = new AstraDBAdmin(token).createDatabase(DB, CloudProviderType.GCP, "us-east1");
assertNotNull(dbId);
assertThat(dbId).isNotNull();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import dev.langchain4j.memory.chat.TokenWindowChatMemory;
import dev.langchain4j.model.openai.OpenAiTokenizer;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.MethodOrderer;
import org.junit.jupiter.api.Order;
Expand Down Expand Up @@ -42,10 +41,9 @@ void shouldConnectToDatabase() {
chatMemoryStore = createChatMemoryStore();
log.info("Chat memory store is created.");
// Connection to Cassandra is established
Assertions.assertTrue(chatMemoryStore.getCassandraSession()
assertThat(chatMemoryStore.getCassandraSession()
.getMetadata()
.getKeyspace(KEYSPACE)
.isPresent());
.getKeyspace(KEYSPACE)).isPresent();
log.info("Chat memory table is present.");
}

Expand All @@ -55,10 +53,10 @@ void shouldConnectToDatabase() {
void shouldCreateChatMemoryStore() {
chatMemoryStore.create();
// Table exists
Assertions.assertTrue(chatMemoryStore.getCassandraSession()
assertThat(chatMemoryStore.getCassandraSession()
.refreshSchema()
.getKeyspace(KEYSPACE).get()
.getTable(CassandraChatMemoryStore.DEFAULT_TABLE_NAME).isPresent());
.getTable(CassandraChatMemoryStore.DEFAULT_TABLE_NAME)).isPresent();
chatMemoryStore.clear();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
import static dev.langchain4j.model.openai.OpenAiModelName.TEXT_EMBEDDING_ADA_002;
import static java.time.Duration.ofSeconds;
import static java.util.stream.Collectors.joining;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.assertj.core.api.Assertions.assertThat;

class DocumentLoaderAndRagWithAstraTest {

Expand All @@ -54,11 +54,11 @@ void shouldRagWithOpenAiAndAstra() {

// Database Id
UUID databaseId = new AstraDBAdmin(getAstraToken()).createDatabase(DB_NAME);
assertNotNull(databaseId);
assertThat(databaseId).isNotNull();

// OpenAI Key
String openAIKey = System.getenv("OPENAI_API_KEY");
assertNotNull(openAIKey);
assertThat(openAIKey).isNotNull();

// --- Documents Ingestion ---

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
import static dev.langchain4j.model.openai.OpenAiModelName.TEXT_EMBEDDING_ADA_002;
import static java.time.Duration.ofSeconds;
import static java.util.stream.Collectors.joining;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.assertj.core.api.Assertions.assertThat;

public class WebPageLoaderAndRagWIthAstraTest {

Expand All @@ -57,11 +57,11 @@ void shouldRagWithOpenAiAndAstra() throws IOException {

// Database Id
UUID databaseId = new AstraDBAdmin(getAstraToken()).createDatabase(DB_NAME);
assertNotNull(databaseId);
assertThat(databaseId).isNotNull();

// OpenAI Key
String openAIKey = System.getenv("OPENAI_API_KEY");
assertNotNull(openAIKey);
assertThat(openAIKey).isNotNull();

// --- Documents Ingestion ---

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
import java.util.List;
import java.util.stream.Stream;

import static org.junit.jupiter.api.Assertions.fail;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.fail;

public class QwenTestHelper {
public static Stream<Arguments> languageModelNameProvider() {
Expand Down Expand Up @@ -85,14 +85,14 @@ public static List<ChatMessage> multimodalChatMessagesWithImageData() {
public static String multimodalImageData() {
ByteArrayOutputStream buffer = new ByteArrayOutputStream();
try (InputStream in = QwenTestHelper.class.getResourceAsStream("/parrot.jpg")) {
assertNotNull(in);
assertThat(in).isNotNull();
byte[] data = new byte[512];
int n;
while ((n = in.read(data)) != -1) {
buffer.write(data, 0, n);
}
} catch (IOException e) {
fail(e.getMessage());
fail("", e.getMessage());
}

return Base64.getEncoder().encodeToString(buffer.toByteArray());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ void should_embed_and_return_token_usage_with_multiple_inputs() {
Response<List<Embedding>> response = model.embedAll(asList(textSegment1, textSegment2));

// then
assertThat(response.content().size()).isEqualTo(2);
assertThat(response.content()).hasSize(2);
assertThat(response.content().get(0).vector()).hasSize(1024);
assertThat(response.content().get(1).vector()).hasSize(1024);

Expand Down
Loading

0 comments on commit 5f522e5

Please sign in to comment.