Skip to content

Commit

Permalink
feat: ability to pick custom model
Browse files Browse the repository at this point in the history
  • Loading branch information
ex3ndr committed Dec 22, 2023
1 parent f342aee commit 10085e4
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 51 deletions.
62 changes: 42 additions & 20 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -38,23 +38,9 @@
"properties": {
"inference.endpoint": {
"type": "string",
"default": "http://127.0.0.1:11434/",
"description": "Ollama Server Endpoint"
},
"inference.maxLines": {
"type": "number",
"default": 16,
"description": "Max number of lines to be keep."
},
"inference.maxTokens": {
"type": "number",
"default": 256,
"description": "Max number of new tokens to be generated."
},
"inference.temperature": {
"type": "number",
"default": 0.2,
"description": "Temperature of the model. Increasing the temperature will make the model answer more creatively."
"default": "",
"description": "Ollama Server Endpoint. Empty for local instance.",
"order": 1
},
"inference.model": {
"type": "string",
Expand All @@ -81,10 +67,46 @@
"deepseek-coder:6.7b-base-fp16",
"deepseek-coder:33b-base-q4_K_S",
"deepseek-coder:33b-base-q4_K_M",
"deepseek-coder:33b-base-fp16"
"deepseek-coder:33b-base-fp16",
"custom"
],
"default": "deepseek-coder:1.3b-base-q4_1",
"description": "Inference model to use"
"description": "Inference model to use",
"order": 2
},
"inference.temperature": {
"type": "number",
"default": 0.2,
"description": "Temperature of the model. Increasing the temperature will make the model answer more creatively.",
"order": 3
},
"inference.custom.model": {
"type": "string",
"default": "",
"description": "Custom model name",
"order": 4
},
"inference.custom.format": {
"type": "string",
"enum": [
"codellama",
"deepseek"
],
"default": "codellama",
"description": "Custom model prompt format",
"order": 5
},
"inference.maxLines": {
"type": "number",
"default": 16,
"description": "Max number of lines to be keep.",
"order": 6
},
"inference.maxTokens": {
"type": "number",
"default": 256,
"description": "Max number of new tokens to be generated.",
"order": 7
}
}
}
Expand All @@ -111,4 +133,4 @@
"ts-jest": "^29.1.1",
"typescript": "^5.2.2"
}
}
}
50 changes: 50 additions & 0 deletions src/config.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import vscode from 'vscode';

class Config {

// Inference
get inference() {
let config = this.#config;

// Load endpoint
let endpoint = (config.get('endpoint') as string).trim();
if (endpoint.endsWith('/')) {
endpoint = endpoint.slice(0, endpoint.length - 1).trim();
}
if (endpoint === '') {
endpoint = 'http://127.0.0.1:11434';
}

// Load general paremeters
let maxLines = config.get('maxLines') as number;
let maxTokens = config.get('maxTokens') as number;
let temperature = config.get('temperature') as number;

// Load model
let modelName = config.get('model') as string;
let modelFormat: 'codellama' | 'deepseek' = 'codellama';
if (modelName === 'custom') {
modelName = config.get('custom.model') as string;
modelFormat = config.get('cutom.format') as 'codellama' | 'deepseek';
} else {
if (modelName.startsWith('deepseek-coder')) {
modelFormat = 'deepseek';
}
}

return {
endpoint,
maxLines,
maxTokens,
temperature,
modelName,
modelFormat
};
}

get #config() {
return vscode.workspace.getConfiguration('inference');
};
}

export const config = new Config();
7 changes: 4 additions & 3 deletions src/prompts/autocomplete.ts
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
import { ollamaTokenGenerator } from '../modules/ollamaTokenGenerator';
import { countSymbol } from '../modules/text';
import { info } from '../modules/log';
import { adaptPrompt } from './adaptors/adaptPrompt';
import { ModelFormat, adaptPrompt } from './processors/models';

