Skip to content

Commit

Permalink
add support for embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
hbsgithub committed May 2, 2023
1 parent 8e7b584 commit 71bb363
Showing 1 changed file with 73 additions and 31 deletions.
104 changes: 73 additions & 31 deletions main.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { serve } from "https://deno.land/std@0.181.0/http/server.ts";
import { pooledMap } from "https://deno.land/std@0.182.0/async/pool.ts";

// The name of your Azure OpenAI Resource.
const resourceName:string = Deno.env.get("RESOURCE_NAME");
Expand All @@ -19,63 +20,60 @@ async function handleRequest(request:Request):Promise<Response> {
const url = new URL(request.url);
let path:string;
if (url.pathname === '/v1/chat/completions') {
path="chat/completions"
return handleDirect(request, "chat/completions");
} else if (url.pathname === '/v1/completions') {
path="completions"
return handleDirect(request, "completions");
} else if (url.pathname === '/v1/models') {
return handleModels(request)
} else if (url.pathname === '/v1/embeddings') {
return handleEmbedding(request, "embeddings");
} else {
return new Response('404 Not Found', { status: 404 })
}
}


async function requestAzure(method: string, body: any, path: string, authKey?: string) {
if (!authKey) {
return new Response("Not allowed", { status: 403 });
}

// Get the value of the model field and perform mapping.
let deployName:string = '';
let body:any;
if (request.method === 'POST') {
body = await request.json();
}

const modelName:string = body?.model;

if (modelName) {
deployName = mapper[modelName] || modelName;
}
if (method === 'POST') {
const modelName: string | undefined = body?.model;
if (modelName) {
deployName = mapper[modelName] || modelName;
}
}

const fetchAPI:string = `https://${resourceName}.openai.azure.com/openai/deployments/${deployName}/${path}?api-version=${apiVersion}`;

const authKey:string|null = request.headers.get('Authorization');
if (!authKey) {
return new Response("Not allowed", {status: 403});
}

const payload:RequestInit = {
method: request.method,
method: method,
headers: {
"Content-Type": "application/json",
"api-key": authKey.replace('Bearer ', ''),
},
body: JSON.stringify(body),
};

const response:Response = await fetch(fetchAPI, payload);

return await fetch(fetchAPI, payload);
// const response:Response = await fetch(fetchAPI, payload);
}
async function handleDirect(request: Request, path: string) {
const [key, body] = await extractRequest(request);
const response: Response = await requestAzure(request.method, body, path, key);
if (body?.stream != true){
return response
}

const { readable, writable } = new TransformStream();

if (response.body) {
stream(response.body, writable);
return new Response(readable, response);
const { readable, writable } = new TransformStream();
stream(response.body, writable);
return new Response(readable, response);
} else {
throw new Error('Response body is null');
}

throw new Error('Response body is null');
}
}


function sleep(ms:number):Promise<void> {
return new Promise(resolve => setTimeout(resolve, ms));
}
Expand Down Expand Up @@ -117,6 +115,50 @@ async function stream(readable:ReadableStream<Uint8Array>, writable:WritableStre
await writer.close();
}

async function extractRequest(request: Request) {
const key = request.headers.get('Authorization')?.replace('Bearer ', '');
const body = request.method === "POST" ? await request.json() : null;
return [key, body]
}

async function handleEmbedding(request: Request, path: string) {
const [key, body] = await extractRequest(request);
const input = body.input;
if (typeof input === "string") {
return await requestAzure(request.method, body, path, key);
} else if (Array.isArray(input)) {
const resps = pooledMap(3,
input, x => {
return requestAzure(request.method, { ...body, input: x }, path, key);
});
const retbody = {
object: "list",
data: [] as any[],
model: body.model,
usage: {
prompt_tokens: 0,
total_tokens: 0
}
};
let i = 0;
for await (const r of resps) {
const ret = await r.json();
for (const data of ret.data) {
retbody.data.push({ ...data, index: i });
i++;
}
retbody.usage.prompt_tokens += ret.usage.prompt_tokens;
retbody.usage.total_tokens += ret.usage.total_tokens;
}
const json: string = JSON.stringify(retbody, null, 2);
return new Response(json, {
headers: { 'Content-Type': 'application/json' },
});
} else {
throw new Error('Invalid input type');
}
}

async function handleModels(request:Request):Promise<Response> {
const data:any = {
"object": "list",
Expand Down

0 comments on commit 71bb363

Please sign in to comment.