Skip to content

VinF Hybrid Inference: infer expected inputs from prompt (exploration) #8989

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: vaihi-exp-google-ai
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions packages/vertexai/src/methods/chrome-adapter.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ async function toStringArray(

describe('ChromeAdapter', () => {
describe('constructor', () => {
it('sets image as expected input type by default', async () => {
it('determines expected inputs by request inspection', async () => {
const languageModelProvider = {
availability: () => Promise.resolve(Availability.available)
} as LanguageModel;
Expand All @@ -69,7 +69,11 @@ describe('ChromeAdapter', () => {
contents: [
{
role: 'user',
parts: [{ text: 'hi' }]
parts: [
{ text: 'hi' },
// Triggers image as expected type.
{ inlineData: { mimeType: 'image/asd', data: 'asd' } }
]
}
]
});
Expand Down
54 changes: 43 additions & 11 deletions packages/vertexai/src/methods/chrome-adapter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ import {
Availability,
LanguageModel,
LanguageModelCreateOptions,
LanguageModelMessageContent
LanguageModelExpectedInput,
LanguageModelMessageContent,
LanguageModelMessageType
} from '../types/language-model';

/**
Expand All @@ -44,9 +46,7 @@ export class ChromeAdapter {
private languageModelProvider?: LanguageModel,
private mode?: InferenceMode,
private onDeviceParams: LanguageModelCreateOptions = {}
) {
this.addImageTypeAsExpectedInput();
}
) {}

/**
* Checks if a given request can be made on-device.
Expand All @@ -68,8 +68,10 @@ export class ChromeAdapter {
return false;
}

const expectedInputs = ChromeAdapter.extractExpectedInputs(request);

// Triggers out-of-band download so model will eventually become available.
const availability = await this.downloadIfAvailable();
const availability = await this.downloadIfAvailable(expectedInputs);

if (this.mode === 'only_on_device') {
return true;
Expand Down Expand Up @@ -129,6 +131,33 @@ export class ChromeAdapter {
);
}

/**
* Maps
* <a href="https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference#blob">
* Vertex's input mime types</a> to
* <a href="https://github.com/webmachinelearning/prompt-api?tab=readme-ov-file#full-api-surface-in-web-idl">
* Chrome's expected types</a>.
*
* <p>Chrome's API checks availability by type. It's tedious to specify the types in advance, so
* this method infers the types.</p>
*/
private static extractExpectedInputs(
request: GenerateContentRequest
): LanguageModelExpectedInput[] {
const inputSet = new Set<LanguageModelExpectedInput>();
for (const content of request.contents) {
for (const part of content.parts) {
if (part.inlineData) {
const type = part.inlineData.mimeType.split(
'/'
)[0] as LanguageModelMessageType;
inputSet.add({ type });
}
}
}
return Array.from(inputSet);
}

/**
* Asserts inference for the given request can be performed by an on-device model.
*/
Expand Down Expand Up @@ -164,12 +193,20 @@ export class ChromeAdapter {
/**
* Encapsulates logic to get availability and download a model if one is downloadable.
*/
private async downloadIfAvailable(): Promise<Availability | undefined> {
private async downloadIfAvailable(
expectedInputs: LanguageModelExpectedInput[]
): Promise<Availability | undefined> {
// Side-effect: updates construction-time params with request-time params.
// This is required because params are referenced through multiple flows.
Object.assign(this.onDeviceParams, { expectedInputs });

const availability = await this.languageModelProvider?.availability(
this.onDeviceParams
);

if (availability === Availability.downloadable) {
// Side-effect: triggers out-of-band model download.
// This is required because Chrome manages the model download.
this.download();
}

Expand Down Expand Up @@ -252,11 +289,6 @@ export class ChromeAdapter {
return newSession;
}

private addImageTypeAsExpectedInput(): void {
// Defaults to support image inputs for convenience.
this.onDeviceParams.expectedInputs ??= [{ type: 'image' }];
}

/**
* Formats string returned by Chrome as a {@link Response} returned by Vertex.
*/
Expand Down
4 changes: 2 additions & 2 deletions packages/vertexai/src/types/language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ export interface LanguageModelCreateOptions
interface LanguageModelPromptOptions {
signal?: AbortSignal;
}
interface LanguageModelExpectedInput {
export interface LanguageModelExpectedInput {
type: LanguageModelMessageType;
languages?: string[];
}
Expand All @@ -74,7 +74,7 @@ export interface LanguageModelMessageContent {
content: LanguageModelMessageContentValue;
}
type LanguageModelMessageRole = 'system' | 'user' | 'assistant';
type LanguageModelMessageType = 'text' | 'image' | 'audio';
export type LanguageModelMessageType = 'text' | 'image' | 'audio';
type LanguageModelMessageContentValue =
| ImageBitmapSource
| AudioBuffer
Expand Down
Loading