forked from langchain-ai/langchainjs
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feature: Google Vertex AI PaLM Support (langchain-ai#1273)
* Initial work implementing access to the Google Vertex AI API and the LLM models available through it. * Refactor into a base class * Prettier * Add chat version of Google Vertex AI LLM. * Add Google Vertex AI classes * Example for Google Vertex AI LLM * Add Google Vertex AI LLM * Documentation * Google Vertex AI Chat example * Fix capitalization typo * Make Google's SDK a peer dependency, move out of main entrypoint * Add convenience method for type equality. * Extensive refactoring to make the Google Vertex AI Chat implementation compliant * Legacy prettier cleanup * Prettier cleanup of refactored code * Refactor to move connection and common types into separate modules. * Refactor to move connection and common types into separate modules. * Refactor to move connection and common types into separate modules. * Change maxTokens to maxOutputTokens. Change configuration for the connection object. Prettier * ChatGoogleVertexAi entry point * GoogleVertexAi tests * GoogleVertexAi example refactoring * Example for ChatGoogleVertexAi Fix and test for determining context and examples. * Documentation for Google VertexAI text and chat models. * Refactor some of ChatGoogleVertexAi to break out some components to be better unit tested. Add unit tests. Fix bugs with context and example generation. Formatting. * Descope PR to just GoogleVertexAI LLM, normalize class names * Small fixes and renames, increase default max token output size * Fix typo in documentation import from examples. I think. * Fix docs --------- Co-authored-by: Jacob Lee <jacoblee93@gmail.com>
- Loading branch information
1 parent
69a691b
commit 5e42176
Showing
11 changed files
with
479 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import { GoogleVertexAI } from "langchain/llms/googlevertexai"; | ||
|
||
/* | ||
* Before running this, you should make sure you have created a | ||
* Google Cloud Project that is permitted to the Vertex AI API. | ||
* | ||
* You will also need permission to access this project / API. | ||
* Typically, this is done in one of three ways: | ||
* - You are logged into an account permitted to that project. | ||
* - You are running this on a machine using a service account permitted to | ||
* the project. | ||
* - The `GOOGLE_APPLICATION_CREDENTIALS` environment variable is set to the | ||
* path of a credentials file for a service account permitted to the project. | ||
*/ | ||
export const run = async () => { | ||
const model = new GoogleVertexAI({ | ||
temperature: 0.7, | ||
}); | ||
const res = await model.call( | ||
"What would be a good company name a company that makes colorful socks?" | ||
); | ||
console.log({ res }); | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
import { BaseLLM } from "./base.js"; | ||
import { Generation, LLMResult } from "../schema/index.js"; | ||
import { GoogleVertexAIConnection } from "../util/googlevertexai-connection.js"; | ||
import { | ||
GoogleVertexAIBaseLLMInput, | ||
GoogleVertexAIBasePrediction, | ||
GoogleVertexAILLMResponse, | ||
GoogleVertexAIModelParams, | ||
} from "../types/googlevertexai-types.js"; | ||
|
||
export interface GoogleVertexAITextInput extends GoogleVertexAIBaseLLMInput {} | ||
|
||
interface GoogleVertexAILLMTextInstance { | ||
content: string; | ||
} | ||
|
||
/** | ||
* Models the data returned from the API call | ||
*/ | ||
interface TextPrediction extends GoogleVertexAIBasePrediction { | ||
content: string; | ||
} | ||
|
||
/** | ||
* Enables calls to the Google Cloud's Vertex AI API to access | ||
* Large Language Models. | ||
* | ||
* To use, you will need to have one of the following authentication | ||
* methods in place: | ||
* - You are logged into an account permitted to the Google Cloud project | ||
* using Vertex AI. | ||
* - You are running this on a machine using a service account permitted to | ||
* the Google Cloud project using Vertex AI. | ||
* - The `GOOGLE_APPLICATION_CREDENTIALS` environment variable is set to the | ||
* path of a credentials file for a service account permitted to the | ||
* Google Cloud project using Vertex AI. | ||
*/ | ||
export class GoogleVertexAI extends BaseLLM implements GoogleVertexAITextInput { | ||
model = "text-bison"; | ||
|
||
temperature = 0.7; | ||
|
||
maxOutputTokens = 1024; | ||
|
||
topP = 0.8; | ||
|
||
topK = 40; | ||
|
||
private connection: GoogleVertexAIConnection< | ||
this["CallOptions"], | ||
GoogleVertexAILLMTextInstance, | ||
TextPrediction | ||
>; | ||
|
||
constructor(fields?: GoogleVertexAITextInput) { | ||
super(fields ?? {}); | ||
|
||
this.model = fields?.model ?? this.model; | ||
this.temperature = fields?.temperature ?? this.temperature; | ||
this.maxOutputTokens = fields?.maxOutputTokens ?? this.maxOutputTokens; | ||
this.topP = fields?.topP ?? this.topP; | ||
this.topK = fields?.topK ?? this.topK; | ||
|
||
this.connection = new GoogleVertexAIConnection( | ||
{ ...fields, ...this }, | ||
this.caller | ||
); | ||
} | ||
|
||
_llmType(): string { | ||
return "googlevertexai"; | ||
} | ||
|
||
async _generate( | ||
prompts: string[], | ||
options: this["ParsedCallOptions"] | ||
): Promise<LLMResult> { | ||
const generations: Generation[][] = await Promise.all( | ||
prompts.map((prompt) => this._generatePrompt(prompt, options)) | ||
); | ||
return { generations }; | ||
} | ||
|
||
async _generatePrompt( | ||
prompt: string, | ||
options: this["ParsedCallOptions"] | ||
): Promise<Generation[]> { | ||
const instance = this.formatInstance(prompt); | ||
const parameters: GoogleVertexAIModelParams = { | ||
temperature: this.temperature, | ||
topK: this.topK, | ||
topP: this.topP, | ||
maxOutputTokens: this.maxOutputTokens, | ||
}; | ||
const result = await this.connection.request( | ||
[instance], | ||
parameters, | ||
options | ||
); | ||
const prediction = this.extractPredictionFromResponse(result); | ||
return [ | ||
{ | ||
text: prediction.content, | ||
generationInfo: prediction, | ||
}, | ||
]; | ||
} | ||
|
||
formatInstance(prompt: string): GoogleVertexAILLMTextInstance { | ||
return { content: prompt }; | ||
} | ||
|
||
extractPredictionFromResponse( | ||
result: GoogleVertexAILLMResponse<TextPrediction> | ||
): TextPrediction { | ||
return result?.data?.predictions[0]; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import { test } from "@jest/globals"; | ||
import { GoogleVertexAI } from "../googlevertexai.js"; | ||
|
||
test("Test Google Vertex", async () => { | ||
const model = new GoogleVertexAI({ maxOutputTokens: 50 }); | ||
const res = await model.call("1 + 1 = "); | ||
console.log({ res }); | ||
}); | ||
|
||
test("Test Google Vertex generation", async () => { | ||
const model = new GoogleVertexAI({ maxOutputTokens: 50 }); | ||
const res = await model.generate(["1 + 1 = "]); | ||
console.log(JSON.stringify(res, null, 2)); | ||
}); | ||
|
||
test("Test Google Vertex generation", async () => { | ||
const model = new GoogleVertexAI({ maxOutputTokens: 50 }); | ||
const res = await model.generate(["Print hello world."]); | ||
console.log(JSON.stringify(res, null, 2)); | ||
}); | ||
|
||
test("Test Google Vertex generation", async () => { | ||
const model = new GoogleVertexAI({ maxOutputTokens: 50 }); | ||
const res = await model.generate([ | ||
`Translate "I love programming" into Korean.`, | ||
]); | ||
console.log(JSON.stringify(res, null, 2)); | ||
}); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import { BaseLLMParams } from "../llms/index.js"; | ||
|
||
export interface GoogleVertexAIConnectionParams { | ||
/** Hostname for the API call */ | ||
endpoint?: string; | ||
|
||
/** Region where the LLM is stored */ | ||
location?: string; | ||
|
||
/** Model to use */ | ||
model?: string; | ||
} | ||
|
||
export interface GoogleVertexAIModelParams { | ||
/** Sampling temperature to use */ | ||
temperature?: number; | ||
|
||
/** | ||
* Maximum number of tokens to generate in the completion. | ||
*/ | ||
maxOutputTokens?: number; | ||
|
||
/** | ||
* Top-p changes how the model selects tokens for output. | ||
* | ||
* Tokens are selected from most probable to least until the sum | ||
* of their probabilities equals the top-p value. | ||
* | ||
* For example, if tokens A, B, and C have a probability of | ||
* .3, .2, and .1 and the top-p value is .5, then the model will | ||
* select either A or B as the next token (using temperature). | ||
*/ | ||
topP?: number; | ||
|
||
/** | ||
* Top-k changes how the model selects tokens for output. | ||
* | ||
* A top-k of 1 means the selected token is the most probable among | ||
* all tokens in the model’s vocabulary (also called greedy decoding), | ||
* while a top-k of 3 means that the next token is selected from | ||
* among the 3 most probable tokens (using temperature). | ||
*/ | ||
topK?: number; | ||
} | ||
|
||
export interface GoogleVertexAIBaseLLMInput | ||
extends BaseLLMParams, | ||
GoogleVertexAIConnectionParams, | ||
GoogleVertexAIModelParams {} | ||
|
||
export interface GoogleVertexAIBasePrediction { | ||
// eslint-disable-next-line @typescript-eslint/no-explicit-any | ||
safetyAttributes?: any; | ||
} | ||
|
||
export interface GoogleVertexAILLMResponse< | ||
PredictionType extends GoogleVertexAIBasePrediction | ||
> { | ||
data: { | ||
predictions: PredictionType[]; | ||
}; | ||
} |
Oops, something went wrong.