diff --git a/packages/vertex_ai/lib/src/gen_ai/models/model.dart b/packages/vertex_ai/lib/src/gen_ai/models/model.dart index de1ce1d8..77fa8d5c 100644 --- a/packages/vertex_ai/lib/src/gen_ai/models/model.dart +++ b/packages/vertex_ai/lib/src/gen_ai/models/model.dart @@ -110,7 +110,8 @@ class VertexAIPredictionSafetyAttributes { final Map safetyAttributesJson, ) { return VertexAIPredictionSafetyAttributes( - categories: (safetyAttributesJson['categories'] as List) + categories: (safetyAttributesJson['categories'] as List? ?? + const []) .map( (final category) => switch (category) { 'Derogatory' => @@ -142,7 +143,7 @@ class VertexAIPredictionSafetyAttributes { .toList(growable: false), scores: (safetyAttributesJson['scores'] as List? ?? const []).cast(), - blocked: safetyAttributesJson['blocked'] as bool, + blocked: safetyAttributesJson['blocked'] as bool? ?? false, ); } diff --git a/packages/vertex_ai/lib/src/gen_ai/models/text.dart b/packages/vertex_ai/lib/src/gen_ai/models/text.dart index d4ee80bc..e908ad06 100644 --- a/packages/vertex_ai/lib/src/gen_ai/models/text.dart +++ b/packages/vertex_ai/lib/src/gen_ai/models/text.dart @@ -52,7 +52,7 @@ class VertexAITextModelRequestParams { this.maxOutputTokens = 1024, this.topP = 0.95, this.topK = 40, - this.stopSequence = const [], + this.stopSequences = const [], this.candidateCount = 1, }); @@ -114,7 +114,7 @@ class VertexAITextModelRequestParams { /// if one of the strings is encountered in the response. If a string appears /// multiple times in the response, then the response truncates where it's /// first encountered. The strings are case-sensitive. - final List stopSequence; + final List stopSequences; /// The number of response variations to return. final int candidateCount; @@ -126,7 +126,7 @@ class VertexAITextModelRequestParams { 'maxOutputTokens': maxOutputTokens, 'topP': topP, 'topK': topK, - 'stopSequence': stopSequence, + 'stopSequences': stopSequences, 'candidateCount': candidateCount, }; } @@ -140,8 +140,8 @@ class VertexAITextModelRequestParams { topP == other.topP && topK == other.topK && const ListEquality().equals( - stopSequence, - other.stopSequence, + stopSequences, + other.stopSequences, ) && candidateCount == other.candidateCount; @@ -151,7 +151,7 @@ class VertexAITextModelRequestParams { maxOutputTokens.hashCode ^ topP.hashCode ^ topK.hashCode ^ - const ListEquality().hash(stopSequence) ^ + const ListEquality().hash(stopSequences) ^ candidateCount.hashCode; @override @@ -161,7 +161,7 @@ class VertexAITextModelRequestParams { 'maxOutputTokens: $maxOutputTokens, ' 'topP: $topP, ' 'topK: $topK, ' - 'stopSequence: $stopSequence, ' + 'stopSequence: $stopSequences, ' 'candidateCount: $candidateCount}'; } } @@ -231,7 +231,7 @@ class VertexAITextModelPrediction { final Map predictionJson, ) { final citationMetadata = - predictionJson['citationMetadata'] as Map; + predictionJson['citationMetadata'] as Map? ?? const {}; final citations = citationMetadata['citations'] as List? ?? const []; return VertexAITextModelPrediction( diff --git a/packages/vertex_ai/lib/src/gen_ai/models/text_chat.dart b/packages/vertex_ai/lib/src/gen_ai/models/text_chat.dart index ce10a345..b9790e28 100644 --- a/packages/vertex_ai/lib/src/gen_ai/models/text_chat.dart +++ b/packages/vertex_ai/lib/src/gen_ai/models/text_chat.dart @@ -98,7 +98,7 @@ class VertexAITextChatModelRequestParams { this.maxOutputTokens = 1024, this.topP = 0.95, this.topK = 40, - this.stopSequence = const [], + this.stopSequences = const [], this.candidateCount = 1, }); @@ -160,7 +160,7 @@ class VertexAITextChatModelRequestParams { /// if one of the strings is encountered in the response. If a string appears /// multiple times in the response, then the response truncates where it's /// first encountered. The strings are case-sensitive. - final List stopSequence; + final List stopSequences; /// The number of response variations to return. final int candidateCount; @@ -172,7 +172,7 @@ class VertexAITextChatModelRequestParams { 'maxOutputTokens': maxOutputTokens, 'topP': topP, 'topK': topK, - 'stopSequence': stopSequence, + 'stopSequences': stopSequences, 'candidateCount': candidateCount, }; } @@ -186,8 +186,8 @@ class VertexAITextChatModelRequestParams { topP == other.topP && topK == other.topK && const ListEquality().equals( - stopSequence, - other.stopSequence, + stopSequences, + other.stopSequences, ) && candidateCount == other.candidateCount; @@ -197,7 +197,7 @@ class VertexAITextChatModelRequestParams { maxOutputTokens.hashCode ^ topP.hashCode ^ topK.hashCode ^ - const ListEquality().hash(stopSequence) ^ + const ListEquality().hash(stopSequences) ^ candidateCount.hashCode; @override @@ -207,7 +207,7 @@ class VertexAITextChatModelRequestParams { 'maxOutputTokens: $maxOutputTokens, ' 'topP: $topP, ' 'topK: $topK, ' - 'stopSequence: $stopSequence, ' + 'stopSequence: $stopSequences, ' 'candidateCount: $candidateCount}'; } } diff --git a/packages/vertex_ai/test/gen_ai/gen_ai_client_test.dart b/packages/vertex_ai/test/gen_ai/gen_ai_client_test.dart index ae0509f4..4be6c08b 100644 --- a/packages/vertex_ai/test/gen_ai/gen_ai_client_test.dart +++ b/packages/vertex_ai/test/gen_ai/gen_ai_client_test.dart @@ -32,6 +32,28 @@ void main() async { } }); + test('Test VertexAITextModelApi stop sequence', () async { + final res = await vertexAi.text.predict( + prompt: + 'List the numbers from 1 to 9 in order without any spaces or commas.', + parameters: const VertexAITextModelRequestParams( + stopSequences: ['4'], + ), + ); + expect(res.predictions.first.content, contains('123')); + }); + + test('Test VertexAITextModelApi candidates count', () async { + final res = await vertexAi.text.predict( + prompt: 'Suggest a name for a LLM framework for Dart', + parameters: const VertexAITextModelRequestParams( + temperature: 1, + candidateCount: 3, + ), + ); + expect(res.predictions.length, 3); + }); + test('Test VertexAIChatModelApi', () async { final models = ['chat-bison', 'chat-bison-32k']; diff --git a/packages/vertex_ai/test/gen_ai/mappers/chat_test.dart b/packages/vertex_ai/test/gen_ai/mappers/chat_test.dart index d5fa6a2b..f7bf56fc 100644 --- a/packages/vertex_ai/test/gen_ai/mappers/chat_test.dart +++ b/packages/vertex_ai/test/gen_ai/mappers/chat_test.dart @@ -31,7 +31,7 @@ void main() { maxOutputTokens: 256, topP: 0.1, topK: 30, - stopSequence: ['STOP'], + stopSequences: ['STOP'], candidateCount: 10, ), ); @@ -64,7 +64,7 @@ void main() { 'maxOutputTokens': 256, 'topP': 0.1, 'topK': 30, - 'stopSequence': ['STOP'], + 'stopSequences': ['STOP'], 'candidateCount': 10, }, ); diff --git a/packages/vertex_ai/test/gen_ai/mappers/text_test.dart b/packages/vertex_ai/test/gen_ai/mappers/text_test.dart index 95ec76b5..6e45862a 100644 --- a/packages/vertex_ai/test/gen_ai/mappers/text_test.dart +++ b/packages/vertex_ai/test/gen_ai/mappers/text_test.dart @@ -13,7 +13,7 @@ void main() { maxOutputTokens: 256, topP: 0.1, topK: 30, - stopSequence: ['STOP'], + stopSequences: ['STOP'], candidateCount: 10, ), ); @@ -28,7 +28,7 @@ void main() { 'maxOutputTokens': 256, 'topP': 0.1, 'topK': 30, - 'stopSequence': ['STOP'], + 'stopSequences': ['STOP'], 'candidateCount': 10, }, );