Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ feat: Support Cloudflare Workers AI #2966

Closed
wants to merge 48 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
f2996eb
Delete .nvmrc
sxjeru May 31, 2024
6de6000
Merge branch 'lobehub:main' into cf
sxjeru Jun 21, 2024
4b1d4c6
feat: Add Cloudflare as a model provider
sxjeru Jun 21, 2024
b774549
fix
sxjeru Jun 21, 2024
1336aae
fix
sxjeru Jun 21, 2024
ff5361e
fix
sxjeru Jun 21, 2024
6d658bd
fix
sxjeru Jun 21, 2024
5a0a4da
fix
sxjeru Jun 21, 2024
3608659
fix
sxjeru Jun 21, 2024
161af46
fix
sxjeru Jun 21, 2024
aa609af
fix icon
sxjeru Jun 21, 2024
0972e11
fix
sxjeru Jun 21, 2024
8ad1100
Create .nvmrc
sxjeru Jun 21, 2024
ed2f3c0
Delete src/config/modelProviders/.nvmrc
sxjeru Jun 21, 2024
e47aee5
CF -> CLOUDFLARE
sxjeru Jun 21, 2024
1909a89
Merge branch 'cf' of https://github.com/sxjeru/lobe-chat into cf
sxjeru Jun 21, 2024
5a1180c
revert
sxjeru Jun 21, 2024
7648bde
chore: Update agentRuntime.ts and auth.ts to support Cloudflare accou…
sxjeru Jun 21, 2024
9d036ee
Add provider setting
sxjeru Jun 21, 2024
7fe9401
fix
sxjeru Jun 21, 2024
fa23ba4
Update cloudflare.ts
sxjeru Jun 21, 2024
4414320
fix
sxjeru Jun 24, 2024
8d1f973
Update cloudflare.ts
sxjeru Jun 24, 2024
3b57709
Merge branch 'main' into cf
sxjeru Jun 24, 2024
7efaab9
accountID
sxjeru Jul 1, 2024
87f0721
fix
sxjeru Jul 1, 2024
7844a5b
Merge branch 'main' into cf
sxjeru Jul 1, 2024
26de0f1
i18n
sxjeru Jul 1, 2024
65463e0
Merge branch 'main' into cf
sxjeru Jul 10, 2024
7fe207a
Merge branch 'main' into cf
sxjeru Jul 25, 2024
e0f541a
Update index.ts
sxjeru Jul 27, 2024
bc26fd8
Update baichuan.ts
sxjeru Jul 27, 2024
0f5462f
Merge branch 'main' into cf
sxjeru Jul 27, 2024
bb02954
Update cloudflare.ts
sxjeru Jul 27, 2024
85021aa
save changes
BrandonStudio Jul 31, 2024
cb7dd1c
commit check
BrandonStudio Jul 31, 2024
ac8d4f2
disable function calling for now
BrandonStudio Jul 31, 2024
eefacf5
does not catch errors when fetching models
BrandonStudio Jul 31, 2024
5fc4c81
ready to add base url
BrandonStudio Jul 31, 2024
52ff9d1
commit check
BrandonStudio Jul 31, 2024
b8492e2
revert change
BrandonStudio Aug 1, 2024
b452d30
revert string boolean check
BrandonStudio Aug 1, 2024
b46c642
fix type error on Vercel.
BrandonStudio Aug 1, 2024
2dca07d
i18n by groq/llama-3.1-8b-instant
BrandonStudio Aug 1, 2024
0f40d15
rename env var
BrandonStudio Aug 1, 2024
8469931
Merge branch 'cf' into pr/BrandonStudio/38
sxjeru Aug 1, 2024
b3351d8
Merge branch 'main' into cf
sxjeru Aug 1, 2024
65c0bd2
Merge branch 'main' into cf
sxjeru Aug 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
save changes
  • Loading branch information
BrandonStudio committed Jul 31, 2024
commit 85021aa9eabfb59316c0c59bcb749f0266502def
13 changes: 12 additions & 1 deletion src/config/modelProviders/cloudflare.ts
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,20 @@ const Cloudflare: ModelProviderCard = {
id: '@hf/thebloke/zephyr-7b-beta-awq',
tokens: 32_768,
},
{
description:
'Generation over generation, Meta Llama 3 demonstrates state-of-the-art performance on a wide range of industry benchmarks and offers new capabilities, including improved reasoning.\t',
displayName: 'meta-llama-3-8b-instruct',
enabled: true,
functionCall: false,
id: '@hf/meta-llama/meta-llama-3-8b-instruct',
},
],
checkModel: '@hf/thebloke/deepseek-coder-6.7b-instruct-awq',
checkModel: '@hf/meta-llama/meta-llama-3-8b-instruct',
id: 'cloudflare',
modelList: {
showModelFetcher: true,
},
name: 'Cloudflare Workers AI',
};

