Skip to content

Commit

Permalink
🚧 wip: add lm studio
Browse files Browse the repository at this point in the history
  • Loading branch information
arvinxx committed Oct 25, 2024
1 parent 8f83863 commit 8840a62
Show file tree
Hide file tree
Showing 6 changed files with 309 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/config/modelProviders/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import GoogleProvider from './google';
import GroqProvider from './groq';
import HuggingFaceProvider from './huggingface';
import HunyuanProvider from './hunyuan';
import LMStudioProvider from './lmstudio';
import MinimaxProvider from './minimax';
import MistralProvider from './mistral';
import MoonshotProvider from './moonshot';
Expand Down Expand Up @@ -65,6 +66,7 @@ export const LOBE_DEFAULT_MODEL_LIST: ChatModelCard[] = [
HunyuanProvider.chatModels,
WenxinProvider.chatModels,
SenseNovaProvider.chatModels,
LMStudioProvider.chatModels,
].flat();

export const DEFAULT_MODEL_PROVIDER_LIST = [
Expand Down Expand Up @@ -100,6 +102,7 @@ export const DEFAULT_MODEL_PROVIDER_LIST = [
Ai360Provider,
TaichuProvider,
SiliconCloudProvider,
LMStudioProvider,
];

export const filterEnabledModels = (provider: ModelProviderCard) => {
Expand All @@ -124,6 +127,7 @@ export { default as GoogleProviderCard } from './google';
export { default as GroqProviderCard } from './groq';
export { default as HuggingFaceProviderCard } from './huggingface';
export { default as HunyuanProviderCard } from './hunyuan';
export { default as LMStudioProviderCard } from './lmstudio';
export { default as MinimaxProviderCard } from './minimax';
export { default as MistralProviderCard } from './mistral';
export { default as MoonshotProviderCard } from './moonshot';
Expand Down
31 changes: 31 additions & 0 deletions src/config/modelProviders/lmstudio.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import { ModelProviderCard } from '@/types/llm';

// ref: https://ollama.com/library
const LMStudio: ModelProviderCard = {
chatModels: [
{
description:
'Llama 3.1 是 Meta 推出的领先模型,支持高达 405B 参数,可应用于复杂对话、多语言翻译和数据分析领域。',
displayName: 'Llama 3.1 8B',
enabled: true,
id: 'llama3.1',
tokens: 128_000,
},
{
description: 'Qwen2.5 是阿里巴巴的新一代大规模语言模型,以优异的性能支持多元化的应用需求。',
displayName: 'Qwen2.5 7B',
enabled: true,
id: 'qwen2.5',
tokens: 128_000,
},
],
defaultShowBrowserRequest: true,
id: 'lmstudio',
modelList: { showModelFetcher: true },
modelsUrl: 'https://lmstudio.ai/models',
name: 'LM Studio',
showApiKey: false,
url: 'https://lmstudio.ai',
};

export default LMStudio;
7 changes: 7 additions & 0 deletions src/libs/agent-runtime/AgentRuntime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import { LobeGoogleAI } from './google';
import { LobeGroq } from './groq';
import { LobeHuggingFaceAI } from './huggingface';
import { LobeHunyuanAI } from './hunyuan';
import { LobeLMStudioAI } from './lmstudio';
import { LobeMinimaxAI } from './minimax';
import { LobeMistralAI } from './mistral';
import { LobeMoonshotAI } from './moonshot';
Expand Down Expand Up @@ -138,6 +139,7 @@ class AgentRuntime {
groq: Partial<ClientOptions>;
huggingface: { apiKey?: string; baseURL?: string };
hunyuan: Partial<ClientOptions>;
lmstudio: Partial<ClientOptions>;
minimax: Partial<ClientOptions>;
mistral: Partial<ClientOptions>;
moonshot: Partial<ClientOptions>;
Expand Down Expand Up @@ -197,6 +199,11 @@ class AgentRuntime {
break;
}

case ModelProvider.LMStudio: {
runtimeModel = new LobeLMStudioAI(params.lmstudio);
break;
}

case ModelProvider.Ollama: {
runtimeModel = new LobeOllamaAI(params.ollama);
break;
Expand Down
255 changes: 255 additions & 0 deletions src/libs/agent-runtime/lmstudio/index.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
// @vitest-environment node
import OpenAI from 'openai';
import { Mock, afterEach, beforeEach, describe, expect, it, vi } from 'vitest';

import {
ChatStreamCallbacks,
LobeOpenAICompatibleRuntime,
ModelProvider,
} from '@/libs/agent-runtime';

import * as debugStreamModule from '../utils/debugStream';
import { LobeDeepSeekAI } from './index';

const provider = ModelProvider.DeepSeek;
const defaultBaseURL = 'https://api.deepseek.com/v1';

const bizErrorType = 'ProviderBizError';
const invalidErrorType = 'InvalidProviderAPIKey';

// Mock the console.error to avoid polluting test output
vi.spyOn(console, 'error').mockImplementation(() => {});

let instance: LobeOpenAICompatibleRuntime;

beforeEach(() => {
instance = new LobeDeepSeekAI({ apiKey: 'test' });

// 使用 vi.spyOn 来模拟 chat.completions.create 方法
vi.spyOn(instance['client'].chat.completions, 'create').mockResolvedValue(
new ReadableStream() as any,
);
});

afterEach(() => {
vi.clearAllMocks();
});

describe('LobeDeepSeekAI', () => {
describe('init', () => {
it('should correctly initialize with an API key', async () => {
const instance = new LobeDeepSeekAI({ apiKey: 'test_api_key' });
expect(instance).toBeInstanceOf(LobeDeepSeekAI);
expect(instance.baseURL).toEqual(defaultBaseURL);
});
});

describe('chat', () => {
describe('Error', () => {
it('should return OpenAIBizError with an openai error response when OpenAI.APIError is thrown', async () => {
// Arrange
const apiError = new OpenAI.APIError(
400,
{
status: 400,
error: {
message: 'Bad Request',
},
},
'Error message',
{},
);

vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(apiError);

// Act
try {
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'deepseek-chat',
temperature: 0,
});
} catch (e) {
expect(e).toEqual({
endpoint: defaultBaseURL,
error: {
error: { message: 'Bad Request' },
status: 400,
},
errorType: bizErrorType,
provider,
});
}
});

it('should throw AgentRuntimeError with NoOpenAIAPIKey if no apiKey is provided', async () => {
try {
new LobeDeepSeekAI({});
} catch (e) {
expect(e).toEqual({ errorType: invalidErrorType });
}
});

it('should return OpenAIBizError with the cause when OpenAI.APIError is thrown with cause', async () => {
// Arrange
const errorInfo = {
stack: 'abc',
cause: {
message: 'api is undefined',
},
};
const apiError = new OpenAI.APIError(400, errorInfo, 'module error', {});

vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(apiError);

// Act
try {
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'deepseek-chat',
temperature: 0,
});
} catch (e) {
expect(e).toEqual({
endpoint: defaultBaseURL,
error: {
cause: { message: 'api is undefined' },
stack: 'abc',
},
errorType: bizErrorType,
provider,
});
}
});

it('should return OpenAIBizError with an cause response with desensitize Url', async () => {
// Arrange
const errorInfo = {
stack: 'abc',
cause: { message: 'api is undefined' },
};
const apiError = new OpenAI.APIError(400, errorInfo, 'module error', {});

instance = new LobeDeepSeekAI({
apiKey: 'test',

baseURL: 'https://api.abc.com/v1',
});

vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(apiError);

// Act
try {
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'deepseek-chat',
temperature: 0,
});
} catch (e) {
expect(e).toEqual({
endpoint: 'https://api.***.com/v1',
error: {
cause: { message: 'api is undefined' },
stack: 'abc',
},
errorType: bizErrorType,
provider,
});
}
});

it('should throw an InvalidDeepSeekAPIKey error type on 401 status code', async () => {
// Mock the API call to simulate a 401 error
const error = new Error('Unauthorized') as any;
error.status = 401;
vi.mocked(instance['client'].chat.completions.create).mockRejectedValue(error);

try {
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'deepseek-chat',
temperature: 0,
});
} catch (e) {
// Expect the chat method to throw an error with InvalidDeepSeekAPIKey
expect(e).toEqual({
endpoint: defaultBaseURL,
error: new Error('Unauthorized'),
errorType: invalidErrorType,
provider,
});
}
});

it('should return AgentRuntimeError for non-OpenAI errors', async () => {
// Arrange
const genericError = new Error('Generic Error');

vi.spyOn(instance['client'].chat.completions, 'create').mockRejectedValue(genericError);

// Act
try {
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'deepseek-chat',
temperature: 0,
});
} catch (e) {
expect(e).toEqual({
endpoint: defaultBaseURL,
errorType: 'AgentRuntimeError',
provider,
error: {
name: genericError.name,
cause: genericError.cause,
message: genericError.message,
stack: genericError.stack,
},
});
}
});
});

