Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the Content caching feature #167

Merged
merged 15 commits into from
Jun 17, 2024
Merged
Next Next commit
add cache, streamline request functions
hsubox76 committed Jun 5, 2024
commit d069d67a76f5e6c47c0a497e6d0809b102d82e94
12 changes: 0 additions & 12 deletions packages/main/api-extractor.files.json

This file was deleted.

12 changes: 12 additions & 0 deletions packages/main/api-extractor.server.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"extends": "../../config/api-extractor.json",
"mainEntryPointFilePath": "<projectFolder>/dist/src/server/index.d.ts",
"dtsRollup": {
"enabled": true,
"untrimmedFilePath": "<projectFolder>/dist/server/server.d.ts"
},
"docModel": {
"enabled": true,
"apiJsonFilePath": "<projectFolder>/temp/server/<unscopedPackageName>-server.api.json"
}
}
8 changes: 0 additions & 8 deletions packages/main/files/package.json

This file was deleted.

8 changes: 8 additions & 0 deletions packages/main/server/package.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"name": "@google/generative-ai-server",
"description": "GoogleAI file upload manager",
DellaBitta marked this conversation as resolved.
Show resolved Hide resolved
"main": "../dist/server/index.js",
"browser": "../dist/server/index.mjs",
"module": "../dist/server/index.mjs",
"typings": "../dist/server/server.d.ts"
}
4 changes: 2 additions & 2 deletions packages/main/src/methods/count-tokens.ts
Original file line number Diff line number Diff line change
@@ -20,15 +20,15 @@ import {
CountTokensResponse,
RequestOptions,
} from "../../types";
import { Task, makeRequest } from "../requests/request";
import { Task, makeModelRequest } from "../requests/request";

