Skip to content

Commit 7b37cdc

Browse files
committed
Add support for token classification (e.g., named entity recognition)
1 parent 7298c1d commit 7b37cdc

File tree

3 files changed

+193
-33
lines changed

3 files changed

+193
-33
lines changed

src/models.js

Lines changed: 117 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -912,6 +912,23 @@ class BertForSequenceClassification extends BertPreTrainedModel {
912912
}
913913
}
914914

915+
/**
916+
* BertForTokenClassification is a class representing a BERT model for token classification.
917+
* @extends BertPreTrainedModel
918+
*/
919+
class BertForTokenClassification extends BertPreTrainedModel {
920+
/**
921+
* Calls the model on new inputs.
922+
*
923+
* @param {Object} model_inputs - The inputs to the model.
924+
* @returns {Promise<TokenClassifierOutput>} - An object containing the model's output logits for token classification.
925+
*/
926+
async _call(model_inputs) {
927+
let logits = (await super._call(model_inputs)).logits;
928+
return new TokenClassifierOutput(logits)
929+
}
930+
}
931+
915932
/**
916933
* BertForQuestionAnswering is a class representing a BERT model for question answering.
917934
* @extends BertPreTrainedModel
@@ -944,14 +961,32 @@ class DistilBertForSequenceClassification extends DistilBertPreTrainedModel {
944961
* Calls the model on new inputs.
945962
*
946963
* @param {Object} model_inputs - The inputs to the model.
947-
* @returns {Promise<SequenceClassifierOutput>} - An object containing the model's output logits for question answering.
964+
* @returns {Promise<SequenceClassifierOutput>} - An object containing the model's output logits for sequence classification.
948965
*/
949966
async _call(model_inputs) {
950967
let logits = (await super._call(model_inputs)).logits;
951968
return new SequenceClassifierOutput(logits)
952969
}
953970
}
954971

