diff --git a/src/models.js b/src/models.js index 57de2c689f1fd5..0ed9b7b697e2fa 100644 --- a/src/models.js +++ b/src/models.js @@ -4,6 +4,7 @@ const { fetchJSON, dispatchCallback, isIntegralNumber, + exists, } = require("./utils.js"); const { @@ -80,9 +81,25 @@ function boolTensor(value) { } // JS doesn't support mixings, so we define some reused functions here, and allow "this" to be passed in + +async function seq2seqLoadModel(modelPath, progressCallback) { + let info = await Promise.all([ + fetchJSON(modelPath, 'config.json', progressCallback), + constructSession(modelPath, 'encoder_model.onnx', progressCallback), + constructSession(modelPath, 'decoder_model_merged.onnx', progressCallback), + fetchJSON(modelPath, 'generation_config.json', progressCallback), + ]) + + // Called when all parts are loaded + dispatchCallback(progressCallback, { + status: 'loaded', + name: modelPath + }); + + return info; +} async function seq2seq_forward(self, model_inputs, { encoder_input_name = 'input_ids', - encoder_attention_mask = true, add_decoder_pkv = true } = {}) { let encoderOutputs = model_inputs.encoder_outputs; @@ -92,7 +109,8 @@ async function seq2seq_forward(self, model_inputs, { const encoderFeeds = { [encoder_input_name]: model_inputs[encoder_input_name], } - if (encoder_attention_mask) { + + if (self.session.inputNames.includes('attention_mask')) { encoderFeeds.attention_mask = model_inputs.attention_mask } const encoderResults = await sessionRun(self.session, encoderFeeds); @@ -103,7 +121,8 @@ async function seq2seq_forward(self, model_inputs, { encoder_hidden_states: encoderOutputs, use_cache_branch: boolTensor(pastKeyValues !== null) }; - if (encoder_attention_mask) { + + if (self.decoder_merged_session.inputNames.includes('encoder_attention_mask')) { decoderFeeds.encoder_attention_mask = model_inputs.attention_mask } self.addPastKeyValues(decoderFeeds, pastKeyValues, add_decoder_pkv); @@ -255,282 +274,6 @@ function textgenUpdatebeam(beam, newTokenId) { } ////////////////////////////////////////////////// -////////////////////////////////////////////////// -// AutoModels, used to simplify construction of PreTrainedModels -// (uses config to instantiate correct class) -class AutoModel { - // Helper class to determine model type from config - - static async from_pretrained(modelPath, progressCallback = null) { - - let config = await fetchJSON(modelPath, 'config.json', progressCallback); - let modelName = config.is_encoder_decoder ? 'encoder_model.onnx' : 'model.onnx'; - - let session = await constructSession(modelPath, modelName, progressCallback); - - // Called when all parts are loaded - dispatchCallback(progressCallback, { - status: 'loaded', - name: modelPath - }); - - switch (config.model_type) { - case 'bert': - return new BertModel(config, session); - case 'albert': - return new AlbertModel(config, session); - case 'distilbert': - return new DistilBertModel(config, session); - case 't5': - return new T5Model(config, session); - case 'gpt2': - return new GPT2Model(config, session); - case 'codegen': - return new CodeGenModel(config, session); - case 'bart': - return new BartModel(config, session); - case 'roberta': - return new RobertaModel(config, session); - case 'whisper': - return new WhisperModel(config, session); - case 'clip': - return new CLIPModel(config, session); - - default: - console.warn(`Unknown model class "${config.model_type}", attempting to construct from base class.`); - return new PreTrainedModel(config, session); - } - } -} - -class AutoModelForSequenceClassification { - - static async from_pretrained(modelPath, progressCallback = null) { - - let [config, session] = await Promise.all([ - fetchJSON(modelPath, 'config.json', progressCallback), - constructSession(modelPath, 'model.onnx', progressCallback) - ]); - - // Called when all parts are loaded - dispatchCallback(progressCallback, { - status: 'loaded', - name: modelPath - }); - - switch (config.model_type) { - case 'bert': - return new BertForSequenceClassification(config, session); - case 'albert': - return new AlbertForSequenceClassification(config, session); - case 'distilbert': - return new DistilBertForSequenceClassification(config, session); - case 'roberta': - return new RobertaForSequenceClassification(config, session); - - default: - throw Error(`Unsupported model type: ${config.model_type}`) - } - } -} - -class AutoModelForSeq2SeqLM { - static async from_pretrained(modelPath, progressCallback = null) { - - let [config, session, decoder_merged_session] = await Promise.all([ - fetchJSON(modelPath, 'config.json', progressCallback), - constructSession(modelPath, 'encoder_model.onnx', progressCallback), - constructSession(modelPath, 'decoder_model_merged.onnx', progressCallback) - ]) - - // Called when all parts are loaded - dispatchCallback(progressCallback, { - status: 'loaded', - name: modelPath - }); - - switch (config.model_type) { - case 't5': - return new T5ForConditionalGeneration( - config, - session, - decoder_merged_session - ); - case 'bart': - return new BartForConditionalGeneration( - config, - session, - decoder_merged_session - ); - case 'whisper': - return new WhisperForConditionalGeneration( - config, - session, - decoder_merged_session - ) - default: - throw Error(`Unsupported model type: ${config.model_type}`) - } - } -} - -class AutoModelForCausalLM { - static async from_pretrained(modelPath, progressCallback = null) { - - let [config, session] = await Promise.all([ - fetchJSON(modelPath, 'config.json', progressCallback), - constructSession(modelPath, 'decoder_model_merged.onnx', progressCallback) - ]) - - // Called when all parts are loaded - dispatchCallback(progressCallback, { - status: 'loaded', - name: modelPath - }); - - switch (config.model_type) { - case 'gpt2': - return new GPT2LMHeadModel( - config, - session - ); - - case 'codegen': - return new CodeGenForCausalLM( - config, - session - ) - - default: - throw Error(`Unsupported model type: ${config.model_type}`) - } - } -} - -class AutoModelForMaskedLM { - - static async from_pretrained(modelPath, progressCallback = null) { - - let config = await fetchJSON(modelPath, 'config.json', progressCallback); - let modelName = config.is_encoder_decoder ? 'encoder_model.onnx' : 'model.onnx'; - - let session = await constructSession(modelPath, modelName, progressCallback); - - // Called when all parts are loaded - dispatchCallback(progressCallback, { - status: 'loaded', - name: modelPath - }); - - switch (config.model_type) { - case 'bert': - return new BertForMaskedLM(config, session); - case 'albert': - return new AlbertForMaskedLM(config, session); - case 'distilbert': - return new DistilBertForMaskedLM(config, session); - case 'roberta': - return new RobertaForMaskedLM(config, session); - - default: - console.warn(`Unknown model class "${config.model_type}", attempting to construct from base class.`); - return new PreTrainedModel(config, session); - } - } -} - - -class AutoModelForQuestionAnswering { - - static async from_pretrained(modelPath, progressCallback = null) { - - let [config, session] = await Promise.all([ - fetchJSON(modelPath, 'config.json', progressCallback), - constructSession(modelPath, 'model.onnx', progressCallback) - ]); - - // Called when all parts are loaded - dispatchCallback(progressCallback, { - status: 'loaded', - name: modelPath - }); - - switch (config.model_type) { - case 'bert': - return new BertForQuestionAnswering(config, session); - case 'albert': - return new AlbertForQuestionAnswering(config, session); - case 'distilbert': - return new DistilBertForQuestionAnswering(config, session); - case 'roberta': - return new RobertaForQuestionAnswering(config, session); - - default: - throw Error(`Unsupported model type: ${config.model_type}`) - } - } -} - -class AutoModelForVision2Seq { - static async from_pretrained(modelPath, progressCallback = null) { - - let [config, session, decoder_merged_session] = await Promise.all([ - fetchJSON(modelPath, 'config.json', progressCallback), - constructSession(modelPath, 'encoder_model.onnx', progressCallback), - constructSession(modelPath, 'decoder_model_merged.onnx', progressCallback) - ]) - - // Called when all parts are loaded - dispatchCallback(progressCallback, { - status: 'loaded', - name: modelPath - }); - - switch (config.model_type) { - case 'vision-encoder-decoder': - return new VisionEncoderDecoderModel( - config, - session, - decoder_merged_session - ); - default: - throw Error(`Unsupported model type: ${config.model_type}`) - } - } -} - -class AutoModelForImageClassification { - static async from_pretrained(modelPath, progressCallback = null) { - - let [config, session] = await Promise.all([ - fetchJSON(modelPath, 'config.json', progressCallback), - constructSession(modelPath, 'model.onnx', progressCallback), - ]) - - // Called when all parts are loaded - dispatchCallback(progressCallback, { - status: 'loaded', - name: modelPath - }); - - switch (config.model_type) { - case 'vit': - return new ViTForImageClassification( - config, - session, - ); - default: - throw Error(`Unsupported model type: ${config.model_type}`) - } - } - -} - -////////////////////////////////////////////////// - - - - ////////////////////////////////////////////////// // Base class class PreTrainedModel extends Callable { @@ -653,20 +396,8 @@ class PreTrainedModel extends Callable { } let output = await this.runBeam(beam); + this.applyLogitsProcessors(output.logits, beam.output_token_ids.length); - // Apply logits processor to each item in the batch: - for (let batch of output.logits) { - // NOTE: In future, generalise this - - let map = this.forced_decoder_ids_mapping[beam.output_token_ids.length]; - if (map !== undefined) { // There exists a mapping - // NOTE: - // - modifications affect original data - // - logits are of the shape [1, vocabSize] - batch.data.fill(-Infinity) - batch.data[map] = 0; - } - } let sampledTokens = sampler(output.logits); for (let [newTokenId, logProb] of sampledTokens) { @@ -768,6 +499,27 @@ class PreTrainedModel extends Callable { } } + applyLogitsProcessors(logits, index) { + // Apply logits processor to each item in the batch: + for (let batch of logits) { + // NOTE: In future, generalise this + // - modifications affect original data + // - logits are of the shape [1, vocabSize] + let map = this.forced_decoder_ids_mapping[index]; + if (exists(map)) { // There exists a mapping + batch.data.fill(-Infinity) + batch.data[map] = 0; + } + + if (exists(this.generation_config) && exists(this.generation_config.forced_bos_token_id) && index === 1) { + batch.data.fill(-Infinity) + batch.data[this.generation_config.forced_bos_token_id] = 0; + } + + } + } + + } ////////////////////////////////////////////////// @@ -856,10 +608,10 @@ class T5Model extends T5PreTrainedModel { } class T5ForConditionalGeneration extends T5PreTrainedModel { - constructor(config, session, decoder_merged_session) { + constructor(config, session, decoder_merged_session, generation_config) { super(config, session); this.decoder_merged_session = decoder_merged_session; - + this.generation_config = generation_config; this.num_decoder_layers = this.config.num_decoder_layers; this.num_decoder_heads = this.config.num_heads; @@ -871,20 +623,8 @@ class T5ForConditionalGeneration extends T5PreTrainedModel { } static async from_pretrained(modelPath, progressCallback = null) { - - let [config, session, decoder_merged_session] = await Promise.all([ - fetchJSON(modelPath, 'config.json', progressCallback), - constructSession(modelPath, 'encoder_model.onnx', progressCallback), - constructSession(modelPath, 'decoder_model_merged.onnx', progressCallback), - ]) - - // Called when all parts are loaded - dispatchCallback(progressCallback, { - status: 'loaded', - name: modelPath - }); - - return new this(config, session, decoder_merged_session); + let info = await seq2seqLoadModel(modelPath, progressCallback); + return new this(...info); } getStartBeams(inputs, numOutputTokens, ...args) { @@ -917,10 +657,10 @@ class BartModel extends BartPretrainedModel { } class BartForConditionalGeneration extends BartPretrainedModel { - constructor(config, session, decoder_merged_session) { + constructor(config, session, decoder_merged_session, generation_config) { super(config, session); this.decoder_merged_session = decoder_merged_session; - + this.generation_config = generation_config; this.num_decoder_layers = this.config.decoder_layers; this.num_decoder_heads = this.config.decoder_attention_heads; @@ -931,22 +671,9 @@ class BartForConditionalGeneration extends BartPretrainedModel { this.encoder_dim_kv = this.config.d_model / this.num_encoder_heads; } - static async from_pretrained(modelPath, progressCallback = null) { - // TODO remove duplication between here and t5 - - let [config, session, decoder_merged_session] = await Promise.all([ - fetchJSON(modelPath, 'config.json', progressCallback), - constructSession(modelPath, 'encoder_model.onnx', progressCallback), - constructSession(modelPath, 'decoder_merged_session.onnx', progressCallback), - ]) - - // Called when all parts are loaded - dispatchCallback(progressCallback, { - status: 'loaded', - name: modelPath - }); - - return new this(config, session, decoder_merged_session); + static async from_pretrained(modelPath, progressCallback = null) { + let info = await seq2seqLoadModel(modelPath, progressCallback); + return new this(...info); } getStartBeams(inputs, numOutputTokens, ...args) { @@ -1004,9 +731,10 @@ class WhisperModel extends WhisperPreTrainedModel { } class WhisperForConditionalGeneration extends WhisperPreTrainedModel { - constructor(config, session, decoder_merged_session) { + constructor(config, session, decoder_merged_session, generation_config) { super(config, session); this.decoder_merged_session = decoder_merged_session; + this.generation_config = generation_config; this.num_decoder_layers = this.config.decoder_layers; this.num_decoder_heads = this.config.decoder_attention_heads; @@ -1018,20 +746,8 @@ class WhisperForConditionalGeneration extends WhisperPreTrainedModel { } static async from_pretrained(modelPath, progressCallback = null) { - - let [config, session, decoder_merged_session] = await Promise.all([ - fetchJSON(modelPath, 'config.json', progressCallback), - constructSession(modelPath, 'encoder_model.onnx', progressCallback), - constructSession(modelPath, 'decoder_merged_session.onnx', progressCallback), - ]) - - // Called when all parts are loaded - dispatchCallback(progressCallback, { - status: 'loaded', - name: modelPath - }); - - return new this(config, session, decoder_merged_session); + let info = await seq2seqLoadModel(modelPath, progressCallback); + return new this(...info); } getStartBeams(inputTokenIds, numOutputTokens, ...args) { @@ -1051,7 +767,6 @@ class WhisperForConditionalGeneration extends WhisperPreTrainedModel { async forward(model_inputs) { return await seq2seq_forward(this, model_inputs, { encoder_input_name: 'input_features', - encoder_attention_mask: false }); } } @@ -1101,7 +816,6 @@ class VisionEncoderDecoderModel extends PreTrainedModel { async forward(model_inputs) { return await seq2seq_forward(this, model_inputs, { encoder_input_name: 'pixel_values', - encoder_attention_mask: false, add_decoder_pkv: false }) } @@ -1218,6 +932,255 @@ class ViTForImageClassification extends PreTrainedModel { ////////////////////////////////////////////////// +////////////////////////////////////////////////// +// AutoModels, used to simplify construction of PreTrainedModels +// (uses config to instantiate correct class) +class AutoModel { + // Helper class to determine model type from config + + static async from_pretrained(modelPath, progressCallback = null) { + + let config = await fetchJSON(modelPath, 'config.json', progressCallback); + let modelName = config.is_encoder_decoder ? 'encoder_model.onnx' : 'model.onnx'; + + let session = await constructSession(modelPath, modelName, progressCallback); + + // Called when all parts are loaded + dispatchCallback(progressCallback, { + status: 'loaded', + name: modelPath + }); + + switch (config.model_type) { + case 'bert': + return new BertModel(config, session); + case 'albert': + return new AlbertModel(config, session); + case 'distilbert': + return new DistilBertModel(config, session); + case 't5': + return new T5Model(config, session); + case 'gpt2': + return new GPT2Model(config, session); + case 'codegen': + return new CodeGenModel(config, session); + case 'bart': + return new BartModel(config, session); + case 'roberta': + return new RobertaModel(config, session); + case 'whisper': + return new WhisperModel(config, session); + case 'clip': + return new CLIPModel(config, session); + + default: + console.warn(`Unknown model class "${config.model_type}", attempting to construct from base class.`); + return new PreTrainedModel(config, session); + } + } +} + +class AutoModelForSequenceClassification { + + static async from_pretrained(modelPath, progressCallback = null) { + + let [config, session] = await Promise.all([ + fetchJSON(modelPath, 'config.json', progressCallback), + constructSession(modelPath, 'model.onnx', progressCallback) + ]); + + // Called when all parts are loaded + dispatchCallback(progressCallback, { + status: 'loaded', + name: modelPath + }); + + switch (config.model_type) { + case 'bert': + return new BertForSequenceClassification(config, session); + case 'albert': + return new AlbertForSequenceClassification(config, session); + case 'distilbert': + return new DistilBertForSequenceClassification(config, session); + case 'roberta': + return new RobertaForSequenceClassification(config, session); + + default: + throw Error(`Unsupported model type: ${config.model_type}`) + } + } +} + +class AutoModelForSeq2SeqLM { + static modelClassMapping = { + 't5': T5ForConditionalGeneration, + 'bart': BartForConditionalGeneration, + 'whisper': WhisperForConditionalGeneration, + } + static async from_pretrained(modelPath, progressCallback = null) { + let info = await seq2seqLoadModel(modelPath, progressCallback); + let config = info[0]; + let cls = this.modelClassMapping[config.model_type]; + if (!cls) { + throw Error(`Unsupported model type: ${config.model_type}`) + } + return new cls(...info) + } +} + +class AutoModelForCausalLM { + static async from_pretrained(modelPath, progressCallback = null) { + + let [config, session] = await Promise.all([ + fetchJSON(modelPath, 'config.json', progressCallback), + constructSession(modelPath, 'decoder_model_merged.onnx', progressCallback) + ]) + + // Called when all parts are loaded + dispatchCallback(progressCallback, { + status: 'loaded', + name: modelPath + }); + + switch (config.model_type) { + case 'gpt2': + return new GPT2LMHeadModel( + config, + session + ); + + case 'codegen': + return new CodeGenForCausalLM( + config, + session + ) + + default: + throw Error(`Unsupported model type: ${config.model_type}`) + } + } +} + +class AutoModelForMaskedLM { + + static async from_pretrained(modelPath, progressCallback = null) { + + let config = await fetchJSON(modelPath, 'config.json', progressCallback); + let modelName = config.is_encoder_decoder ? 'encoder_model.onnx' : 'model.onnx'; + + let session = await constructSession(modelPath, modelName, progressCallback); + + // Called when all parts are loaded + dispatchCallback(progressCallback, { + status: 'loaded', + name: modelPath + }); + + switch (config.model_type) { + case 'bert': + return new BertForMaskedLM(config, session); + case 'albert': + return new AlbertForMaskedLM(config, session); + case 'distilbert': + return new DistilBertForMaskedLM(config, session); + case 'roberta': + return new RobertaForMaskedLM(config, session); + + default: + console.warn(`Unknown model class "${config.model_type}", attempting to construct from base class.`); + return new PreTrainedModel(config, session); + } + } +} + +class AutoModelForQuestionAnswering { + + static async from_pretrained(modelPath, progressCallback = null) { + + let [config, session] = await Promise.all([ + fetchJSON(modelPath, 'config.json', progressCallback), + constructSession(modelPath, 'model.onnx', progressCallback) + ]); + + // Called when all parts are loaded + dispatchCallback(progressCallback, { + status: 'loaded', + name: modelPath + }); + + switch (config.model_type) { + case 'bert': + return new BertForQuestionAnswering(config, session); + case 'albert': + return new AlbertForQuestionAnswering(config, session); + case 'distilbert': + return new DistilBertForQuestionAnswering(config, session); + case 'roberta': + return new RobertaForQuestionAnswering(config, session); + + default: + throw Error(`Unsupported model type: ${config.model_type}`) + } + } +} + +class AutoModelForVision2Seq { + static async from_pretrained(modelPath, progressCallback = null) { + + let [config, session, decoder_merged_session] = await Promise.all([ + fetchJSON(modelPath, 'config.json', progressCallback), + constructSession(modelPath, 'encoder_model.onnx', progressCallback), + constructSession(modelPath, 'decoder_model_merged.onnx', progressCallback) + ]) + + // Called when all parts are loaded + dispatchCallback(progressCallback, { + status: 'loaded', + name: modelPath + }); + + switch (config.model_type) { + case 'vision-encoder-decoder': + return new VisionEncoderDecoderModel( + config, + session, + decoder_merged_session + ); + default: + throw Error(`Unsupported model type: ${config.model_type}`) + } + } +} + +class AutoModelForImageClassification { + static async from_pretrained(modelPath, progressCallback = null) { + + let [config, session] = await Promise.all([ + fetchJSON(modelPath, 'config.json', progressCallback), + constructSession(modelPath, 'model.onnx', progressCallback), + ]) + + // Called when all parts are loaded + dispatchCallback(progressCallback, { + status: 'loaded', + name: modelPath + }); + + switch (config.model_type) { + case 'vit': + return new ViTForImageClassification( + config, + session, + ); + default: + throw Error(`Unsupported model type: ${config.model_type}`) + } + } + +} +////////////////////////////////////////////////// + +////////////////////////////////////////////////// class Seq2SeqLMOutput { constructor(logits, past_key_values, encoder_outputs) { this.logits = logits;