Skip to content

Commit

Permalink
feat: completion and infill (#164)
Browse files Browse the repository at this point in the history
* feat: add `LlamaCompletion` that provides the ability to complete or infill text
* feat: support configuring more options for `getLlama` when using `"lastBuild"`
* fix: various bug fixes
  • Loading branch information
giladgd authored Feb 18, 2024
1 parent 47b476f commit ede69c1
Show file tree
Hide file tree
Showing 26 changed files with 2,054 additions and 124 deletions.
6 changes: 6 additions & 0 deletions docs/guide/cli/cli.data.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ import {CommandModule} from "yargs";
import {getCommandHtmlDoc} from "../../../.vitepress/utils/getCommandHtmlDoc.js";
import {BuildCommand} from "../../../src/cli/commands/BuildCommand.js";
import {ChatCommand} from "../../../src/cli/commands/ChatCommand.js";
import {CompleteCommand} from "../../../src/cli/commands/CompleteCommand.js";
import {InfillCommand} from "../../../src/cli/commands/InfillCommand.js";
import {DownloadCommand} from "../../../src/cli/commands/DownloadCommand.js";
import {ClearCommand} from "../../../src/cli/commands/ClearCommand.js";
import {htmlEscape} from "../../../.vitepress/utils/htmlEscape.js";
Expand All @@ -17,12 +19,16 @@ export default {
return {
index: buildIndexTable([
["chat", ChatCommand],
["complete", CompleteCommand],
["infill", InfillCommand],
["download", DownloadCommand],
["build", BuildCommand],
["clear", ClearCommand]
]),

chat: await getCommandHtmlDoc(ChatCommand),
complete: await getCommandHtmlDoc(CompleteCommand),
infill: await getCommandHtmlDoc(InfillCommand),
download: await getCommandHtmlDoc(DownloadCommand),
build: await getCommandHtmlDoc(BuildCommand),
clear: await getCommandHtmlDoc(ClearCommand)
Expand Down
17 changes: 17 additions & 0 deletions docs/guide/cli/complete.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
---
outline: deep
---
# `complete` command

<script setup lang="ts">
import {data as docs} from "./cli.data.js";
const commandDoc = docs.complete;
</script>

{{commandDoc.description}}

## Usage
```shell-vue
{{commandDoc.usage}}
```
<div v-html="commandDoc.options"></div>
17 changes: 17 additions & 0 deletions docs/guide/cli/infill.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
---
outline: deep
---
# `infill` command

<script setup lang="ts">
import {data as docs} from "./cli.data.js";
const commandDoc = docs.infill;
</script>

{{commandDoc.description}}

## Usage
```shell-vue
{{commandDoc.usage}}
```
<div v-html="commandDoc.options"></div>
72 changes: 63 additions & 9 deletions llama/addon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,26 @@ Napi::Value getGpuVramInfo(const Napi::CallbackInfo& info) {
return result;
}

static Napi::Value getNapiToken(const Napi::CallbackInfo& info, llama_model* model, llama_token token) {
auto tokenType = llama_token_get_type(model, token);

if (tokenType == LLAMA_TOKEN_TYPE_UNDEFINED || tokenType == LLAMA_TOKEN_TYPE_UNKNOWN) {
return Napi::Number::From(info.Env(), -1);
}

return Napi::Number::From(info.Env(), token);
}

static Napi::Value getNapiControlToken(const Napi::CallbackInfo& info, llama_model* model, llama_token token) {
auto tokenType = llama_token_get_type(model, token);

if (tokenType != LLAMA_TOKEN_TYPE_CONTROL) {
return Napi::Number::From(info.Env(), -1);
}

return Napi::Number::From(info.Env(), token);
}

class AddonModel : public Napi::ObjectWrap<AddonModel> {
public:
llama_model_params model_params;
Expand Down Expand Up @@ -119,7 +139,6 @@ class AddonModel : public Napi::ObjectWrap<AddonModel> {
}
}

llama_backend_init(false);
model = llama_load_model_from_file(modelPath.c_str(), model_params);

if (model == NULL) {
Expand Down Expand Up @@ -203,6 +222,15 @@ class AddonModel : public Napi::ObjectWrap<AddonModel> {
return Napi::Number::From(info.Env(), llama_n_ctx_train(model));
}

Napi::Value GetEmbeddingVectorSize(const Napi::CallbackInfo& info) {
if (disposed) {
Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException();
return info.Env().Undefined();
}

return Napi::Number::From(info.Env(), llama_n_embd(model));
}

Napi::Value GetTotalSize(const Napi::CallbackInfo& info) {
if (disposed) {
Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException();
Expand Down Expand Up @@ -239,55 +267,55 @@ class AddonModel : public Napi::ObjectWrap<AddonModel> {
return info.Env().Undefined();
}

return Napi::Number::From(info.Env(), llama_token_bos(model));
return getNapiControlToken(info, model, llama_token_bos(model));
}
Napi::Value TokenEos(const Napi::CallbackInfo& info) {
if (disposed) {
Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException();
return info.Env().Undefined();
}

return Napi::Number::From(info.Env(), llama_token_eos(model));
return getNapiControlToken(info, model, llama_token_eos(model));
}
Napi::Value TokenNl(const Napi::CallbackInfo& info) {
if (disposed) {
Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException();
return info.Env().Undefined();
}

return Napi::Number::From(info.Env(), llama_token_nl(model));
return getNapiToken(info, model, llama_token_nl(model));
}
Napi::Value PrefixToken(const Napi::CallbackInfo& info) {
if (disposed) {
Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException();
return info.Env().Undefined();
}

return Napi::Number::From(info.Env(), llama_token_prefix(model));
return getNapiControlToken(info, model, llama_token_prefix(model));
}
Napi::Value MiddleToken(const Napi::CallbackInfo& info) {
if (disposed) {
Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException();
return info.Env().Undefined();
}

return Napi::Number::From(info.Env(), llama_token_middle(model));
return getNapiControlToken(info, model, llama_token_middle(model));
}
Napi::Value SuffixToken(const Napi::CallbackInfo& info) {
if (disposed) {
Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException();
return info.Env().Undefined();
}

return Napi::Number::From(info.Env(), llama_token_suffix(model));
return getNapiControlToken(info, model, llama_token_suffix(model));
}
Napi::Value EotToken(const Napi::CallbackInfo& info) {
if (disposed) {
Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException();
return info.Env().Undefined();
}

return Napi::Number::From(info.Env(), llama_token_eot(model));
return getNapiControlToken(info, model, llama_token_eot(model));
}
Napi::Value GetTokenString(const Napi::CallbackInfo& info) {
if (disposed) {
Expand All @@ -308,6 +336,29 @@ class AddonModel : public Napi::ObjectWrap<AddonModel> {
return Napi::String::New(info.Env(), ss.str());
}

Napi::Value GetTokenType(const Napi::CallbackInfo& info) {
if (disposed) {
Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException();
return info.Env().Undefined();
}

if (info[0].IsNumber() == false) {
return Napi::Number::From(info.Env(), int32_t(LLAMA_TOKEN_TYPE_UNDEFINED));
}

int token = info[0].As<Napi::Number>().Int32Value();
auto tokenType = llama_token_get_type(model, token);

return Napi::Number::From(info.Env(), int32_t(tokenType));
}
Napi::Value ShouldPrependBosToken(const Napi::CallbackInfo& info) {
const int addBos = llama_add_bos_token(model);

bool shouldPrependBos = addBos != -1 ? bool(addBos) : (llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM);

return Napi::Boolean::New(info.Env(), shouldPrependBos);
}

static void init(Napi::Object exports) {
exports.Set(
"AddonModel",
Expand All @@ -318,6 +369,7 @@ class AddonModel : public Napi::ObjectWrap<AddonModel> {
InstanceMethod("tokenize", &AddonModel::Tokenize),
InstanceMethod("detokenize", &AddonModel::Detokenize),
InstanceMethod("getTrainContextSize", &AddonModel::GetTrainContextSize),
InstanceMethod("getEmbeddingVectorSize", &AddonModel::GetEmbeddingVectorSize),
InstanceMethod("getTotalSize", &AddonModel::GetTotalSize),
InstanceMethod("getTotalParameters", &AddonModel::GetTotalParameters),
InstanceMethod("getModelDescription", &AddonModel::GetModelDescription),
Expand All @@ -329,6 +381,8 @@ class AddonModel : public Napi::ObjectWrap<AddonModel> {
InstanceMethod("suffixToken", &AddonModel::SuffixToken),
InstanceMethod("eotToken", &AddonModel::EotToken),
InstanceMethod("getTokenString", &AddonModel::GetTokenString),
InstanceMethod("getTokenType", &AddonModel::GetTokenType),
InstanceMethod("shouldPrependBosToken", &AddonModel::ShouldPrependBosToken),
InstanceMethod("dispose", &AddonModel::Dispose),
}
)
Expand Down Expand Up @@ -993,7 +1047,7 @@ Napi::Value setLoggerLogLevel(const Napi::CallbackInfo& info) {
}

Napi::Object registerCallback(Napi::Env env, Napi::Object exports) {
llama_backend_init(false);
llama_backend_init();
exports.DefineProperties({
Napi::PropertyDescriptor::Function("systemInfo", systemInfo),
Napi::PropertyDescriptor::Function("setLogger", setLogger),
Expand Down
8 changes: 4 additions & 4 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@
"cross-spawn": "^7.0.3",
"env-var": "^7.3.1",
"fs-extra": "^11.2.0",
"lifecycle-utils": "^1.2.1",
"lifecycle-utils": "^1.2.2",
"log-symbols": "^5.1.0",
"node-addon-api": "^7.0.0",
"octokit": "^3.1.0",
Expand Down
6 changes: 0 additions & 6 deletions src/AbortError.ts

This file was deleted.

5 changes: 4 additions & 1 deletion src/bindings/AddonTypes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ export type AddonModel = {
tokenize(text: string, specialTokens: boolean): Uint32Array,
detokenize(tokens: Uint32Array): string,
getTrainContextSize(): number,
getEmbeddingVectorSize(): number,
getTotalSize(): number,
getTotalParameters(): number,
getModelDescription(): ModelTypeDescription,
Expand All @@ -52,7 +53,9 @@ export type AddonModel = {
middleToken(): Token,
suffixToken(): Token,
eotToken(): Token,
getTokenString(token: number): string
getTokenString(token: number): string,
getTokenType(token: Token): number,
shouldPrependBosToken(): boolean
};

export type AddonContext = {
Expand Down
27 changes: 25 additions & 2 deletions src/bindings/getLlama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,27 @@ export type LastBuildOptions = {
/**
* Set a custom logger for llama.cpp logs.
*/
logger?: (level: LlamaLogLevel, message: string) => void
logger?: (level: LlamaLogLevel, message: string) => void,

/**
* If a local build is not found, use prebuilt binaries.
* Enabled by default.
*/
usePrebuiltBinaries?: boolean,

/**
* If a local build is not found, and prebuilt binaries are not found, when building from source,
* print binary compilation progress logs.
* Enabled by default.
*/
progressLogs?: boolean,

/**
* If a local build is not found, and prebuilt binaries are not found, don't download llama.cpp source if it's not found.
* When set to `true`, and llama.cpp source is needed but is not found, a `NoBinaryFoundError` error will be thrown.
* Disabled by default.
*/
skipDownload?: boolean
};

export const getLlamaFunctionName = "getLlama";
Expand All @@ -124,7 +144,10 @@ export async function getLlama(options?: LlamaOptions | "lastBuild", lastBuildOp
const lastBuildInfo = await getLastBuildInfo();
const getLlamaOptions: LlamaOptions = {
logLevel: lastBuildOptions?.logLevel ?? defaultLlamaCppDebugLogs,
logger: lastBuildOptions?.logger ?? Llama.defaultConsoleLogger
logger: lastBuildOptions?.logger ?? Llama.defaultConsoleLogger,
usePrebuiltBinaries: lastBuildOptions?.usePrebuiltBinaries ?? true,
progressLogs: lastBuildOptions?.progressLogs ?? true,
skipDownload: lastBuildOptions?.skipDownload ?? defaultSkipDownload
};

if (lastBuildInfo == null)
Expand Down
4 changes: 4 additions & 0 deletions src/cli/cli.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import {BuildCommand} from "./commands/BuildCommand.js";
import {OnPostInstallCommand} from "./commands/OnPostInstallCommand.js";
import {ClearCommand} from "./commands/ClearCommand.js";
import {ChatCommand} from "./commands/ChatCommand.js";
import {CompleteCommand} from "./commands/CompleteCommand.js";
import {InfillCommand} from "./commands/InfillCommand.js";
import {DebugCommand} from "./commands/DebugCommand.js";

const __dirname = path.dirname(fileURLToPath(import.meta.url));
Expand All @@ -30,6 +32,8 @@ yarg
.command(BuildCommand)
.command(ClearCommand)
.command(ChatCommand)
.command(CompleteCommand)
.command(InfillCommand)
.command(OnPostInstallCommand)
.command(DebugCommand)
.recommendCommands()
Expand Down
Loading

0 comments on commit ede69c1

Please sign in to comment.