Skip to content

Commit

Permalink
fix(vertex_ai): Fix typo in stop sequences field deserialization (#152)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmigloz committed Sep 5, 2023
1 parent 8a2199e commit 4f7161d
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 21 deletions.
5 changes: 3 additions & 2 deletions packages/vertex_ai/lib/src/gen_ai/models/model.dart
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ class VertexAIPredictionSafetyAttributes {
final Map<String, dynamic> safetyAttributesJson,
) {
return VertexAIPredictionSafetyAttributes(
categories: (safetyAttributesJson['categories'] as List<dynamic>)
categories: (safetyAttributesJson['categories'] as List<dynamic>? ??
const [])
.map(
(final category) => switch (category) {
'Derogatory' =>
Expand Down Expand Up @@ -142,7 +143,7 @@ class VertexAIPredictionSafetyAttributes {
.toList(growable: false),
scores:
(safetyAttributesJson['scores'] as List<dynamic>? ?? const []).cast(),
blocked: safetyAttributesJson['blocked'] as bool,
blocked: safetyAttributesJson['blocked'] as bool? ?? false,
);
}

Expand Down
16 changes: 8 additions & 8 deletions packages/vertex_ai/lib/src/gen_ai/models/text.dart
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class VertexAITextModelRequestParams {
this.maxOutputTokens = 1024,
this.topP = 0.95,
this.topK = 40,
this.stopSequence = const [],
this.stopSequences = const [],
this.candidateCount = 1,
});

Expand Down Expand Up @@ -114,7 +114,7 @@ class VertexAITextModelRequestParams {
/// if one of the strings is encountered in the response. If a string appears
/// multiple times in the response, then the response truncates where it's
/// first encountered. The strings are case-sensitive.
final List<String> stopSequence;
final List<String> stopSequences;

/// The number of response variations to return.
final int candidateCount;
Expand All @@ -126,7 +126,7 @@ class VertexAITextModelRequestParams {
'maxOutputTokens': maxOutputTokens,
'topP': topP,
'topK': topK,
'stopSequence': stopSequence,
'stopSequences': stopSequences,
'candidateCount': candidateCount,
};
}
Expand All @@ -140,8 +140,8 @@ class VertexAITextModelRequestParams {
topP == other.topP &&
topK == other.topK &&
const ListEquality<String>().equals(
stopSequence,
other.stopSequence,
stopSequences,
other.stopSequences,
) &&
candidateCount == other.candidateCount;

Expand All @@ -151,7 +151,7 @@ class VertexAITextModelRequestParams {
maxOutputTokens.hashCode ^
topP.hashCode ^
topK.hashCode ^
const ListEquality<String>().hash(stopSequence) ^
const ListEquality<String>().hash(stopSequences) ^
candidateCount.hashCode;

@override
Expand All @@ -161,7 +161,7 @@ class VertexAITextModelRequestParams {
'maxOutputTokens: $maxOutputTokens, '
'topP: $topP, '
'topK: $topK, '
'stopSequence: $stopSequence, '
'stopSequence: $stopSequences, '
'candidateCount: $candidateCount}';
}
}
Expand Down Expand Up @@ -231,7 +231,7 @@ class VertexAITextModelPrediction {
final Map<String, dynamic> predictionJson,
) {
final citationMetadata =
predictionJson['citationMetadata'] as Map<String, dynamic>;
predictionJson['citationMetadata'] as Map<String, dynamic>? ?? const {};
final citations =
citationMetadata['citations'] as List<dynamic>? ?? const [];
return VertexAITextModelPrediction(
Expand Down
14 changes: 7 additions & 7 deletions packages/vertex_ai/lib/src/gen_ai/models/text_chat.dart
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class VertexAITextChatModelRequestParams {
this.maxOutputTokens = 1024,
this.topP = 0.95,
this.topK = 40,
this.stopSequence = const [],
this.stopSequences = const [],
this.candidateCount = 1,
});

Expand Down Expand Up @@ -160,7 +160,7 @@ class VertexAITextChatModelRequestParams {
/// if one of the strings is encountered in the response. If a string appears
/// multiple times in the response, then the response truncates where it's
/// first encountered. The strings are case-sensitive.
final List<String> stopSequence;
final List<String> stopSequences;

/// The number of response variations to return.
final int candidateCount;
Expand All @@ -172,7 +172,7 @@ class VertexAITextChatModelRequestParams {
'maxOutputTokens': maxOutputTokens,
'topP': topP,
'topK': topK,
'stopSequence': stopSequence,
'stopSequences': stopSequences,
'candidateCount': candidateCount,
};
}
Expand All @@ -186,8 +186,8 @@ class VertexAITextChatModelRequestParams {
topP == other.topP &&
topK == other.topK &&
const ListEquality<String>().equals(
stopSequence,
other.stopSequence,
stopSequences,
other.stopSequences,
) &&
candidateCount == other.candidateCount;

Expand All @@ -197,7 +197,7 @@ class VertexAITextChatModelRequestParams {
maxOutputTokens.hashCode ^
topP.hashCode ^
topK.hashCode ^
const ListEquality<String>().hash(stopSequence) ^
const ListEquality<String>().hash(stopSequences) ^
candidateCount.hashCode;

@override
Expand All @@ -207,7 +207,7 @@ class VertexAITextChatModelRequestParams {
'maxOutputTokens: $maxOutputTokens, '
'topP: $topP, '
'topK: $topK, '
'stopSequence: $stopSequence, '
'stopSequence: $stopSequences, '
'candidateCount: $candidateCount}';
}
}
Expand Down
22 changes: 22 additions & 0 deletions packages/vertex_ai/test/gen_ai/gen_ai_client_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,28 @@ void main() async {
}
});

test('Test VertexAITextModelApi stop sequence', () async {
final res = await vertexAi.text.predict(
prompt:
'List the numbers from 1 to 9 in order without any spaces or commas.',
parameters: const VertexAITextModelRequestParams(
stopSequences: ['4'],
),
);
expect(res.predictions.first.content, contains('123'));
});

test('Test VertexAITextModelApi candidates count', () async {
final res = await vertexAi.text.predict(
prompt: 'Suggest a name for a LLM framework for Dart',
parameters: const VertexAITextModelRequestParams(
temperature: 1,
candidateCount: 3,
),
);
expect(res.predictions.length, 3);
});

test('Test VertexAIChatModelApi', () async {
final models = ['chat-bison', 'chat-bison-32k'];

Expand Down
4 changes: 2 additions & 2 deletions packages/vertex_ai/test/gen_ai/mappers/chat_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ void main() {
maxOutputTokens: 256,
topP: 0.1,
topK: 30,
stopSequence: ['STOP'],
stopSequences: ['STOP'],
candidateCount: 10,
),
);
Expand Down Expand Up @@ -64,7 +64,7 @@ void main() {
'maxOutputTokens': 256,
'topP': 0.1,
'topK': 30,
'stopSequence': ['STOP'],
'stopSequences': ['STOP'],
'candidateCount': 10,
},
);
Expand Down
4 changes: 2 additions & 2 deletions packages/vertex_ai/test/gen_ai/mappers/text_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ void main() {
maxOutputTokens: 256,
topP: 0.1,
topK: 30,
stopSequence: ['STOP'],
stopSequences: ['STOP'],
candidateCount: 10,
),
);
Expand All @@ -28,7 +28,7 @@ void main() {
'maxOutputTokens': 256,
'topP': 0.1,
'topK': 30,
'stopSequence': ['STOP'],
'stopSequences': ['STOP'],
'candidateCount': 10,
},
);
Expand Down

0 comments on commit 4f7161d

Please sign in to comment.