describe('DEBUG', () => {
it('should call debugStream and return StreamingTextResponse when DEBUG_DEEPSEEK_CHAT_COMPLETION is 1', async () => {
// Arrange
const mockProdStream = new ReadableStream() as any; // 模拟的 prod 流
const mockDebugStream = new ReadableStream({
start(controller) {
controller.enqueue('Debug stream content');
controller.close();
},
}) as any;
mockDebugStream.toReadableStream = () => mockDebugStream; // 添加 toReadableStream 方法

// 模拟 chat.completions.create 返回值,包括模拟的 tee 方法
(instance['client'].chat.completions.create as Mock).mockResolvedValue({
tee: () => [mockProdStream, { toReadableStream: () => mockDebugStream }],
});

// 保存原始环境变量值
const originalDebugValue = process.env.DEBUG_DEEPSEEK_CHAT_COMPLETION;

// 模拟环境变量
process.env.DEBUG_DEEPSEEK_CHAT_COMPLETION = '1';
vi.spyOn(debugStreamModule, 'debugStream').mockImplementation(() => Promise.resolve());

// 执行测试
// 运行你的测试函数,确保它会在条件满足时调用 debugStream
// 假设的测试函数调用,你可能需要根据实际情况调整
await instance.chat({
messages: [{ content: 'Hello', role: 'user' }],
model: 'deepseek-chat',
stream: true,
temperature: 0,
});

// 验证 debugStream 被调用
expect(debugStreamModule.debugStream).toHaveBeenCalled();

// 恢复原始环境变量值
process.env.DEBUG_DEEPSEEK_CHAT_COMPLETION = originalDebugValue;
});
});
});
});
10 changes: 10 additions & 0 deletions src/libs/agent-runtime/lmstudio/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import { ModelProvider } from '../types';
import { LobeOpenAICompatibleFactory } from '../utils/openaiCompatibleFactory';

export const LobeLMStudioAI = LobeOpenAICompatibleFactory({
baseURL: 'http://localhost:1234/v1',
debug: {
chatCompletion: () => process.env.DEBUG_LMSTUDIO_CHAT_COMPLETION === '1',
},
provider: ModelProvider.LMStudio,
});
2 changes: 2 additions & 0 deletions src/libs/agent-runtime/types/type.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import OpenAI from 'openai';


import { ILobeAgentRuntimeErrorType } from '../error';
import { ChatStreamPayload } from './chat';

Expand Down Expand Up @@ -35,6 +36,7 @@ export enum ModelProvider {
Groq = 'groq',
HuggingFace = 'huggingface',
Hunyuan = 'hunyuan',
LMStudio = 'lmstudio',
Minimax = 'minimax',
Mistral = 'mistral',
Moonshot = 'moonshot',
Expand Down

0 comments on commit 8840a62

Please sign in to comment.