Expand Down
218 changes: 209 additions & 9 deletions src/libs/agent-runtime/cloudflare/index.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,216 @@
import { ChatModelCard } from '@/types/llm';
import { LobeRuntimeAI } from '../BaseAI';
import { AgentRuntimeErrorType } from '../error';
import { ChatCompetitionOptions, ChatStreamPayload, ModelProvider } from '../types';
import { AgentRuntimeError } from '../utils/createError';
import { desensitizeUrl } from '../utils/desensitizeUrl';
import { StreamingResponse } from '../utils/response';

import { ModelProvider } from '../types';
import { LobeOpenAICompatibleFactory } from '../utils/openaiCompatibleFactory';
const DEFAULT_BASE_URL_PREFIX = 'https://api.cloudflare.com';

export interface LobeCloudflareParams {
accountID?: string;
apiKey?: string;
}

export const LobeCloudflareAI = LobeOpenAICompatibleFactory({
baseURL: `https://api.cloudflare.com/client/v4/accounts/${process.env.CLOUDFLARE_ACCOUNT_ID}/ai/v1`,
debug: {
chatCompletion: () => process.env.DEBUG_CLOUDFLARE_CHAT_COMPLETION === '1',
},
provider: ModelProvider.Cloudflare,
});
function fillUrl(accountID: string): string {
return `${DEFAULT_BASE_URL_PREFIX}/client/v4/accounts/${accountID}/ai/run/`;
}

