Skip to content

Commit

Permalink
Feature: Google Vertex AI PaLM Support (langchain-ai#1273)
Browse files Browse the repository at this point in the history
* Initial work implementing access to the Google Vertex AI API and the LLM models available through it.

* Refactor into a base class

* Prettier

* Add chat version of Google Vertex AI LLM.

* Add Google Vertex AI classes

* Example for Google Vertex AI LLM

* Add Google Vertex AI LLM

* Documentation

* Google Vertex AI Chat example

* Fix capitalization typo

* Make Google's SDK a peer dependency, move out of main entrypoint

* Add convenience method for type equality.

* Extensive refactoring to make the Google Vertex AI Chat implementation compliant

* Legacy prettier cleanup

* Prettier cleanup of refactored code

* Refactor to move connection and common types into separate modules.

* Refactor to move connection and common types into separate modules.

* Refactor to move connection and common types into separate modules.

* Change maxTokens to maxOutputTokens.
Change configuration for the connection object.
Prettier

* ChatGoogleVertexAi entry point

* GoogleVertexAi tests

* GoogleVertexAi example refactoring

* Example for ChatGoogleVertexAi
Fix and test for determining context and examples.

* Documentation for Google VertexAI text and chat models.

* Refactor some of ChatGoogleVertexAi to break out some components to be better unit tested.
Add unit tests.
Fix bugs with context and example generation.
Formatting.

* Descope PR to just GoogleVertexAI LLM, normalize class names

* Small fixes and renames, increase default max token output size

* Fix typo in documentation import from examples. I think.

* Fix docs

---------

Co-authored-by: Jacob Lee <jacoblee93@gmail.com>
  • Loading branch information
afirstenberg and jacoblee93 authored May 23, 2023
1 parent 69a691b commit 5e42176
Show file tree
Hide file tree
Showing 11 changed files with 479 additions and 3 deletions.
25 changes: 25 additions & 0 deletions docs/docs/modules/models/llms/integrations.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,31 @@ const res = await model.call(
console.log({ res });
```

## Google Vertex AI

The Vertex AI implementation is meant to be used in Node.js and not
directly in a browser, since it requires a service account to use.

Before running this code, you should make sure the Vertex AI API is
enabled for the relevant project in your Google Cloud dashboard and that you've authenticated to
Google Cloud using one of these methods:

- You are logged into an account (using `gcloud auth application-default login`)
permitted to that project.
- You are running on a machine using a service account that is permitted
to the project.
- You have downloaded the credentials for a service account that is permitted
to the project and set the `GOOGLE_APPLICATION_CREDENTIALS` environment
variable to the path of this file.

```bash npm2yarn
npm install google-auth-library
```

import GoogleVertexAIExample from "@examples/llms/googlevertexai.ts";

<CodeBlock language="typescript">{GoogleVertexAIExample}</CodeBlock>

## `HuggingFaceInference`

```bash npm2yarn
Expand Down
23 changes: 23 additions & 0 deletions examples/src/llms/googlevertexai.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import { GoogleVertexAI } from "langchain/llms/googlevertexai";

/*
* Before running this, you should make sure you have created a
* Google Cloud Project that is permitted to the Vertex AI API.
*
* You will also need permission to access this project / API.
* Typically, this is done in one of three ways:
* - You are logged into an account permitted to that project.
* - You are running this on a machine using a service account permitted to
* the project.
* - The `GOOGLE_APPLICATION_CREDENTIALS` environment variable is set to the
* path of a credentials file for a service account permitted to the project.
*/
export const run = async () => {
const model = new GoogleVertexAI({
temperature: 0.7,
});
const res = await model.call(
"What would be a good company name a company that makes colorful socks?"
);
console.log({ res });
};
3 changes: 3 additions & 0 deletions langchain/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ llms/hf.d.ts
llms/replicate.cjs
llms/replicate.js
llms/replicate.d.ts
llms/googlevertexai.cjs
llms/googlevertexai.js
llms/googlevertexai.d.ts
llms/sagemaker_endpoint.cjs
llms/sagemaker_endpoint.js
llms/sagemaker_endpoint.d.ts
Expand Down
10 changes: 10 additions & 0 deletions langchain/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@
"llms/replicate.cjs",
"llms/replicate.js",
"llms/replicate.d.ts",
"llms/googlevertexai.cjs",
"llms/googlevertexai.js",
"llms/googlevertexai.d.ts",
"llms/sagemaker_endpoint.cjs",
"llms/sagemaker_endpoint.js",
"llms/sagemaker_endpoint.d.ts",
Expand Down Expand Up @@ -390,6 +393,7 @@
"eslint-plugin-no-instanceof": "^1.0.1",
"eslint-plugin-prettier": "^4.2.1",
"faiss-node": "^0.1.1",
"google-auth-library": "^8.8.0",
"graphql": "^16.6.0",
"hnswlib-node": "^1.4.2",
"html-to-text": "^9.0.5",
Expand Down Expand Up @@ -437,6 +441,7 @@
"d3-dsv": "^2.0.0",
"epub2": "^3.0.1",
"faiss-node": "^0.1.1",
"google-auth-library": "^8.8.0",
"hnswlib-node": "^1.4.2",
"html-to-text": "^9.0.5",
"ignore": "^5.2.0",
Expand Down Expand Up @@ -736,6 +741,11 @@
"import": "./llms/replicate.js",
"require": "./llms/replicate.cjs"
},
"./llms/googlevertexai": {
"types": "./llms/googlevertexai.d.ts",
"import": "./llms/googlevertexai.js",
"require": "./llms/googlevertexai.cjs"
},
"./llms/sagemaker_endpoint": {
"types": "./llms/sagemaker_endpoint.d.ts",
"import": "./llms/sagemaker_endpoint.js",
Expand Down
2 changes: 2 additions & 0 deletions langchain/scripts/create-entrypoints.js
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ const entrypoints = {
"llms/cohere": "llms/cohere",
"llms/hf": "llms/hf",
"llms/replicate": "llms/replicate",
"llms/googlevertexai": "llms/googlevertexai",
"llms/sagemaker_endpoint": "llms/sagemaker_endpoint",
// prompts
prompts: "prompts/index",
Expand Down Expand Up @@ -161,6 +162,7 @@ const requiresOptionalDependency = [
"embeddings/hf",
"llms/load",
"llms/cohere",
"llms/googlevertexai",
"llms/hf",
"llms/replicate",
"llms/sagemaker_endpoint",
Expand Down
118 changes: 118 additions & 0 deletions langchain/src/llms/googlevertexai.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import { BaseLLM } from "./base.js";
import { Generation, LLMResult } from "../schema/index.js";
import { GoogleVertexAIConnection } from "../util/googlevertexai-connection.js";
import {
GoogleVertexAIBaseLLMInput,
GoogleVertexAIBasePrediction,
GoogleVertexAILLMResponse,
GoogleVertexAIModelParams,
} from "../types/googlevertexai-types.js";

export interface GoogleVertexAITextInput extends GoogleVertexAIBaseLLMInput {}

interface GoogleVertexAILLMTextInstance {
content: string;
}

/**
* Models the data returned from the API call
*/
interface TextPrediction extends GoogleVertexAIBasePrediction {
content: string;
}

/**
* Enables calls to the Google Cloud's Vertex AI API to access
* Large Language Models.
*
* To use, you will need to have one of the following authentication
* methods in place:
* - You are logged into an account permitted to the Google Cloud project
* using Vertex AI.
* - You are running this on a machine using a service account permitted to
* the Google Cloud project using Vertex AI.
* - The `GOOGLE_APPLICATION_CREDENTIALS` environment variable is set to the
* path of a credentials file for a service account permitted to the
* Google Cloud project using Vertex AI.
*/
export class GoogleVertexAI extends BaseLLM implements GoogleVertexAITextInput {
model = "text-bison";

temperature = 0.7;

maxOutputTokens = 1024;

topP = 0.8;

topK = 40;

private connection: GoogleVertexAIConnection<
this["CallOptions"],
GoogleVertexAILLMTextInstance,
TextPrediction
>;

constructor(fields?: GoogleVertexAITextInput) {
super(fields ?? {});

this.model = fields?.model ?? this.model;
this.temperature = fields?.temperature ?? this.temperature;
this.maxOutputTokens = fields?.maxOutputTokens ?? this.maxOutputTokens;
this.topP = fields?.topP ?? this.topP;
this.topK = fields?.topK ?? this.topK;

this.connection = new GoogleVertexAIConnection(
{ ...fields, ...this },
this.caller
);
}

_llmType(): string {
return "googlevertexai";
}

async _generate(
prompts: string[],
options: this["ParsedCallOptions"]
): Promise<LLMResult> {
const generations: Generation[][] = await Promise.all(
prompts.map((prompt) => this._generatePrompt(prompt, options))
);
return { generations };
}

async _generatePrompt(
prompt: string,
options: this["ParsedCallOptions"]
): Promise<Generation[]> {
const instance = this.formatInstance(prompt);
const parameters: GoogleVertexAIModelParams = {
temperature: this.temperature,
topK: this.topK,
topP: this.topP,
maxOutputTokens: this.maxOutputTokens,
};
const result = await this.connection.request(
[instance],
parameters,
options
);
const prediction = this.extractPredictionFromResponse(result);
return [
{
text: prediction.content,
generationInfo: prediction,
},
];
}

formatInstance(prompt: string): GoogleVertexAILLMTextInstance {
return { content: prompt };
}

extractPredictionFromResponse(
result: GoogleVertexAILLMResponse<TextPrediction>
): TextPrediction {
return result?.data?.predictions[0];
}
}
28 changes: 28 additions & 0 deletions langchain/src/llms/tests/googlevertexai.int.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import { test } from "@jest/globals";
import { GoogleVertexAI } from "../googlevertexai.js";

test("Test Google Vertex", async () => {
const model = new GoogleVertexAI({ maxOutputTokens: 50 });
const res = await model.call("1 + 1 = ");
console.log({ res });
});

test("Test Google Vertex generation", async () => {
const model = new GoogleVertexAI({ maxOutputTokens: 50 });
const res = await model.generate(["1 + 1 = "]);
console.log(JSON.stringify(res, null, 2));
});

test("Test Google Vertex generation", async () => {
const model = new GoogleVertexAI({ maxOutputTokens: 50 });
const res = await model.generate(["Print hello world."]);
console.log(JSON.stringify(res, null, 2));
});

test("Test Google Vertex generation", async () => {
const model = new GoogleVertexAI({ maxOutputTokens: 50 });
const res = await model.generate([
`Translate "I love programming" into Korean.`,
]);
console.log(JSON.stringify(res, null, 2));
});
62 changes: 62 additions & 0 deletions langchain/src/types/googlevertexai-types.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import { BaseLLMParams } from "../llms/index.js";

export interface GoogleVertexAIConnectionParams {
/** Hostname for the API call */
endpoint?: string;

/** Region where the LLM is stored */
location?: string;

/** Model to use */
model?: string;
}

export interface GoogleVertexAIModelParams {
/** Sampling temperature to use */
temperature?: number;

/**
* Maximum number of tokens to generate in the completion.
*/
maxOutputTokens?: number;

/**
* Top-p changes how the model selects tokens for output.
*
* Tokens are selected from most probable to least until the sum
* of their probabilities equals the top-p value.
*
* For example, if tokens A, B, and C have a probability of
* .3, .2, and .1 and the top-p value is .5, then the model will
* select either A or B as the next token (using temperature).
*/
topP?: number;

/**
* Top-k changes how the model selects tokens for output.
*
* A top-k of 1 means the selected token is the most probable among
* all tokens in the model’s vocabulary (also called greedy decoding),
* while a top-k of 3 means that the next token is selected from
* among the 3 most probable tokens (using temperature).
*/
topK?: number;
}

export interface GoogleVertexAIBaseLLMInput
extends BaseLLMParams,
GoogleVertexAIConnectionParams,
GoogleVertexAIModelParams {}

export interface GoogleVertexAIBasePrediction {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
safetyAttributes?: any;
}

export interface GoogleVertexAILLMResponse<
PredictionType extends GoogleVertexAIBasePrediction
> {
data: {
predictions: PredictionType[];
};
}
Loading

0 comments on commit 5e42176

Please sign in to comment.