@@ -912,6 +912,23 @@ class BertForSequenceClassification extends BertPreTrainedModel {
912
912
}
913
913
}
914
914
915
+ /**
916
+ * BertForTokenClassification is a class representing a BERT model for token classification.
917
+ * @extends BertPreTrainedModel
918
+ */
919
+ class BertForTokenClassification extends BertPreTrainedModel {
920
+ /**
921
+ * Calls the model on new inputs.
922
+ *
923
+ * @param {Object } model_inputs - The inputs to the model.
924
+ * @returns {Promise<TokenClassifierOutput> } - An object containing the model's output logits for token classification.
925
+ */
926
+ async _call ( model_inputs ) {
927
+ let logits = ( await super . _call ( model_inputs ) ) . logits ;
928
+ return new TokenClassifierOutput ( logits )
929
+ }
930
+ }
931
+
915
932
/**
916
933
* BertForQuestionAnswering is a class representing a BERT model for question answering.
917
934
* @extends BertPreTrainedModel
@@ -944,14 +961,32 @@ class DistilBertForSequenceClassification extends DistilBertPreTrainedModel {
944
961
* Calls the model on new inputs.
945
962
*
946
963
* @param {Object } model_inputs - The inputs to the model.
947
- * @returns {Promise<SequenceClassifierOutput> } - An object containing the model's output logits for question answering .
964
+ * @returns {Promise<SequenceClassifierOutput> } - An object containing the model's output logits for sequence classification .
948
965
*/
949
966
async _call ( model_inputs ) {
950
967
let logits = ( await super . _call ( model_inputs ) ) . logits ;
951
968
return new SequenceClassifierOutput ( logits )
952
969
}
953
970
}
954
971
972
+ /**
973
+ * DistilBertForTokenClassification is a class representing a DistilBERT model for token classification.
974
+ * @extends DistilBertPreTrainedModel
975
+ */
976
+ class DistilBertForTokenClassification extends DistilBertPreTrainedModel {
977
+ /**
978
+ * Calls the model on new inputs.
979
+ *
980
+ * @param {Object } model_inputs - The inputs to the model.
981
+ * @returns {Promise<TokenClassifierOutput> } - An object containing the model's output logits for token classification.
982
+ */
983
+ async _call ( model_inputs ) {
984
+ let logits = ( await super . _call ( model_inputs ) ) . logits ;
985
+ return new TokenClassifierOutput ( logits )
986
+ }
987
+ }
988
+
989
+
955
990
/**
956
991
* DistilBertForQuestionAnswering is a class representing a DistilBERT model for question answering.
957
992
* @extends DistilBertPreTrainedModel
@@ -1197,10 +1232,9 @@ class T5ForConditionalGeneration extends T5PreTrainedModel {
1197
1232
}
1198
1233
1199
1234
/**
1200
- * Runs the beam search for a given beam.
1201
- * @async
1202
- * @param {any } beam - The current beam.
1203
- * @returns {Promise<any> } The model output.
1235
+ * Runs a single step of the beam search generation algorithm.
1236
+ * @param {any } beam - The current beam being generated.
1237
+ * @returns {Promise<any> } - The updated beam after a single generation step.
1204
1238
*/
1205
1239
async runBeam ( beam ) {
1206
1240
return await seq2seqRunBeam ( this , beam ) ;
@@ -1298,11 +1332,10 @@ class MT5ForConditionalGeneration extends MT5PreTrainedModel {
1298
1332
}
1299
1333
1300
1334
/**
1301
- * Runs the given beam through
1302
- * the model and returns the next token prediction.
1303
- * @param {any } beam - The beam to run.
1304
- * @returns {Promise<number> } - A Promise that resolves to the index of the predicted token.
1305
- */
1335
+ * Runs a single step of the beam search generation algorithm.
1336
+ * @param {any } beam - The current beam being generated.
1337
+ * @returns {Promise<any> } - The updated beam after a single generation step.
1338
+ */
1306
1339
async runBeam ( beam ) {
1307
1340
return await seq2seqRunBeam ( this , beam ) ;
1308
1341
}
@@ -1602,9 +1635,9 @@ class WhisperForConditionalGeneration extends WhisperPreTrainedModel {
1602
1635
}
1603
1636
1604
1637
/**
1605
- * Runs a beam for generating outputs .
1606
- * @param {Object } beam - Beam object .
1607
- * @returns {Promise<Object > } Promise object represents the generated outputs for the beam .
1638
+ * Runs a single step of the beam search generation algorithm .
1639
+ * @param {any } beam - The current beam being generated .
1640
+ * @returns {Promise<any > } - The updated beam after a single generation step .
1608
1641
*/
1609
1642
async runBeam ( beam ) {
1610
1643
return await seq2seqRunBeam ( this , beam , {
@@ -1693,10 +1726,9 @@ class VisionEncoderDecoderModel extends PreTrainedModel {
1693
1726
}
1694
1727
1695
1728
/**
1696
- * Generate the next beam step for the given beam.
1697
- *
1698
- * @param {any } beam - The current beam.
1699
- * @returns {Promise<any> } The updated beam with the additional predicted token ID.
1729
+ * Runs a single step of the beam search generation algorithm.
1730
+ * @param {any } beam - The current beam being generated.
1731
+ * @returns {Promise<any> } - The updated beam after a single generation step.
1700
1732
*/
1701
1733
async runBeam ( beam ) {
1702
1734
return seq2seqRunBeam ( this , beam , {
@@ -1792,9 +1824,9 @@ class GPT2LMHeadModel extends GPT2PreTrainedModel {
1792
1824
}
1793
1825
1794
1826
/**
1795
- * Runs beam search for text generation given a beam .
1796
- * @param {any } beam - The Beam object representing the beam .
1797
- * @returns {Promise<any> } A Beam object representing the updated beam after running beam search .
1827
+ * Runs a single step of the beam search generation algorithm .
1828
+ * @param {any } beam - The current beam being generated .
1829
+ * @returns {Promise<any> } - The updated beam after a single generation step .
1798
1830
*/
1799
1831
async runBeam ( beam ) {
1800
1832
return await textgenRunBeam ( this , beam ) ;
@@ -1867,9 +1899,9 @@ class GPTNeoForCausalLM extends GPTNeoPreTrainedModel {
1867
1899
}
1868
1900
1869
1901
/**
1870
- * Runs beam search for text generation given a beam .
1871
- * @param {any } beam - The Beam object representing the beam .
1872
- * @returns {Promise<any> } A Beam object representing the updated beam after running beam search .
1902
+ * Runs a single step of the beam search generation algorithm .
1903
+ * @param {any } beam - The current beam being generated .
1904
+ * @returns {Promise<any> } - The updated beam after a single generation step .
1873
1905
*/
1874
1906
async runBeam ( beam ) {
1875
1907
return await textgenRunBeam ( this , beam ) ;
@@ -1952,9 +1984,9 @@ class CodeGenForCausalLM extends CodeGenPreTrainedModel {
1952
1984
}
1953
1985
1954
1986
/**
1955
- * Runs beam search for text generation given a beam .
1956
- * @param {any } beam - The Beam object representing the beam .
1957
- * @returns {Promise<any> } A Beam object representing the updated beam after running beam search .
1987
+ * Runs a single step of the beam search generation algorithm .
1988
+ * @param {any } beam - The current beam being generated .
1989
+ * @returns {Promise<any> } - The updated beam after a single generation step .
1958
1990
*/
1959
1991
async runBeam ( beam ) {
1960
1992
return await textgenRunBeam ( this , beam ) ;
@@ -2080,8 +2112,9 @@ class MarianMTModel extends MarianPreTrainedModel {
2080
2112
}
2081
2113
2082
2114
/**
2083
- * @param {any } beam
2084
- * @returns {Promise<any> }
2115
+ * Runs a single step of the beam search generation algorithm.
2116
+ * @param {any } beam - The current beam being generated.
2117
+ * @returns {Promise<any> } - The updated beam after a single generation step.
2085
2118
*/
2086
2119
async runBeam ( beam ) {
2087
2120
return await seq2seqRunBeam ( this , beam ) ;
@@ -2203,6 +2236,46 @@ class AutoModelForSequenceClassification {
2203
2236
}
2204
2237
}
2205
2238
2239
+
2240
+ /**
2241
+ * Helper class for loading token classification models from pretrained checkpoints
2242
+ */
2243
+ class AutoModelForTokenClassification {
2244
+
2245
+ static MODEL_CLASS_MAPPING = {
2246
+ 'bert' : BertForTokenClassification ,
2247
+ 'distilbert' : DistilBertForTokenClassification ,
2248
+ }
2249
+
2250
+ /**
2251
+ * Load a token classification model from a pretrained checkpoint
2252
+ * @param {string } modelPath - The path to the model checkpoint directory
2253
+ * @param {function } [progressCallback=null] - An optional callback function to receive progress updates
2254
+ * @returns {Promise<PreTrainedModel> } A promise that resolves to a pre-trained token classification model
2255
+ * @throws {Error } if an unsupported model type is encountered
2256
+ */
2257
+ static async from_pretrained ( modelPath , progressCallback = null ) {
2258
+
2259
+ let [ config , session ] = await Promise . all ( [
2260
+ fetchJSON ( modelPath , 'config.json' , progressCallback ) ,
2261
+ constructSession ( modelPath , 'model.onnx' , progressCallback )
2262
+ ] ) ;
2263
+
2264
+ // Called when all parts are loaded
2265
+ dispatchCallback ( progressCallback , {
2266
+ status : 'loaded' ,
2267
+ name : modelPath
2268
+ } ) ;
2269
+
2270
+ let cls = this . MODEL_CLASS_MAPPING [ config . model_type ] ;
2271
+ if ( ! cls ) {
2272
+ throw Error ( `Unsupported model type: ${ config . model_type } ` )
2273
+ }
2274
+ return new cls ( config , session ) ;
2275
+ }
2276
+ }
2277
+
2278
+
2206
2279
/**
2207
2280
* Class representing an automatic sequence-to-sequence language model.
2208
2281
*/
@@ -2248,7 +2321,7 @@ class AutoModelForCausalLM {
2248
2321
* Loads a pre-trained model from the given path and returns an instance of the appropriate class.
2249
2322
* @param {string } modelPath - The path to the pre-trained model.
2250
2323
* @param {function } [progressCallback=null] - An optional callback function to track the progress of the loading process.
2251
- * @returns {Promise<GPT2LMHeadModel|CodeGenForCausalLM> } An instance of the appropriate class for the loaded model.
2324
+ * @returns {Promise<GPT2LMHeadModel|CodeGenForCausalLM|CodeGenForCausalLM > } An instance of the appropriate class for the loaded model.
2252
2325
* @throws {Error } If the loaded model type is not supported.
2253
2326
*/
2254
2327
static async from_pretrained ( modelPath , progressCallback = null ) {
@@ -2369,7 +2442,7 @@ class AutoModelForVision2Seq {
2369
2442
* Loads a pretrained model from a given path.
2370
2443
* @param {string } modelPath - The path to the pretrained model.
2371
2444
* @param {function } progressCallback - Optional callback function to track progress of the model loading.
2372
- * @returns {Promise<VisionEncoderDecoderModel > } - A Promise that resolves to a new instance of VisionEncoderDecoderModel.
2445
+ * @returns {Promise<PreTrainedModel > } - A Promise that resolves to a new instance of VisionEncoderDecoderModel.
2373
2446
*/
2374
2447
static async from_pretrained ( modelPath , progressCallback = null ) {
2375
2448
@@ -2406,7 +2479,7 @@ class AutoModelForImageClassification {
2406
2479
* Loads a pre-trained image classification model from a given directory path.
2407
2480
* @param {string } modelPath - The path to the directory containing the pre-trained model.
2408
2481
* @param {function } [progressCallback=null] - A callback function to monitor the loading progress.
2409
- * @returns {Promise<ViTForImageClassification > } A Promise that resolves with an instance of the ViTForImageClassification class.
2482
+ * @returns {Promise<PreTrainedModel > } A Promise that resolves with an instance of the ViTForImageClassification class.
2410
2483
* @throws {Error } If the specified model type is not supported.
2411
2484
*/
2412
2485
static async from_pretrained ( modelPath , progressCallback = null ) {
@@ -2441,7 +2514,7 @@ class AutoModelForObjectDetection {
2441
2514
* Loads a pre-trained image classification model from a given directory path.
2442
2515
* @param {string } modelPath - The path to the directory containing the pre-trained model.
2443
2516
* @param {function } [progressCallback=null] - A callback function to monitor the loading progress.
2444
- * @returns {Promise<any > } A Promise that resolves with an instance of the ViTForImageClassification class.
2517
+ * @returns {Promise<PreTrainedModel > } A Promise that resolves with an instance of the ViTForImageClassification class.
2445
2518
* @throws {Error } If the specified model type is not supported.
2446
2519
*/
2447
2520
static async from_pretrained ( modelPath , progressCallback = null ) {
@@ -2491,6 +2564,17 @@ class SequenceClassifierOutput extends ModelOutput {
2491
2564
}
2492
2565
}
2493
2566
2567
+ class TokenClassifierOutput extends ModelOutput {
2568
+ /**
2569
+ * @param {Tensor } logits
2570
+ */
2571
+ constructor ( logits ) {
2572
+ super ( ) ;
2573
+ this . logits = logits ;
2574
+ }
2575
+ }
2576
+
2577
+
2494
2578
class MaskedLMOutput extends ModelOutput {
2495
2579
/**
2496
2580
* @param {Tensor } logits
@@ -2517,6 +2601,7 @@ module.exports = {
2517
2601
AutoModel,
2518
2602
AutoModelForSeq2SeqLM,
2519
2603
AutoModelForSequenceClassification,
2604
+ AutoModelForTokenClassification,
2520
2605
AutoModelForCausalLM,
2521
2606
AutoModelForMaskedLM,
2522
2607
AutoModelForQuestionAnswering,
0 commit comments