Skip to content

AI Hybrid Inference: extract expected inputs from prompt #8989

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: firebase-ai-hybridinference
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/ai/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
"@firebase/component": "0.6.14",
"@firebase/logger": "0.4.4",
"@firebase/util": "1.11.1",
"deepmerge": "4.3.1",
"tslib": "^2.1.0"
},
"license": "Apache-2.0",
Expand Down
78 changes: 41 additions & 37 deletions packages/ai/src/methods/chrome-adapter.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,30 +54,6 @@ async function toStringArray(

describe('ChromeAdapter', () => {
describe('constructor', () => {
it('sets image as expected input type by default', async () => {
const languageModelProvider = {
availability: () => Promise.resolve(Availability.available)
} as LanguageModel;
const availabilityStub = stub(
languageModelProvider,
'availability'
).resolves(Availability.available);
const adapter = new ChromeAdapter(
languageModelProvider,
'prefer_on_device'
);
await adapter.isAvailable({
contents: [
{
role: 'user',
parts: [{ text: 'hi' }]
}
]
});
expect(availabilityStub).to.have.been.calledWith({
expectedInputs: [{ type: 'image' }]
});
});
it('honors explicitly set expected inputs', async () => {
const languageModelProvider = {
availability: () => Promise.resolve(Availability.available)
Expand Down Expand Up @@ -299,6 +275,39 @@ describe('ChromeAdapter', () => {
})
).to.be.false;
});
it('extracts and merges expected inputs from the request', async () => {
const languageModelProvider = {
availability: () => Promise.resolve(Availability.available)
} as LanguageModel;
const availabilityStub = stub(
languageModelProvider,
'availability'
).resolves(Availability.available);
const adapter = new ChromeAdapter(
languageModelProvider,
'prefer_on_device',
{
createOptions: {
expectedInputs: [{ type: 'text' }]
}
}
);
await adapter.isAvailable({
contents: [
{
role: 'user',
parts: [
{ text: 'hi' },
// Triggers image as expected type.
{ inlineData: { mimeType: 'image/jpeg', data: 'asd' } }
]
}
]
});
expect(availabilityStub).to.have.been.calledWith({
expectedInputs: [{ type: 'text' }, { type: 'image' }]
});
});
});
describe('generateContent', () => {
it('throws if Chrome API is undefined', async () => {
Expand Down Expand Up @@ -378,14 +387,9 @@ describe('ChromeAdapter', () => {
);
const promptOutput = 'hi';
const promptStub = stub(languageModel, 'prompt').resolves(promptOutput);
const createOptions = {
systemPrompt: 'be yourself',
expectedInputs: [{ type: 'image' }]
} as LanguageModelCreateOptions;
const adapter = new ChromeAdapter(
languageModelProvider,
'prefer_on_device',
{ createOptions }
'prefer_on_device'
);
const request = {
contents: [
Expand All @@ -405,7 +409,9 @@ describe('ChromeAdapter', () => {
} as GenerateContentRequest;
const response = await adapter.generateContent(request);
// Asserts initialization params are proxied.
expect(createStub).to.have.been.calledOnceWith(createOptions);
expect(createStub).to.have.been.calledOnceWith({
expectedInputs: [{ type: 'image' }]
});
// Asserts Vertex input type is mapped to Chrome type.
expect(promptStub).to.have.been.calledOnceWith([
{
Expand Down Expand Up @@ -606,13 +612,9 @@ describe('ChromeAdapter', () => {
}
})
);
const createOptions = {
expectedInputs: [{ type: 'image' }]
} as LanguageModelCreateOptions;
const adapter = new ChromeAdapter(
languageModelProvider,
'prefer_on_device',
{ createOptions }
'prefer_on_device'
);
const request = {
contents: [
Expand All @@ -631,7 +633,9 @@ describe('ChromeAdapter', () => {
]
} as GenerateContentRequest;
const response = await adapter.generateContentStream(request);
expect(createStub).to.have.been.calledOnceWith(createOptions);
expect(createStub).to.have.been.calledOnceWith({
expectedInputs: [{ type: 'image' }]
});
expect(promptStub).to.have.been.calledOnceWith([
{
role: request.contents[0].role,
Expand Down
92 changes: 72 additions & 20 deletions packages/ai/src/methods/chrome-adapter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,14 @@ import {
import {
Availability,
LanguageModel,
LanguageModelCreateOptions,
LanguageModelExpected,
LanguageModelMessage,
LanguageModelMessageContent,
LanguageModelMessageRole
LanguageModelMessageRole,
LanguageModelMessageType
} from '../types/language-model';
import deepMerge from 'deepmerge';

/**
* Defines an inference "backend" that uses Chrome's on-device model,
Expand All @@ -48,12 +52,7 @@ export class ChromeAdapter {
constructor(
private languageModelProvider?: LanguageModel,
private mode?: InferenceMode,
private onDeviceParams: OnDeviceParams = {
createOptions: {
// Defaults to support image inputs for convenience.
expectedInputs: [{ type: 'image' }]
}
}
private onDeviceParams: OnDeviceParams = {}
) {}

/**
Expand Down Expand Up @@ -85,8 +84,11 @@ export class ChromeAdapter {
return false;
}

const extractedOptions = this.extractCreateOptions(request);
const mergedOptions = this.mergeCreateOptions(extractedOptions);

// Triggers out-of-band download so model will eventually become available.
const availability = await this.downloadIfAvailable();
const availability = await this.downloadIfAvailable(mergedOptions);

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I now see the same error in Edge that we used to have in Chrome when we were migrated to the new type: the LanguageModel.prompt method exists, but throws an unsupported input error with the new rich type. Edge Canary works. I'll take a look at detecting the Edge version in isAvailable.

if (this.mode === 'only_on_device') {
return true;
Expand Down Expand Up @@ -118,7 +120,9 @@ export class ChromeAdapter {
* @returns {@link Response}, so we can reuse common response formatting.
*/
async generateContent(request: GenerateContentRequest): Promise<Response> {
const session = await this.createSession();
const extractedOptions = this.extractCreateOptions(request);
const mergedOptions = this.mergeCreateOptions(extractedOptions);
const session = await this.createSession(mergedOptions);
const contents = await Promise.all(
request.contents.map(ChromeAdapter.toLanguageModelMessage)
);
Expand All @@ -140,7 +144,9 @@ export class ChromeAdapter {
async generateContentStream(
request: GenerateContentRequest
): Promise<Response> {
const session = await this.createSession();
const extractedOptions = this.extractCreateOptions(request);
const mergedOptions = this.mergeCreateOptions(extractedOptions);
const session = await this.createSession(mergedOptions);
const contents = await Promise.all(
request.contents.map(ChromeAdapter.toLanguageModelMessage)
);
Expand All @@ -158,6 +164,48 @@ export class ChromeAdapter {
);
}

/**
* Extracts session creation options specified at request-time.
*
* <p>In particular, this method maps
* <a href="https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference#blob">
* Vertex's input mime types</a> to
* <a href="https://github.com/webmachinelearning/prompt-api?tab=readme-ov-file#full-api-surface-in-web-idl">
* Chrome's expected input types</a>.</p>
*
* <p>Chrome's API checks availability by type. It's tedious to specify the types in advance, so
* this method infers the types.</p>
*/
private extractCreateOptions(
request: GenerateContentRequest
): LanguageModelCreateOptions {
const inputSet = new Set<LanguageModelExpected>();
for (const content of request.contents) {
for (const part of content.parts) {
if (part.inlineData) {
const type = part.inlineData.mimeType.split(
'/'
)[0] as LanguageModelMessageType;
inputSet.add({ type });
}
}
}

return {
expectedInputs: Array.from(inputSet)
};
}

/**
* Assembles a unified {@link LanguageModelCreateOptions} from create- and request-time options.
* Request-time options take priority over create-time options.
*/
private mergeCreateOptions(
requestOptions: LanguageModelCreateOptions
): LanguageModelCreateOptions {
return deepMerge(this.onDeviceParams.createOptions || {}, requestOptions);
}

/**
* Asserts inference for the given request can be performed by an on-device model.
*/
Expand Down Expand Up @@ -196,13 +244,17 @@ export class ChromeAdapter {
/**
* Encapsulates logic to get availability and download a model if one is downloadable.
*/
private async downloadIfAvailable(): Promise<Availability | undefined> {
private async downloadIfAvailable(
createOptions: LanguageModelCreateOptions
): Promise<Availability | undefined> {
const availability = await this.languageModelProvider?.availability(
this.onDeviceParams.createOptions
createOptions
);

if (availability === Availability.downloadable) {
this.download();
// Side-effect: triggers out-of-band model download.
// This is required because Chrome manages the model download.
this.download(createOptions);
}

return availability;
Expand All @@ -212,18 +264,18 @@ export class ChromeAdapter {
* Triggers out-of-band download of an on-device model.
*
* <p>Chrome only downloads models as needed. Chrome knows a model is needed when code calls
* LanguageModel.create.</p>
* {@link LanguageModel.create}.</p>
*
* <p>Since Chrome manages the download, the SDK can only avoid redundant download requests by
* tracking if a download has previously been requested.</p>
*/
private download(): void {
private download(createOptions: LanguageModelCreateOptions): void {
if (this.isDownloading) {
return;
}
this.isDownloading = true;
this.downloadPromise = this.languageModelProvider
?.create(this.onDeviceParams.createOptions)
?.create(createOptions)
.then(() => {
this.isDownloading = false;
});
Expand Down Expand Up @@ -291,16 +343,16 @@ export class ChromeAdapter {
* <p>Chrome will remove a model from memory if it's no longer in use, so this method ensures a
* new session is created before an old session is destroyed.</p>
*/
private async createSession(): Promise<LanguageModel> {
private async createSession(
createOptions: LanguageModelCreateOptions
): Promise<LanguageModel> {
if (!this.languageModelProvider) {
throw new AIError(
AIErrorCode.REQUEST_ERROR,
'Chrome AI requested for unsupported browser version.'
);
}
const newSession = await this.languageModelProvider.create(
this.onDeviceParams.createOptions
);
const newSession = await this.languageModelProvider.create(createOptions);
if (this.oldSession) {
this.oldSession.destroy();
}
Expand Down
2 changes: 1 addition & 1 deletion yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -6250,7 +6250,7 @@ deep-is@^0.1.3:
resolved "https://registry.npmjs.org/deep-is/-/deep-is-0.1.4.tgz#a6f2dce612fadd2ef1f519b73551f17e85199831"
integrity sha512-oIPzksmTg4/MriiaYGO+okXDT7ztn/w3Eptv/+gSIdMdKsJo0u4CfYNFJPy+4SKMuCqGw2wxnA+URMg3t8a/bQ==

deepmerge@^4.2.2:
deepmerge@4.3.1, deepmerge@^4.2.2:
version "4.3.1"
resolved "https://registry.npmjs.org/deepmerge/-/deepmerge-4.3.1.tgz#44b5f2147cd3b00d4b56137685966f26fd25dd4a"
integrity sha512-3sUqbMEc77XqpdNO7FRyRog+eW3ph+GYCbj+rK+uYyRMuwsVy0rMiVtPn+QJlKFvWP/1PYpapqYn0Me2knFn+A==
Expand Down
Loading