function desensitizeAccountId(path: string): string {
return path.replace(/\/[\dA-Fa-f]{32}\//, '/****/');
}

function desensitizeCloudflareUrl(url: string): string {
const urlObj = new URL(url);
let { protocol, hostname, port, pathname, search } = urlObj;
if (url.startsWith(DEFAULT_BASE_URL_PREFIX)) {
return `${protocol}//${hostname}${port ? `:${port}` : ''}${desensitizeAccountId(pathname)}${search}`;
} else {
const desensitizedUrl = desensitizeUrl(`${protocol}//${hostname}${port ? `:${port}` : ''}`);
return `${desensitizedUrl}${desensitizeAccountId(pathname)}${search}`;
}
}

const CF_PROPERTY_NAME = 'property_id';

function getModelBeta(model: any): boolean {
try {
const betaProperty = model['properties'].filter(
(property: any) => property[CF_PROPERTY_NAME] === 'beta',
);
if (betaProperty.length === 1) {
return betaProperty[0]['value'].toLowerCase() === "true"; // This is a string now.
}
return false;
} catch {
return false;
}
}

function getModelDisplayName(model: any, beta: boolean): string {
const modelId = model['name'];
let name = modelId.split('/').at(-1)!;
if (beta) {
name += ' (Beta)';
}
return name;
}

function getModelFunctionCalling(model: any): boolean {
try {
const fcProperty = model['properties'].filter(
(property: any) => property[CF_PROPERTY_NAME] === 'function_calling',
);
if (fcProperty.length === 1) {
return fcProperty[0]['value'].toLowerCase() === "true";
}
return false;
} catch {
return false;
}
}

function getModelTokens(model: any): number | undefined {
try {
const tokensProperty = model['properties'].filter(
(property: any) => property[CF_PROPERTY_NAME] === 'max_total_tokens',
);
if (tokensProperty.length === 1) {
return parseInt(tokensProperty[0]['value']);
}
return undefined;
} catch {
return undefined;
}
}

class CloudflareStreamTransformer {
private textDecoder = new TextDecoder();
private buffer: string = '';

private parseChunk(chunk: string, controller: TransformStreamDefaultController) {
const dataPrefix = /^data: /;
const json = chunk.replace(dataPrefix, '');
const parsedChunk = JSON.parse(json);
controller.enqueue(`event: text\n`);
controller.enqueue(`data: ${JSON.stringify(parsedChunk.response)}\n\n`);
}

public async transform(chunk: Uint8Array, controller: TransformStreamDefaultController) {
let textChunk = this.textDecoder.decode(chunk);
if (this.buffer.trim() !== '') {
textChunk = this.buffer + textChunk;
this.buffer = '';
}
const splits = textChunk.split('\n\n');
for (let i = 0; i < splits.length - 1; i++) {
if (/\[DONE]/.test(splits[i].trim())) {
return;
}
this.parseChunk(splits[i], controller);
}
const lastChunk = splits.at(-1)!;
if (lastChunk.trim() !== '') { // else drop.
this.buffer += lastChunk; // does not need to be trimmed.
}
}
}

export class LobeCloudflareAI implements LobeRuntimeAI {
baseURL: string;
accountID: string;
apiKey?: string;

constructor({ accountID, apiKey }: LobeCloudflareParams) {
//if (baseURLOrAccountID.startsWith('http')) {
// this.baseURL = baseURLOrAccountID;
// // Try get accountID from baseURL
// this.accountID = baseURLOrAccountID.replaceAll(/^.*\/([\dA-Fa-f]{32})\/.*$/g, '$1');
//} else {
this.accountID = accountID!;
this.baseURL = fillUrl(accountID!);
//}
this.apiKey = apiKey;
}

async chat(payload: ChatStreamPayload, options?: ChatCompetitionOptions): Promise<Response> {
// Implement your logic here
// This method should handle the chat functionality using the provided payload and options
// It should return a Promise that resolves to a Response object
// You can make API calls, perform computations, or any other necessary operations

// Example implementation:
try {
const { model, tools, ...restPayload } = payload;
const functions = tools?.map((tool) => tool.function);
const headers = options?.headers || {};
if (this.apiKey) {
headers['Authorization'] = `Bearer ${this.apiKey}`;
}
const url = new URL(model, this.baseURL);
const response = await fetch(url, {
body: JSON.stringify({ tools: functions, ...restPayload }),
headers: { 'Content-Type': 'application/json', ...headers },
method: 'POST',
});

const desensitizedEndpoint = desensitizeCloudflareUrl(this.baseURL);

switch (response.status) {
case 400: {
throw AgentRuntimeError.chat({
endpoint: desensitizedEndpoint,
error: response,
errorType: AgentRuntimeErrorType.ProviderBizError,
provider: ModelProvider.Cloudflare,
});
}
}

return StreamingResponse(
response.body!.pipeThrough(new TransformStream(new CloudflareStreamTransformer()))
);
} catch (error) {
const desensitizedEndpoint = desensitizeCloudflareUrl(this.baseURL);

throw AgentRuntimeError.chat({
endpoint: desensitizedEndpoint,
error: error as any,
errorType: AgentRuntimeErrorType.ProviderBizError,
provider: ModelProvider.Cloudflare,
});
}
}

async models(): Promise<ChatModelCard[]> {
try {
const url = `${DEFAULT_BASE_URL_PREFIX}/client/v4/accounts/${this.accountID}/ai/models/search`;
const response = await fetch(url, {
headers: {
'Authorization': `Bearer ${this.apiKey}`,
'Content-Type': 'application/json',
},
method: 'GET',
});
const j = await response.json();
const models: any[] = j['result'].filter(
(model: any) => model['task']['name'] === 'Text Generation',
);
const chatModels: ChatModelCard[] = models.map((model) => {
const modelBeta = getModelBeta(model);
return {
description: model['description'],
displayName: getModelDisplayName(model, modelBeta),
enabled: !modelBeta,
functionCall: getModelFunctionCalling(model),
id: model['name'],
tokens: getModelTokens(model),
};
});
return chatModels;
} catch {
return [];
}
}
}
4 changes: 4 additions & 0 deletions src/store/user/slices/modelList/action.ts
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ export const createModelListSlice: StateCreator<

const togetherai = draft.find((d) => d.id === ModelProvider.TogetherAI);
if (togetherai) togetherai.chatModels = mergeModels('togetherai', togetherai.chatModels);

const cloudflare = draft.find((d) => d.id === ModelProvider.Cloudflare);
if (cloudflare)
cloudflare.chatModels = mergeModels('cloudflare', cloudflare.chatModels);
});

set({ defaultModelProviderList }, false, `refreshDefaultModelList - ${params?.trigger}`);
Expand Down