Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the Content caching feature #167

Merged
merged 15 commits into from
Jun 17, 2024
Merged
Prev Previous commit
Next Next commit
address PR comments, fix rollup config
  • Loading branch information
hsubox76 committed Jun 6, 2024
commit c1283bfaf49d5e8dd70ff1053e06de67d69e3e68
19 changes: 13 additions & 6 deletions common/api-review/generative-ai-server.api.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,11 @@

// @public
export interface CachedContent extends CachedContentBase {
// (undocumented)
createTime?: string;
expireTime?: string;
// (undocumented)
name?: string;
// (undocumented)
ttl?: string;
// (undocumented)
updateTime?: string;
}

Expand All @@ -32,7 +30,6 @@ export interface CachedContentBase {

// @public
export interface CachedContentCreateParams extends CachedContentBase {
// (undocumented)
ttlSeconds?: number;
}

Expand Down Expand Up @@ -147,12 +144,22 @@ export interface FunctionCall {
export interface FunctionCallingConfig {
// (undocumented)
allowedFunctionNames?: string[];
// Warning: (ae-forgotten-export) The symbol "FunctionCallingMode" needs to be exported by the entry point index.d.ts
//
// (undocumented)
mode?: FunctionCallingMode;
}

// @public (undocumented)
export enum FunctionCallingMode {
// (undocumented)
ANY = "ANY",
// (undocumented)
AUTO = "AUTO",
// (undocumented)
MODE_UNSPECIFIED = "MODE_UNSPECIFIED",
// (undocumented)
NONE = "NONE"
}

// @public
export interface FunctionCallPart {
// (undocumented)
Expand Down
4 changes: 1 addition & 3 deletions common/api-review/generative-ai.api.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,11 @@ export enum BlockReason {

// @public
export interface CachedContent extends CachedContentBase {
// (undocumented)
createTime?: string;
expireTime?: string;
// (undocumented)
name?: string;
// (undocumented)
ttl?: string;
// (undocumented)
updateTime?: string;
}

Expand Down
5 changes: 1 addition & 4 deletions packages/main/rollup.config.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import pkg from "./package.json" assert { type: "json" };

const es2017BuildPlugins = [
typescriptPlugin({
clean: true,
typescript,
tsconfigOverride: {
compilerOptions: {
Expand All @@ -41,7 +42,6 @@ const es2017BuildPlugins = [

const esmBuilds = [
{
clean: true,
input: "src/index.ts",
output: {
file: pkg.module,
Expand All @@ -55,7 +55,6 @@ const esmBuilds = [

const cjsBuilds = [
{
clean: true,
input: "src/index.ts",
output: [{ file: pkg.main, format: "cjs", sourcemap: true }],
external: ["fs"],
Expand All @@ -65,7 +64,6 @@ const cjsBuilds = [

const serverBuilds = [
{
clean: true,
input: "src/server/index.ts",
output: [
{ file: pkg.exports["./server"].import, format: "es", sourcemap: true },
Expand All @@ -74,7 +72,6 @@ const serverBuilds = [
plugins: [...es2017BuildPlugins],
},
{
clean: true,
input: "src/server/index.ts",
output: [
{ file: pkg.exports["./server"].require, format: "cjs", sourcemap: true },
Expand Down
2 changes: 1 addition & 1 deletion packages/main/server/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@google/generative-ai-server",
"description": "GoogleAI file upload manager",
"description": "GoogleAI JS server-environment-only features",
"main": "../dist/server/index.js",
"browser": "../dist/server/index.mjs",
"module": "../dist/server/index.mjs",
Expand Down
15 changes: 14 additions & 1 deletion packages/main/src/gen-ai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
* limitations under the License.
*/

import { GoogleGenerativeAIError } from "./errors";
import {
GoogleGenerativeAIError,
GoogleGenerativeAIRequestInputError,
} from "./errors";
import { CachedContent, ModelParams, RequestOptions } from "../types";
import { GenerativeModel } from "./models/generative-model";

Expand Down Expand Up @@ -52,6 +55,16 @@ export class GoogleGenerativeAI {
cachedContent: CachedContent,
requestOptions?: RequestOptions,
): GenerativeModel {
if (!cachedContent.name) {
throw new GoogleGenerativeAIRequestInputError(
"Cached content must contain a `name` field.",
);
}
if (!cachedContent.model) {
throw new GoogleGenerativeAIRequestInputError(
"Cached content must contain a `model` field.",
);
}
const modelParamsFromCache: ModelParams = {
model: cachedContent.model,
tools: cachedContent.tools,
Expand Down
4 changes: 2 additions & 2 deletions packages/main/src/requests/request.ts
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ export async function makeRequest(
return response;
}

export function handleResponseError(e: Error, url: string): void {
function handleResponseError(e: Error, url: string): void {
let err = e;
if (
!(
Expand All @@ -186,7 +186,7 @@ export function handleResponseError(e: Error, url: string): void {
throw err;
}

export async function handleResponseNotOk(
async function handleResponseNotOk(
response: Response,
url: string,
): Promise<void> {
Expand Down
31 changes: 18 additions & 13 deletions packages/main/src/server/cache-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ import {
ListParams,
} from "../../types/server";
import { RpcTask } from "./constants";
import { GoogleGenerativeAIError } from "../errors";
import { GoogleGenerativeAIError, GoogleGenerativeAIRequestInputError } from "../errors";

/**
* Class for managing GoogleAI file uploads.
* Class for managing GoogleAI content caches.
* @public
*/
export class GoogleAICacheManager {
Expand All @@ -65,6 +65,11 @@ export class GoogleAICacheManager {
newCachedContent.ttl = createOptions.ttlSeconds.toString() + "s";
delete (newCachedContent as CachedContentCreateParams).ttlSeconds;
}
if (!newCachedContent.model) {
throw new GoogleGenerativeAIRequestInputError(
"Cached content must contain a `model` field.",
);
}
if (!newCachedContent.model.includes("/")) {
DellaBitta marked this conversation as resolved.
Show resolved Hide resolved
// If path is not included, assume it's a non-tuned model.
newCachedContent.model = `models/${newCachedContent.model}`;
Expand All @@ -75,11 +80,11 @@ export class GoogleAICacheManager {
this._requestOptions,
);

const uploadHeaders = getHeaders(url);
const headers = getHeaders(url);

const response = await makeServerRequest(
url,
uploadHeaders,
headers,
JSON.stringify(newCachedContent),
);
return response.json();
Expand All @@ -100,8 +105,8 @@ export class GoogleAICacheManager {
if (listParams?.pageToken) {
url.appendParam("pageToken", listParams.pageToken);
}
const uploadHeaders = getHeaders(url);
const response = await makeServerRequest(url, uploadHeaders);
const headers = getHeaders(url);
const response = await makeServerRequest(url, headers);
return response.json();
}

Expand All @@ -115,8 +120,8 @@ export class GoogleAICacheManager {
this._requestOptions,
);
url.appendPath(parseCacheName(name));
const uploadHeaders = getHeaders(url);
const response = await makeServerRequest(url, uploadHeaders);
const headers = getHeaders(url);
const response = await makeServerRequest(url, headers);
return response.json();
}

Expand All @@ -133,10 +138,10 @@ export class GoogleAICacheManager {
this._requestOptions,
);
url.appendPath(parseCacheName(name));
const uploadHeaders = getHeaders(url);
const headers = getHeaders(url);
const response = await makeServerRequest(
url,
uploadHeaders,
headers,
JSON.stringify(updateParams),
);
return response.json();
Expand All @@ -152,13 +157,13 @@ export class GoogleAICacheManager {
this._requestOptions,
);
url.appendPath(parseCacheName(name));
const uploadHeaders = getHeaders(url);
await makeServerRequest(url, uploadHeaders);
const headers = getHeaders(url);
await makeServerRequest(url, headers);
}
}

/**
* If fileId is prepended with "files/", remove prefix
* If cache name is prepended with "cachedContents/", remove prefix
*/
function parseCacheName(name: string): string {
if (name.startsWith("cachedContents/")) {
Expand Down
1 change: 1 addition & 0 deletions packages/main/types/server/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ export * from "./shared";

export { RequestOptions } from "../../types/requests";
export * from "../../types/content";
export { FunctionCallingMode } from '../../types/enums';