export async function autocomplete(args: {
endpoint: string,
model: string,
format: ModelFormat,
prefix: string,
suffix: string | null,
suffix: string,
maxLines: number,
maxTokens: number,
temperature: number,
canceled?: () => boolean,
}): Promise<string> {

let prompt = adaptPrompt({ prefix: args.prefix, suffix: args.suffix, model: args.model });
let prompt = adaptPrompt({ prefix: args.prefix, suffix: args.suffix, format: args.format });

// Calculate arguments
let data = {
Expand Down
8 changes: 4 additions & 4 deletions src/prompts/preparePrompt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ export async function preparePrompt(document: vscode.TextDocument, position: vsc
let text = document.getText();
let offset = document.offsetAt(position);
let prefix = text.slice(0, offset);
let suffix: string | null = text.slice(offset);
let suffix: string = text.slice(offset);

// Trim suffix
// If suffix is too small it is safe to assume that it could be ignored which would allow us to use
// more powerful completition instead of in middle one
if (suffix.length < 256) {
suffix = null;
}
// if (suffix.length < 256) {
// suffix = null;
// }

// Add filename and language to prefix
// NOTE: Most networks don't have a concept of filenames and expected language, but we expect that some files in training set has something in title that
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
export function adaptPrompt(args: { model: string, prefix: string, suffix: string | null }): { prompt: string, stop: string[] } {
export type ModelFormat = 'codellama' | 'deepseek';

export function adaptPrompt(args: { format: ModelFormat, prefix: string, suffix: string }): { prompt: string, stop: string[] } {

// Common non FIM mode
if (!args.suffix) {
return {
prompt: args.prefix,
stop: [`<END>`]
};
}
// if (!args.suffix) {
// return {
// prompt: args.prefix,
// stop: [`<END>`]
// };
// }

// Starcoder FIM
if (args.model.startsWith('deepseek-coder')) {
if (args.format === 'deepseek') {
return {
prompt: `<|fim▁begin|>${args.prefix}<|fim▁hole|>${args.suffix}<|fim▁end|>`,
stop: [`<|fim▁begin|>`, `<|fim▁hole|>`, `<|fim▁end|>`, `<END>`]
Expand Down
35 changes: 19 additions & 16 deletions src/prompts/provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { getFromPromptCache, setPromptToCache } from './promptCache';
import { isNotNeeded, isSupported } from './filter';
import { ollamaCheckModel } from '../modules/ollamaCheckModel';
import { ollamaDownloadModel } from '../modules/ollamaDownloadModel';
import { config } from '../config';

export class PromptProvider implements vscode.InlineCompletionItemProvider {

Expand Down Expand Up @@ -62,22 +63,23 @@ export class PromptProvider implements vscode.InlineCompletionItemProvider {
if (cached === undefined) {

// Config
let config = vscode.workspace.getConfiguration('inference');
let endpoint = config.get('endpoint') as string;
let model = config.get('model') as string;
let maxLines = config.get('maxLines') as number;
let maxTokens = config.get('maxTokens') as number;
let temperature = config.get('temperature') as number;
if (endpoint.endsWith('/')) {
endpoint = endpoint.slice(0, endpoint.length - 1);
}
let inferenceConfig = config.inference;
// let config = vscode.workspace.getConfiguration('inference');
// let endpoint = config.get('endpoint') as string;
// let model = config.get('model') as string;
// let maxLines = config.get('maxLines') as number;
// let maxTokens = config.get('maxTokens') as number;
// let temperature = config.get('temperature') as number;
// if (endpoint.endsWith('/')) {
// endpoint = endpoint.slice(0, endpoint.length - 1);
// }

// Update status
this.statusbar.text = `$(sync~spin) Llama Coder`;
try {

// Check model exists
let modelExists = await ollamaCheckModel(endpoint, model);
let modelExists = await ollamaCheckModel(inferenceConfig.endpoint, inferenceConfig.modelName);
if (token.isCancellationRequested) {
info(`Canceled after AI completion.`);
return;
Expand All @@ -86,7 +88,7 @@ export class PromptProvider implements vscode.InlineCompletionItemProvider {
// Download model if not exists
if (!modelExists) {
this.statusbar.text = `$(sync~spin) Downloading`;
await ollamaDownloadModel(endpoint, model);
await ollamaDownloadModel(inferenceConfig.endpoint, inferenceConfig.modelName);
this.statusbar.text = `$(sync~spin) Llama Coder`;
}
if (token.isCancellationRequested) {
Expand All @@ -99,11 +101,12 @@ export class PromptProvider implements vscode.InlineCompletionItemProvider {
res = await autocomplete({
prefix: prepared.prefix,
suffix: prepared.suffix,
endpoint: endpoint,
model: model,
maxLines: maxLines,
maxTokens: maxTokens,
temperature,
endpoint: inferenceConfig.endpoint,
model: inferenceConfig.modelName,
format: inferenceConfig.modelFormat,
maxLines: inferenceConfig.maxLines,
maxTokens: inferenceConfig.maxTokens,
temperature: inferenceConfig.temperature,
canceled: () => token.isCancellationRequested,
});
info(`AI completion completed: ${res}`);
Expand Down

0 comments on commit 10085e4

Please sign in to comment.