972+
/**
973+
* DistilBertForTokenClassification is a class representing a DistilBERT model for token classification.
974+
* @extends DistilBertPreTrainedModel
975+
*/
976+
class DistilBertForTokenClassification extends DistilBertPreTrainedModel {
977+
/**
978+
* Calls the model on new inputs.
979+
*
980+
* @param {Object} model_inputs - The inputs to the model.
981+
* @returns {Promise<TokenClassifierOutput>} - An object containing the model's output logits for token classification.
982+
*/
983+
async _call(model_inputs) {
984+
let logits = (await super._call(model_inputs)).logits;
985+
return new TokenClassifierOutput(logits)
986+
}
987+
}
988+
989+
955990
/**
956991
* DistilBertForQuestionAnswering is a class representing a DistilBERT model for question answering.
957992
* @extends DistilBertPreTrainedModel
@@ -1197,10 +1232,9 @@ class T5ForConditionalGeneration extends T5PreTrainedModel {
11971232
}
11981233

11991234
/**
1200-
* Runs the beam search for a given beam.
1201-
* @async
1202-
* @param {any} beam - The current beam.
1203-
* @returns {Promise<any>} The model output.
1235+
* Runs a single step of the beam search generation algorithm.
1236+
* @param {any} beam - The current beam being generated.
1237+
* @returns {Promise<any>} - The updated beam after a single generation step.
12041238
*/
12051239
async runBeam(beam) {
12061240
return await seq2seqRunBeam(this, beam);
@@ -1298,11 +1332,10 @@ class MT5ForConditionalGeneration extends MT5PreTrainedModel {
12981332
}
12991333

13001334
/**
1301-
* Runs the given beam through
1302-
* the model and returns the next token prediction.
1303-
* @param {any} beam - The beam to run.
1304-
* @returns {Promise<number>} - A Promise that resolves to the index of the predicted token.
1305-
*/
1335+
* Runs a single step of the beam search generation algorithm.
1336+
* @param {any} beam - The current beam being generated.
1337+
* @returns {Promise<any>} - The updated beam after a single generation step.
1338+
*/
13061339
async runBeam(beam) {
13071340
return await seq2seqRunBeam(this, beam);
13081341
}
@@ -1602,9 +1635,9 @@ class WhisperForConditionalGeneration extends WhisperPreTrainedModel {
16021635
}
16031636

16041637
/**
1605-
* Runs a beam for generating outputs.
1606-
* @param {Object} beam - Beam object.
1607-
* @returns {Promise<Object>} Promise object represents the generated outputs for the beam.
1638+
* Runs a single step of the beam search generation algorithm.
1639+
* @param {any} beam - The current beam being generated.
1640+
* @returns {Promise<any>} - The updated beam after a single generation step.
16081641
*/
16091642
async runBeam(beam) {
16101643
return await seq2seqRunBeam(this, beam, {
@@ -1693,10 +1726,9 @@ class VisionEncoderDecoderModel extends PreTrainedModel {
16931726
}
16941727

16951728
/**
1696-
* Generate the next beam step for the given beam.
1697-
*
1698-
* @param {any} beam - The current beam.
1699-
* @returns {Promise<any>} The updated beam with the additional predicted token ID.
1729+
* Runs a single step of the beam search generation algorithm.
1730+
* @param {any} beam - The current beam being generated.
1731+
* @returns {Promise<any>} - The updated beam after a single generation step.
17001732
*/
17011733
async runBeam(beam) {
17021734
return seq2seqRunBeam(this, beam, {
@@ -1792,9 +1824,9 @@ class GPT2LMHeadModel extends GPT2PreTrainedModel {
17921824
}
17931825

17941826
/**
1795-
* Runs beam search for text generation given a beam.
1796-
* @param {any} beam - The Beam object representing the beam.
1797-
* @returns {Promise<any>} A Beam object representing the updated beam after running beam search.
1827+
* Runs a single step of the beam search generation algorithm.
1828+
* @param {any} beam - The current beam being generated.
1829+
* @returns {Promise<any>} - The updated beam after a single generation step.
17981830
*/
17991831
async runBeam(beam) {
18001832
return await textgenRunBeam(this, beam);
@@ -1867,9 +1899,9 @@ class GPTNeoForCausalLM extends GPTNeoPreTrainedModel {
18671899
}
18681900

18691901
/**
1870-
* Runs beam search for text generation given a beam.
1871-
* @param {any} beam - The Beam object representing the beam.
1872-
* @returns {Promise<any>} A Beam object representing the updated beam after running beam search.
1902+
* Runs a single step of the beam search generation algorithm.
1903+
* @param {any} beam - The current beam being generated.
1904+
* @returns {Promise<any>} - The updated beam after a single generation step.
18731905
*/
18741906
async runBeam(beam) {
18751907
return await textgenRunBeam(this, beam);
@@ -1952,9 +1984,9 @@ class CodeGenForCausalLM extends CodeGenPreTrainedModel {
19521984
}
19531985

19541986
/**
1955-
* Runs beam search for text generation given a beam.
1956-
* @param {any} beam - The Beam object representing the beam.
1957-
* @returns {Promise<any>} A Beam object representing the updated beam after running beam search.
1987+
* Runs a single step of the beam search generation algorithm.
1988+
* @param {any} beam - The current beam being generated.
1989+
* @returns {Promise<any>} - The updated beam after a single generation step.
19581990
*/
19591991
async runBeam(beam) {
19601992
return await textgenRunBeam(this, beam);
@@ -2080,8 +2112,9 @@ class MarianMTModel extends MarianPreTrainedModel {
20802112
}
20812113

20822114
/**
2083-
* @param {any} beam
2084-
* @returns {Promise<any>}
2115+
* Runs a single step of the beam search generation algorithm.
2116+
* @param {any} beam - The current beam being generated.
2117+
* @returns {Promise<any>} - The updated beam after a single generation step.
20852118
*/
20862119
async runBeam(beam) {
20872120
return await seq2seqRunBeam(this, beam);
@@ -2203,6 +2236,46 @@ class AutoModelForSequenceClassification {
22032236
}
22042237
}
22052238

2239+
2240+
/**
2241+
* Helper class for loading token classification models from pretrained checkpoints
2242+
*/
2243+
class AutoModelForTokenClassification {
2244+
2245+
static MODEL_CLASS_MAPPING = {
2246+
'bert': BertForTokenClassification,
2247+
'distilbert': DistilBertForTokenClassification,
2248+
}
2249+
2250+
/**
2251+
* Load a token classification model from a pretrained checkpoint
2252+
* @param {string} modelPath - The path to the model checkpoint directory
2253+
* @param {function} [progressCallback=null] - An optional callback function to receive progress updates
2254+
* @returns {Promise<PreTrainedModel>} A promise that resolves to a pre-trained token classification model
2255+
* @throws {Error} if an unsupported model type is encountered
2256+
*/
2257+
static async from_pretrained(modelPath, progressCallback = null) {
2258+
2259+
let [config, session] = await Promise.all([
2260+
fetchJSON(modelPath, 'config.json', progressCallback),
2261+
constructSession(modelPath, 'model.onnx', progressCallback)
2262+
]);
2263+
2264+
// Called when all parts are loaded
2265+
dispatchCallback(progressCallback, {
2266+
status: 'loaded',
2267+
name: modelPath
2268+
});
2269+
2270+
let cls = this.MODEL_CLASS_MAPPING[config.model_type];
2271+
if (!cls) {
2272+
throw Error(`Unsupported model type: ${config.model_type}`)
2273+
}
2274+
return new cls(config, session);
2275+
}
2276+
}
2277+
2278+
22062279
/**
22072280
* Class representing an automatic sequence-to-sequence language model.
22082281
*/
@@ -2248,7 +2321,7 @@ class AutoModelForCausalLM {
22482321
* Loads a pre-trained model from the given path and returns an instance of the appropriate class.
22492322
* @param {string} modelPath - The path to the pre-trained model.
22502323
* @param {function} [progressCallback=null] - An optional callback function to track the progress of the loading process.
2251-
* @returns {Promise<GPT2LMHeadModel|CodeGenForCausalLM>} An instance of the appropriate class for the loaded model.
2324+
* @returns {Promise<GPT2LMHeadModel|CodeGenForCausalLM|CodeGenForCausalLM>} An instance of the appropriate class for the loaded model.
22522325
* @throws {Error} If the loaded model type is not supported.
22532326
*/
22542327
static async from_pretrained(modelPath, progressCallback = null) {
@@ -2369,7 +2442,7 @@ class AutoModelForVision2Seq {
23692442
* Loads a pretrained model from a given path.
23702443
* @param {string} modelPath - The path to the pretrained model.
23712444
* @param {function} progressCallback - Optional callback function to track progress of the model loading.
2372-
* @returns {Promise<VisionEncoderDecoderModel>} - A Promise that resolves to a new instance of VisionEncoderDecoderModel.
2445+
* @returns {Promise<PreTrainedModel>} - A Promise that resolves to a new instance of VisionEncoderDecoderModel.
23732446
*/
23742447
static async from_pretrained(modelPath, progressCallback = null) {
23752448

@@ -2406,7 +2479,7 @@ class AutoModelForImageClassification {
24062479
* Loads a pre-trained image classification model from a given directory path.
24072480
* @param {string} modelPath - The path to the directory containing the pre-trained model.
24082481
* @param {function} [progressCallback=null] - A callback function to monitor the loading progress.
2409-
* @returns {Promise<ViTForImageClassification>} A Promise that resolves with an instance of the ViTForImageClassification class.
2482+
* @returns {Promise<PreTrainedModel>} A Promise that resolves with an instance of the ViTForImageClassification class.
24102483
* @throws {Error} If the specified model type is not supported.
24112484
*/
24122485
static async from_pretrained(modelPath, progressCallback = null) {
@@ -2441,7 +2514,7 @@ class AutoModelForObjectDetection {
24412514
* Loads a pre-trained image classification model from a given directory path.
24422515
* @param {string} modelPath - The path to the directory containing the pre-trained model.
24432516
* @param {function} [progressCallback=null] - A callback function to monitor the loading progress.
2444-
* @returns {Promise<any>} A Promise that resolves with an instance of the ViTForImageClassification class.
2517+
* @returns {Promise<PreTrainedModel>} A Promise that resolves with an instance of the ViTForImageClassification class.
24452518
* @throws {Error} If the specified model type is not supported.
24462519
*/
24472520
static async from_pretrained(modelPath, progressCallback = null) {
@@ -2491,6 +2564,17 @@ class SequenceClassifierOutput extends ModelOutput {
24912564
}
24922565
}
24932566

2567+
class TokenClassifierOutput extends ModelOutput {
2568+
/**
2569+
* @param {Tensor} logits
2570+
*/
2571+
constructor(logits) {
2572+
super();
2573+
this.logits = logits;
2574+
}
2575+
}
2576+
2577+
24942578
class MaskedLMOutput extends ModelOutput {
24952579
/**
24962580
* @param {Tensor} logits
@@ -2517,6 +2601,7 @@ module.exports = {
25172601
AutoModel,
25182602
AutoModelForSeq2SeqLM,
25192603
AutoModelForSequenceClassification,
2604+
AutoModelForTokenClassification,
25202605
AutoModelForCausalLM,
25212606
AutoModelForMaskedLM,
25222607
AutoModelForQuestionAnswering,

src/pipelines.js

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
const {
22
Callable,
33
softmax,
4+
indexOfMax,
45
getTopItems,
56
cos_sim,
67
pathJoin,
@@ -15,6 +16,7 @@ const {
1516
const {
1617
AutoModel,
1718
AutoModelForSequenceClassification,
19+
AutoModelForTokenClassification,
1820
AutoModelForQuestionAnswering,
1921
AutoModelForMaskedLM,
2022
AutoModelForSeq2SeqLM,
@@ -137,6 +139,69 @@ class TextClassificationPipeline extends Pipeline {
137139
}
138140
}
139141

142+
143+
/**
144+
* TokenClassificationPipeline class for executing a token classification task.
145+
* @extends Pipeline
146+
*/
147+
class TokenClassificationPipeline extends Pipeline {
148+
/**
149+
* Executes the token classification task.
150+
* @param {any} texts - The input texts to be classified.
151+
* @param {object} options - An optional object containing the following properties:
152+
* @returns {Promise<object[]|object>} - A promise that resolves to an array or object containing the predicted labels and scores.
153+
*/
154+
async _call(texts, {
155+
ignore_labels = ['O'], // TODO init param?
156+
} = {}) {
157+
158+
let isBatched = Array.isArray(texts);
159+
160+
if (!isBatched) {
161+
texts = [texts];
162+
}
163+
164+
let tokenizer = this.tokenizer;
165+
let [inputs, outputs] = await super._call(texts);
166+
167+
let logits = outputs.logits;
168+
let id2label = this.model.config.id2label;
169+
170+
let toReturn = [];
171+
for (let i = 0; i < logits.dims[0]; ++i) {
172+
let ids = inputs.input_ids.get(i);
173+
let batch = logits.get(i);
174+
175+
// List of tokens that aren't ignored
176+
let tokens = [];
177+
for (let j = 0; j < batch.dims[0]; ++j) {
178+
let tokenData = batch.get(j);
179+
let topScoreIndex = indexOfMax(tokenData.data);
180+
181+
let entity = id2label[topScoreIndex];
182+
if (ignore_labels.includes(entity)) {
183+
// We predicted a token that should be ignored. So, we skip it.
184+
continue;
185+
}
186+
187+
let scores = softmax(tokenData.data);
188+
189+
tokens.push({
190+
entity: entity,
191+
score: scores[topScoreIndex],
192+
index: j,
193+
word: tokenizer.decode([ids.get(j)], { skip_special_tokens: false }),
194+
195+
// TODO: null for now, but will add
196+
start: null,
197+
end: null,
198+
});
199+
}
200+
toReturn.push(tokens);
201+
}
202+
return isBatched ? toReturn : toReturn[0];
203+
}
204+
}
140205
/**
141206
* QuestionAnsweringPipeline class for executing a question answering task.
142207
* @extends Pipeline
@@ -981,7 +1046,15 @@ const SUPPORTED_TASKS = {
9811046
},
9821047
"type": "text",
9831048
},
984-
1049+
"token-classification": {
1050+
"tokenizer": AutoTokenizer,
1051+
"pipeline": TokenClassificationPipeline,
1052+
"model": AutoModelForTokenClassification,
1053+
"default": {
1054+
"model": "Davlan/bert-base-multilingual-cased-ner-hrl",
1055+
},
1056+
"type": "text",
1057+
},
9851058
"question-answering": {
9861059
"tokenizer": AutoTokenizer,
9871060
"pipeline": QuestionAnsweringPipeline,

src/transformers.js

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ const {
99
const {
1010
AutoModel,
1111
AutoModelForSequenceClassification,
12+
AutoModelForTokenClassification,
1213
AutoModelForSeq2SeqLM,
1314
AutoModelForCausalLM,
1415
AutoModelForMaskedLM,
@@ -39,6 +40,7 @@ const moduleExports = {
3940
AutoModel,
4041
AutoModelForSeq2SeqLM,
4142
AutoModelForSequenceClassification,
43+
AutoModelForTokenClassification,
4244
AutoModelForCausalLM,
4345
AutoModelForMaskedLM,
4446
AutoModelForQuestionAnswering,

0 commit comments

Comments
 (0)