From cc5b1b021636379f32f215546b78547ace87d150 Mon Sep 17 00:00:00 2001 From: David Miguel Lozano Date: Sat, 11 May 2024 19:48:12 +0200 Subject: [PATCH] feat: Add support for done reason in ollama_dart (#413) --- .../example/ollama_dart_example.dart | 14 ++--- .../lib/src/generated/schema/done_reason.dart | 19 +++++++ .../generate_chat_completion_response.dart | 9 ++- .../lib/src/generated/schema/schema.dart | 1 + .../src/generated/schema/schema.freezed.dart | 56 +++++++++++++------ .../lib/src/generated/schema/schema.g.dart | 12 +++- packages/ollama_dart/oas/ollama-curated.yaml | 11 +++- .../test/ollama_dart_chat_test.dart | 25 ++++++++- 8 files changed, 115 insertions(+), 32 deletions(-) create mode 100644 packages/ollama_dart/lib/src/generated/schema/done_reason.dart diff --git a/packages/ollama_dart/example/ollama_dart_example.dart b/packages/ollama_dart/example/ollama_dart_example.dart index 608590fb..50e8db3f 100644 --- a/packages/ollama_dart/example/ollama_dart_example.dart +++ b/packages/ollama_dart/example/ollama_dart_example.dart @@ -150,7 +150,7 @@ Future _generateEmbedding(final OllamaClient client) async { Future _createModel(final OllamaClient client) async { final res = await client.createModel( request: const CreateModelRequest( - name: 'mario', + model: 'mario', modelfile: 'FROM mistral:latest\nSYSTEM You are mario from Super Mario Bros.', ), @@ -161,7 +161,7 @@ Future _createModel(final OllamaClient client) async { Future _createModelStream(final OllamaClient client) async { final stream = client.createModelStream( request: const CreateModelRequest( - name: 'mario', + model: 'mario', modelfile: 'FROM mistral:latest\nSYSTEM You are mario from Super Mario Bros.', ), @@ -178,21 +178,21 @@ Future _listModels(final OllamaClient client) async { Future _showModelInfo(final OllamaClient client) async { final res = await client.showModelInfo( - request: const ModelInfoRequest(name: 'mistral:latest'), + request: const ModelInfoRequest(model: 'mistral:latest'), ); print(res); } Future _pullModel(final OllamaClient client) async { final res = await client.pullModel( - request: const PullModelRequest(name: 'yarn-llama2:13b-128k-q4_1'), + request: const PullModelRequest(model: 'yarn-llama2:13b-128k-q4_1'), ); print(res.status); } Future _pullModelStream(final OllamaClient client) async { final stream = client.pullModelStream( - request: const PullModelRequest(name: 'yarn-llama2:13b-128k-q4_1'), + request: const PullModelRequest(model: 'yarn-llama2:13b-128k-q4_1'), ); await for (final res in stream) { print(res.status); @@ -201,14 +201,14 @@ Future _pullModelStream(final OllamaClient client) async { Future _pushModel(final OllamaClient client) async { final res = await client.pushModel( - request: const PushModelRequest(name: 'mattw/pygmalion:latest'), + request: const PushModelRequest(model: 'mattw/pygmalion:latest'), ); print(res.status); } Future _pushModelStream(final OllamaClient client) async { final stream = client.pushModelStream( - request: const PushModelRequest(name: 'mattw/pygmalion:latest'), + request: const PushModelRequest(model: 'mattw/pygmalion:latest'), ); await for (final res in stream) { print(res.status); diff --git a/packages/ollama_dart/lib/src/generated/schema/done_reason.dart b/packages/ollama_dart/lib/src/generated/schema/done_reason.dart new file mode 100644 index 00000000..6ce8e59c --- /dev/null +++ b/packages/ollama_dart/lib/src/generated/schema/done_reason.dart @@ -0,0 +1,19 @@ +// coverage:ignore-file +// GENERATED CODE - DO NOT MODIFY BY HAND +// ignore_for_file: type=lint +// ignore_for_file: invalid_annotation_target +part of ollama_schema; + +// ========================================== +// ENUM: DoneReason +// ========================================== + +/// Reason why the model is done generating a response. +enum DoneReason { + @JsonValue('stop') + stop, + @JsonValue('length') + length, + @JsonValue('load') + load, +} diff --git a/packages/ollama_dart/lib/src/generated/schema/generate_chat_completion_response.dart b/packages/ollama_dart/lib/src/generated/schema/generate_chat_completion_response.dart index 87b296ac..e17f73c1 100644 --- a/packages/ollama_dart/lib/src/generated/schema/generate_chat_completion_response.dart +++ b/packages/ollama_dart/lib/src/generated/schema/generate_chat_completion_response.dart @@ -29,8 +29,13 @@ class GenerateChatCompletionResponse with _$GenerateChatCompletionResponse { /// Whether the response has completed. @JsonKey(includeIfNull: false) bool? done, - /// Reason the response is done. - @JsonKey(name: 'done_reason', includeIfNull: false) String? doneReason, + /// Reason why the model is done generating a response. + @JsonKey( + name: 'done_reason', + includeIfNull: false, + unknownEnumValue: JsonKey.nullForUndefinedEnumValue, + ) + DoneReason? doneReason, /// Time spent generating the response. @JsonKey(name: 'total_duration', includeIfNull: false) int? totalDuration, diff --git a/packages/ollama_dart/lib/src/generated/schema/schema.dart b/packages/ollama_dart/lib/src/generated/schema/schema.dart index 64371b4a..5c8eb964 100644 --- a/packages/ollama_dart/lib/src/generated/schema/schema.dart +++ b/packages/ollama_dart/lib/src/generated/schema/schema.dart @@ -16,6 +16,7 @@ part 'response_format.dart'; part 'generate_completion_response.dart'; part 'generate_chat_completion_request.dart'; part 'generate_chat_completion_response.dart'; +part 'done_reason.dart'; part 'message.dart'; part 'generate_embedding_request.dart'; part 'generate_embedding_response.dart'; diff --git a/packages/ollama_dart/lib/src/generated/schema/schema.freezed.dart b/packages/ollama_dart/lib/src/generated/schema/schema.freezed.dart index 42fd1450..c8bdc3d1 100644 --- a/packages/ollama_dart/lib/src/generated/schema/schema.freezed.dart +++ b/packages/ollama_dart/lib/src/generated/schema/schema.freezed.dart @@ -2496,9 +2496,12 @@ mixin _$GenerateChatCompletionResponse { @JsonKey(includeIfNull: false) bool? get done => throw _privateConstructorUsedError; - /// Reason the response is done. - @JsonKey(name: 'done_reason', includeIfNull: false) - String? get doneReason => throw _privateConstructorUsedError; + /// Reason why the model is done generating a response. + @JsonKey( + name: 'done_reason', + includeIfNull: false, + unknownEnumValue: JsonKey.nullForUndefinedEnumValue) + DoneReason? get doneReason => throw _privateConstructorUsedError; /// Time spent generating the response. @JsonKey(name: 'total_duration', includeIfNull: false) @@ -2543,7 +2546,11 @@ abstract class $GenerateChatCompletionResponseCopyWith<$Res> { @JsonKey(includeIfNull: false) String? model, @JsonKey(name: 'created_at', includeIfNull: false) String? createdAt, @JsonKey(includeIfNull: false) bool? done, - @JsonKey(name: 'done_reason', includeIfNull: false) String? doneReason, + @JsonKey( + name: 'done_reason', + includeIfNull: false, + unknownEnumValue: JsonKey.nullForUndefinedEnumValue) + DoneReason? doneReason, @JsonKey(name: 'total_duration', includeIfNull: false) int? totalDuration, @JsonKey(name: 'load_duration', includeIfNull: false) int? loadDuration, @JsonKey(name: 'prompt_eval_count', includeIfNull: false) @@ -2602,7 +2609,7 @@ class _$GenerateChatCompletionResponseCopyWithImpl<$Res, doneReason: freezed == doneReason ? _value.doneReason : doneReason // ignore: cast_nullable_to_non_nullable - as String?, + as DoneReason?, totalDuration: freezed == totalDuration ? _value.totalDuration : totalDuration // ignore: cast_nullable_to_non_nullable @@ -2657,7 +2664,11 @@ abstract class _$$GenerateChatCompletionResponseImplCopyWith<$Res> @JsonKey(includeIfNull: false) String? model, @JsonKey(name: 'created_at', includeIfNull: false) String? createdAt, @JsonKey(includeIfNull: false) bool? done, - @JsonKey(name: 'done_reason', includeIfNull: false) String? doneReason, + @JsonKey( + name: 'done_reason', + includeIfNull: false, + unknownEnumValue: JsonKey.nullForUndefinedEnumValue) + DoneReason? doneReason, @JsonKey(name: 'total_duration', includeIfNull: false) int? totalDuration, @JsonKey(name: 'load_duration', includeIfNull: false) int? loadDuration, @JsonKey(name: 'prompt_eval_count', includeIfNull: false) @@ -2716,7 +2727,7 @@ class __$$GenerateChatCompletionResponseImplCopyWithImpl<$Res> doneReason: freezed == doneReason ? _value.doneReason : doneReason // ignore: cast_nullable_to_non_nullable - as String?, + as DoneReason?, totalDuration: freezed == totalDuration ? _value.totalDuration : totalDuration // ignore: cast_nullable_to_non_nullable @@ -2754,7 +2765,11 @@ class _$GenerateChatCompletionResponseImpl @JsonKey(includeIfNull: false) this.model, @JsonKey(name: 'created_at', includeIfNull: false) this.createdAt, @JsonKey(includeIfNull: false) this.done, - @JsonKey(name: 'done_reason', includeIfNull: false) this.doneReason, + @JsonKey( + name: 'done_reason', + includeIfNull: false, + unknownEnumValue: JsonKey.nullForUndefinedEnumValue) + this.doneReason, @JsonKey(name: 'total_duration', includeIfNull: false) this.totalDuration, @JsonKey(name: 'load_duration', includeIfNull: false) this.loadDuration, @JsonKey(name: 'prompt_eval_count', includeIfNull: false) @@ -2791,10 +2806,13 @@ class _$GenerateChatCompletionResponseImpl @JsonKey(includeIfNull: false) final bool? done; - /// Reason the response is done. + /// Reason why the model is done generating a response. @override - @JsonKey(name: 'done_reason', includeIfNull: false) - final String? doneReason; + @JsonKey( + name: 'done_reason', + includeIfNull: false, + unknownEnumValue: JsonKey.nullForUndefinedEnumValue) + final DoneReason? doneReason; /// Time spent generating the response. @override @@ -2897,8 +2915,11 @@ abstract class _GenerateChatCompletionResponse @JsonKey(name: 'created_at', includeIfNull: false) final String? createdAt, @JsonKey(includeIfNull: false) final bool? done, - @JsonKey(name: 'done_reason', includeIfNull: false) - final String? doneReason, + @JsonKey( + name: 'done_reason', + includeIfNull: false, + unknownEnumValue: JsonKey.nullForUndefinedEnumValue) + final DoneReason? doneReason, @JsonKey(name: 'total_duration', includeIfNull: false) final int? totalDuration, @JsonKey(name: 'load_duration', includeIfNull: false) @@ -2939,9 +2960,12 @@ abstract class _GenerateChatCompletionResponse bool? get done; @override - /// Reason the response is done. - @JsonKey(name: 'done_reason', includeIfNull: false) - String? get doneReason; + /// Reason why the model is done generating a response. + @JsonKey( + name: 'done_reason', + includeIfNull: false, + unknownEnumValue: JsonKey.nullForUndefinedEnumValue) + DoneReason? get doneReason; @override /// Time spent generating the response. diff --git a/packages/ollama_dart/lib/src/generated/schema/schema.g.dart b/packages/ollama_dart/lib/src/generated/schema/schema.g.dart index fe12dbac..f5548646 100644 --- a/packages/ollama_dart/lib/src/generated/schema/schema.g.dart +++ b/packages/ollama_dart/lib/src/generated/schema/schema.g.dart @@ -220,7 +220,9 @@ _$GenerateChatCompletionResponseImpl model: json['model'] as String?, createdAt: json['created_at'] as String?, done: json['done'] as bool?, - doneReason: json['done_reason'] as String?, + doneReason: $enumDecodeNullable( + _$DoneReasonEnumMap, json['done_reason'], + unknownValue: JsonKey.nullForUndefinedEnumValue), totalDuration: json['total_duration'] as int?, loadDuration: json['load_duration'] as int?, promptEvalCount: json['prompt_eval_count'] as int?, @@ -243,7 +245,7 @@ Map _$$GenerateChatCompletionResponseImplToJson( writeNotNull('model', instance.model); writeNotNull('created_at', instance.createdAt); writeNotNull('done', instance.done); - writeNotNull('done_reason', instance.doneReason); + writeNotNull('done_reason', _$DoneReasonEnumMap[instance.doneReason]); writeNotNull('total_duration', instance.totalDuration); writeNotNull('load_duration', instance.loadDuration); writeNotNull('prompt_eval_count', instance.promptEvalCount); @@ -253,6 +255,12 @@ Map _$$GenerateChatCompletionResponseImplToJson( return val; } +const _$DoneReasonEnumMap = { + DoneReason.stop: 'stop', + DoneReason.length: 'length', + DoneReason.load: 'load', +}; + _$MessageImpl _$$MessageImplFromJson(Map json) => _$MessageImpl( role: $enumDecode(_$MessageRoleEnumMap, json['role']), diff --git a/packages/ollama_dart/oas/ollama-curated.yaml b/packages/ollama_dart/oas/ollama-curated.yaml index efc2c6f9..609bb44b 100644 --- a/packages/ollama_dart/oas/ollama-curated.yaml +++ b/packages/ollama_dart/oas/ollama-curated.yaml @@ -573,9 +573,7 @@ components: description: Whether the response has completed. example: true done_reason: - type: string - nullable: true - description: Reason the response is done. + $ref: '#/components/schemas/DoneReason' total_duration: type: integer format: int64 @@ -604,6 +602,13 @@ components: format: int64 description: Time in nanoseconds spent generating the response. example: 1325948000 + DoneReason: + type: string + description: Reason why the model is done generating a response. + enum: + - stop # The generation hit a stop token. + - length # The maximum num_tokens was reached. + - load # The request was sent with an empty body to load the model. Message: type: object description: A message in the chat endpoint diff --git a/packages/ollama_dart/test/ollama_dart_chat_test.dart b/packages/ollama_dart/test/ollama_dart_chat_test.dart index 0e9859fc..af90c448 100644 --- a/packages/ollama_dart/test/ollama_dart_chat_test.dart +++ b/packages/ollama_dart/test/ollama_dart_chat_test.dart @@ -52,6 +52,7 @@ void main() { isNotEmpty, ); expect(response.done, isTrue); + expect(response.doneReason, DoneReason.stop); expect(response.totalDuration, greaterThan(0)); expect(response.promptEvalCount, greaterThan(0)); expect(response.evalCount, greaterThan(0)); @@ -118,8 +119,7 @@ void main() { Message( role: MessageRole.user, content: 'List the numbers from 1 to 9 in order. ' - 'Output ONLY the numbers in one line without any spaces or commas. ' - 'NUMBERS: ', + 'Output ONLY the numbers without spaces or commas.', ), ], options: RequestOptions(stop: ['4']), @@ -128,6 +128,27 @@ void main() { final generation = res.message?.content.replaceAll(RegExp(r'[\s\n]'), ''); expect(generation, contains('123')); expect(generation, isNot(contains('456789'))); + expect(res.doneReason, DoneReason.stop); + }); + + test('Test call chat completions API with max tokens', () async { + final res = await client.generateChatCompletion( + request: const GenerateChatCompletionRequest( + model: defaultModel, + messages: [ + Message( + role: MessageRole.system, + content: 'You are a helpful assistant.', + ), + Message( + role: MessageRole.user, + content: 'List the numbers from 1 to 9 in order.', + ), + ], + options: RequestOptions(numPredict: 1), + ), + ); + expect(res.doneReason, DoneReason.length); }); test('Test call chat completions API with image', () async {