export async function countTokens(
apiKey: string,
model: string,
params: CountTokensRequest,
requestOptions?: RequestOptions,
): Promise<CountTokensResponse> {
const response = await makeRequest(
const response = await makeModelRequest(
model,
Task.COUNT_TOKENS,
apiKey,
6 changes: 3 additions & 3 deletions packages/main/src/methods/embed-content.ts
Original file line number Diff line number Diff line change
@@ -22,15 +22,15 @@ import {
EmbedContentResponse,
RequestOptions,
} from "../../types";
import { Task, makeRequest } from "../requests/request";
import { Task, makeModelRequest } from "../requests/request";

export async function embedContent(
apiKey: string,
model: string,
params: EmbedContentRequest,
requestOptions?: RequestOptions,
): Promise<EmbedContentResponse> {
const response = await makeRequest(
const response = await makeModelRequest(
model,
Task.EMBED_CONTENT,
apiKey,
@@ -52,7 +52,7 @@ export async function batchEmbedContents(
return { ...request, model };
},
);
const response = await makeRequest(
const response = await makeModelRequest(
model,
Task.BATCH_EMBED_CONTENTS,
apiKey,
6 changes: 3 additions & 3 deletions packages/main/src/methods/generate-content.ts
Original file line number Diff line number Diff line change
@@ -22,7 +22,7 @@ import {
GenerateContentStreamResult,
RequestOptions,
} from "../../types";
import { Task, makeRequest } from "../requests/request";
import { Task, makeModelRequest } from "../requests/request";
import { addHelpers } from "../requests/response-helpers";
import { processStream } from "../requests/stream-reader";

@@ -32,7 +32,7 @@ export async function generateContentStream(
params: GenerateContentRequest,
requestOptions?: RequestOptions,
): Promise<GenerateContentStreamResult> {
const response = await makeRequest(
const response = await makeModelRequest(
model,
Task.STREAM_GENERATE_CONTENT,
apiKey,
@@ -49,7 +49,7 @@ export async function generateContent(
params: GenerateContentRequest,
requestOptions?: RequestOptions,
): Promise<GenerateContentResult> {
const response = await makeRequest(
const response = await makeModelRequest(
model,
Task.GENERATE_CONTENT,
apiKey,
30 changes: 15 additions & 15 deletions packages/main/src/requests/request.test.ts
Original file line number Diff line number Diff line change
@@ -24,8 +24,8 @@ import {
DEFAULT_BASE_URL,
RequestUrl,
Task,
_makeRequestInternal,
constructRequest,
constructModelRequest,
makeModelRequest,
} from "./request";
import {
GoogleGenerativeAIFetchError,
@@ -112,7 +112,7 @@ describe("request methods", () => {
});
describe("constructRequest", () => {
it("handles basic request", async () => {
const request = await constructRequest(
const request = await constructModelRequest(
"model-name",
Task.GENERATE_CONTENT,
"key",
@@ -131,7 +131,7 @@ describe("request methods", () => {
).to.equal("application/json");
});
it("passes apiClient", async () => {
const request = await constructRequest(
const request = await constructModelRequest(
"model-name",
Task.GENERATE_CONTENT,
"key",
@@ -146,7 +146,7 @@ describe("request methods", () => {
).to.equal("client/version genai-js/__PACKAGE_VERSION__");
});
it("passes timeout", async () => {
const request = await constructRequest(
const request = await constructModelRequest(
"model-name",
Task.GENERATE_CONTENT,
"key",
@@ -159,7 +159,7 @@ describe("request methods", () => {
expect(request.fetchOptions.signal).to.be.instanceOf(AbortSignal);
});
it("passes custom headers", async () => {
const request = await constructRequest(
const request = await constructModelRequest(
"model-name",
Task.GENERATE_CONTENT,
"key",
@@ -175,7 +175,7 @@ describe("request methods", () => {
});
it("passes custom x-goog-api-client header", async () => {
await expect(
constructRequest("model-name", Task.GENERATE_CONTENT, "key", true, "", {
constructModelRequest("model-name", Task.GENERATE_CONTENT, "key", true, "", {
customHeaders: new Headers({
"x-goog-api-client": "client/version",
}),
@@ -184,7 +184,7 @@ describe("request methods", () => {
});
it("passes apiClient and custom x-goog-api-client header", async () => {
await expect(
constructRequest("model-name", Task.GENERATE_CONTENT, "key", true, "", {
constructModelRequest("model-name", Task.GENERATE_CONTENT, "key", true, "", {
apiClient: "client/version",
customHeaders: new Headers({
"x-goog-api-client": "client/version2",
@@ -193,12 +193,12 @@ describe("request methods", () => {
).to.be.rejectedWith(GoogleGenerativeAIRequestInputError);
});
});
describe("_makeRequestInternal", () => {
describe("makeModelRequest", () => {
it("no error", async () => {
const fetchStub = stub().resolves({
ok: true,
} as Response);
const response = await _makeRequestInternal(
const response = await makeModelRequest(
"model-name",
Task.GENERATE_CONTENT,
"key",
@@ -222,7 +222,7 @@ describe("request methods", () => {
} as Response);

try {
await _makeRequestInternal(
await makeModelRequest(
"model-name",
Task.GENERATE_CONTENT,
"key",
@@ -251,7 +251,7 @@ describe("request methods", () => {
statusText: "Server Error",
} as Response);
try {
await _makeRequestInternal(
await makeModelRequest(
"model-name",
Task.GENERATE_CONTENT,
"key",
@@ -280,7 +280,7 @@ describe("request methods", () => {
} as Response);

try {
await _makeRequestInternal(
await makeModelRequest(
"model-name",
Task.GENERATE_CONTENT,
"key",
@@ -321,7 +321,7 @@ describe("request methods", () => {
} as Response);

try {
await _makeRequestInternal(
await makeModelRequest(
"model-name",
Task.GENERATE_CONTENT,
"key",
@@ -349,7 +349,7 @@ describe("request methods", () => {
it("has invalid custom header", async () => {
const fetchStub = stub();
await expect(
_makeRequestInternal(
makeModelRequest(
"model-name",
Task.GENERATE_CONTENT,
"key",
117 changes: 57 additions & 60 deletions packages/main/src/requests/request.ts
Original file line number Diff line number Diff line change
@@ -110,7 +110,7 @@ export async function getHeaders(url: RequestUrl): Promise<Headers> {
return headers;
}

export async function constructRequest(
export async function constructModelRequest(
model: string,
task: Task,
apiKey: string,
@@ -130,89 +130,86 @@ export async function constructRequest(
};
}

/**
* Wrapper for _makeRequestInternal that automatically uses native fetch,
* allowing _makeRequestInternal to be tested with a mocked fetch function.
*/
export async function makeRequest(
export async function makeModelRequest(
model: string,
task: Task,
apiKey: string,
stream: boolean,
body: string,
requestOptions?: RequestOptions,
// Allows this to be stubbed for tests
fetchFn = fetch,
): Promise<Response> {
return _makeRequestInternal(
const { url, fetchOptions } = await constructModelRequest(
model,
task,
apiKey,
stream,
body,
requestOptions,
fetch,
);
return makeRequest(url, fetchOptions, fetchFn);
}

export async function _makeRequestInternal(
model: string,
task: Task,
apiKey: string,
stream: boolean,
body: string,
requestOptions?: RequestOptions,
// Allows this to be stubbed for tests
export async function makeRequest(
url: string,
fetchOptions: RequestInit,
fetchFn = fetch,
): Promise<Response> {
const url = new RequestUrl(model, task, apiKey, stream, requestOptions);
let response;
try {
const request = await constructRequest(
model,
task,
apiKey,
stream,
body,
requestOptions,
response = await fetchFn(url, fetchOptions);
} catch (e) {
handleResponseError(e, url);
}

if (!response.ok) {
await handleResponseNotOk(response, url);
}

return response;
}

export function handleResponseError(e: Error, url: string): void {
DellaBitta marked this conversation as resolved.
Show resolved Hide resolved
let err = e;
if (
!(
e instanceof GoogleGenerativeAIFetchError ||
e instanceof GoogleGenerativeAIRequestInputError
)
) {
err = new GoogleGenerativeAIError(
`Error fetching from ${url.toString()}: ${e.message}`,
);
response = await fetchFn(request.url, request.fetchOptions);
if (!response.ok) {
let message = "";
let errorDetails;
try {
const json = await response.json();
message = json.error.message;
if (json.error.details) {
message += ` ${JSON.stringify(json.error.details)}`;
errorDetails = json.error.details;
}
} catch (e) {
// ignored
}
throw new GoogleGenerativeAIFetchError(
`Error fetching from ${url.toString()}: [${response.status} ${
response.statusText
}] ${message}`,
response.status,
response.statusText,
errorDetails,
);
err.stack = e.stack;
}
throw err;
}

export async function handleResponseNotOk(
DellaBitta marked this conversation as resolved.
Show resolved Hide resolved
response: Response,
url: string,
): Promise<void> {
let message = "";
let errorDetails;
try {
const json = await response.json();
message = json.error.message;
if (json.error.details) {
message += ` ${JSON.stringify(json.error.details)}`;
errorDetails = json.error.details;
}
} catch (e) {
let err = e;
if (
!(
e instanceof GoogleGenerativeAIFetchError ||
e instanceof GoogleGenerativeAIRequestInputError
)
) {
err = new GoogleGenerativeAIError(
`Error fetching from ${url.toString()}: ${e.message}`,
);
err.stack = e.stack;
}
throw err;
// ignored
}
return response;
throw new GoogleGenerativeAIFetchError(
`Error fetching from ${url.toString()}: [${response.status} ${
response.statusText
}] ${message}`,
response.status,
response.statusText,
errorDetails,
);
}

/**
Loading