Skip to content

Multilayer ner #1289

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 0 additions & 12 deletions stanza/models/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,18 +396,6 @@ def set_random_seed(seed):
torch.cuda.manual_seed(seed)
return seed

def get_known_tags(known_tags):
"""
Turns either a list or a list of lists into a single sorted list

Actually this is not at all necessarily about tags
"""
if isinstance(known_tags, list) and isinstance(known_tags[0], list):
known_tags = sorted(set(x for y in known_tags for x in y))
else:
known_tags = sorted(known_tags)
return known_tags

def find_missing_tags(known_tags, test_tags):
if isinstance(known_tags, list) and isinstance(known_tags[0], list):
known_tags = set(x for y in known_tags for x in y)
Expand Down
21 changes: 16 additions & 5 deletions stanza/models/common/vocab.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from copy import copy
from collections import Counter, OrderedDict
from collections.abc import Iterable
import os
import pickle

Expand All @@ -12,6 +13,7 @@
ROOT = '<ROOT>'
ROOT_ID = 3
VOCAB_PREFIX = [PAD, UNK, EMPTY, ROOT]
VOCAB_PREFIX_SIZE = len(VOCAB_PREFIX)

class BaseVocab:
""" A base class for common vocabulary operations. Each subclass should at least
Expand Down Expand Up @@ -111,7 +113,7 @@ def __init__(self, data=None, lang="", idx=0, sep="", keyed=False):

def unit2parts(self, unit):
# unpack parts of a unit
if self.sep == "":
if not self.sep:
parts = [x for x in unit]
else:
parts = unit.split(self.sep)
Expand All @@ -137,17 +139,23 @@ def unit2id(self, unit):
return [self._unit2id[i].get(parts[i], UNK_ID) if i < len(parts) else EMPTY_ID for i in range(len(self._unit2id))]

def id2unit(self, id):
# special case: allow single ids for vocabs with length 1
if len(self._id2unit) == 1 and not isinstance(id, Iterable):
id = (id,)
items = []
for v, k in zip(id, self._id2unit.keys()):
if v == EMPTY_ID: continue
if self.keyed:
items.append("{}={}".format(k, self._id2unit[k][v]))
else:
items.append(self._id2unit[k][v])
res = self.sep.join(items)
if res == "":
res = "_"
return res
if self.sep:
res = self.sep.join(items)
if res == "":
res = "_"
return res
else:
return items

def build_vocab(self):
allunits = [w[self.idx] for sent in self.data for w in sent]
Expand Down Expand Up @@ -191,6 +199,9 @@ def build_vocab(self):
def lens(self):
return [len(self._unit2id[k]) for k in self._unit2id]

def items(self, idx):
return self._id2unit[idx]

class BaseMultiVocab:
""" A convenient vocab container that can store multiple BaseVocab instances, and support
safe serialization of all instances via state dicts. Each subclass of this base class
Expand Down
19 changes: 12 additions & 7 deletions stanza/models/ner/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from stanza.models.common.bert_embedding import filter_data
from stanza.models.common.data import map_to_ids, get_long_tensor, sort_all
from stanza.models.common.vocab import PAD_ID, VOCAB_PREFIX
from stanza.models.pos.vocab import CharVocab, WordVocab
from stanza.models.ner.vocab import TagVocab, MultiVocab
from stanza.models.pos.vocab import CharVocab, CompositeVocab, WordVocab
from stanza.models.ner.vocab import MultiVocab
from stanza.models.common.doc import *
from stanza.models.ner.utils import process_tags
from stanza.models.ner.utils import process_tags, normalize_empty_tags

logger = logging.getLogger('stanza')

Expand Down Expand Up @@ -68,7 +68,8 @@ def from_model(model_filename):
else:
charvocab = CharVocab(data, self.args['shorthand'])
wordvocab = self.pretrain.vocab
tagvocab = TagVocab(data, self.args['shorthand'], idx=1)
tag_data = [[(x[1],) for x in sentence] for sentence in data]
tagvocab = CompositeVocab(tag_data, self.args['shorthand'], idx=0, sep=None)
ignore = None
if self.args['emb_finetune_known_only']:
if self.args['lowercase']:
Expand All @@ -90,7 +91,7 @@ def preprocess(self, data, vocab, args):
char_case = lambda x: x.lower()
else:
char_case = lambda x: x
for sent in data:
for sent_idx, sent in enumerate(data):
processed_sent = [[w[0] for w in sent]]
processed_sent += [[vocab['char'].map([char_case(x) for x in w[0]]) for w in sent]]
processed_sent += [vocab['tag'].map([w[1] for w in sent])]
Expand All @@ -109,7 +110,7 @@ def __getitem__(self, key):
batch = self.data[key]
batch_size = len(batch)
batch = list(zip(*batch))
assert len(batch) == 3 # words: List[List[int]], chars: List[List[List[int]]], tags: List[List[int]]
assert len(batch) == 3 # words: List[List[int]], chars: List[List[List[int]]], tags: List[List[List[int]]]

# sort sentences by lens for easy RNN operations
sentlens = [len(x) for x in batch[0]]
Expand Down Expand Up @@ -146,9 +147,13 @@ def __iter__(self):
yield self.__getitem__(i)

def load_doc(self, doc):
data = doc.get([TEXT, NER], as_sentences=True, from_token=True)
# preferentially load the MULTI_NER in case we are training /
# testing a model with multiple layers of tags
data = doc.get([TEXT, NER, MULTI_NER], as_sentences=True, from_token=True)
data = [[[token[0], token[2]] if token[2] else [token[0], (token[1],)] for token in sentence] for sentence in data]
if self.preprocess_tags: # preprocess tags
data = process_tags(data, self.args.get('scheme', 'bio'))
data = normalize_empty_tags(data)
return data

def process_chars(self, sents):
Expand Down
40 changes: 31 additions & 9 deletions stanza/models/ner/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from stanza.models.common.char_model import CharacterModel, CharacterLanguageModel
from stanza.models.common.crf import CRFLoss
from stanza.models.common.foundation_cache import load_bert
from stanza.models.common.vocab import PAD_ID, UNK_ID
from stanza.models.common.vocab import PAD_ID, UNK_ID, EMPTY_ID
from stanza.models.common.bert_embedding import extract_bert_embeddings

logger = logging.getLogger('stanza')
Expand Down Expand Up @@ -120,12 +120,18 @@ def add_unsaved_module(name, module):
self.taggerlstm_c_init = nn.Parameter(torch.zeros(2 * self.args['num_layers'], 1, self.args['hidden_dim']), requires_grad=False)

# tag classifier
num_tag = len(self.vocab['tag'])
self.tag_clf = nn.Linear(self.args['hidden_dim']*2, num_tag)
self.tag_clf.bias.data.zero_()

# criterion
self.crit = CRFLoss(num_tag)
tag_lengths = self.vocab['tag'].lens()
self.num_output_layers = len(tag_lengths)
if self.args.get('connect_output_layers'):
tag_clfs = [nn.Linear(self.args['hidden_dim']*2, tag_lengths[0])]
for prev_length, next_length in zip(tag_lengths[:-1], tag_lengths[1:]):
tag_clfs.append(nn.Linear(self.args['hidden_dim']*2 + prev_length, next_length))
self.tag_clfs = nn.ModuleList(tag_clfs)
else:
self.tag_clfs = nn.ModuleList([nn.Linear(self.args['hidden_dim']*2, num_tag) for num_tag in tag_lengths])
for tag_clf in self.tag_clfs:
tag_clf.bias.data.zero_()
self.crits = nn.ModuleList([CRFLoss(num_tag) for num_tag in tag_lengths])

self.drop = nn.Dropout(args['dropout'])
self.worddrop = WordDropout(args['word_dropout'])
Expand Down Expand Up @@ -233,8 +239,24 @@ def pad(x):
lstm_outputs = pad(lstm_outputs)
lstm_outputs = self.lockeddrop(lstm_outputs)
lstm_outputs = pack(lstm_outputs).data
logits = pad(self.tag_clf(lstm_outputs)).contiguous()
loss, trans = self.crit(logits, word_mask, tags)

loss = 0
logits = []
trans = []
for idx, (tag_clf, crit) in enumerate(zip(self.tag_clfs, self.crits)):
if not self.args.get('connect_output_layers') or idx == 0:
next_logits = pad(tag_clf(lstm_outputs)).contiguous()
else:
# here we pack the output of the previous round, then append it
packed_logits = pack(next_logits).data
input_logits = torch.cat([lstm_outputs, packed_logits], axis=1)
next_logits = pad(tag_clf(input_logits)).contiguous()
# the tag_mask lets us avoid backprop on a blank tag
tag_mask = torch.eq(tags[:, :, idx], EMPTY_ID)
next_loss, next_trans = crit(next_logits, torch.bitwise_or(tag_mask, word_mask), tags[:, :, idx])
loss = loss + next_loss
logits.append(next_logits)
trans.append(next_trans)

return loss, logits, trans

Expand Down
45 changes: 38 additions & 7 deletions stanza/models/ner/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from stanza.models.common.foundation_cache import NoTransformerFoundationCache
from stanza.models.common.trainer import Trainer as BaseTrainer
from stanza.models.common.vocab import VOCAB_PREFIX
from stanza.models.common.vocab import VOCAB_PREFIX, VOCAB_PREFIX_SIZE
from stanza.models.common import utils, loss
from stanza.models.ner.model import NERTagger
from stanza.models.ner.vocab import MultiVocab
Expand Down Expand Up @@ -72,6 +72,12 @@ def __init__(self, args=None, vocab=None, pretrain=None, model_file=None, device
self.vocab = vocab
self.model = NERTagger(args, vocab, emb_matrix=pretrain.emb, foundation_cache=foundation_cache)

# if this wasn't set anywhere, we use a default of the 0th tagset
# we don't set this as a default in the options so that
# we can distinguish "intentionally set to 0" and "not set at all"
if self.args.get('predict_tagset', None) is None:
self.args['predict_tagset'] = 0

if train_classifier_only:
logger.info('Disabling gradient for non-classifier layers')
exclude = ['tag_clf', 'crit']
Expand Down Expand Up @@ -111,13 +117,28 @@ def predict(self, batch, unsort=True):
_, logits, trans = self.model(word, wordchars, wordchars_mask, tags, word_orig_idx, sentlens, wordlens, chars, charoffsets, charlens, char_orig_idx)

# decode
trans = trans.data.cpu().numpy()
scores = logits.data.cpu().numpy()
bs = logits.size(0)
# TODO: might need to decode multiple columns of output for
# models with multiple layers
trans = [x.data.cpu().numpy() for x in trans]
logits = [x.data.cpu().numpy() for x in logits]
batch_size = logits[0].shape[0]
if any(x.shape[0] != batch_size for x in logits):
raise AssertionError("Expected all of the logits to have the same size")
tag_seqs = []
for i in range(bs):
tags, _ = viterbi_decode(scores[i, :sentlens[i]], trans)
predict_tagset = self.args['predict_tagset']
for i in range(batch_size):
# for each tag column in the output, decode the tag assignments
tags = [viterbi_decode(x[i, :sentlens[i]], y)[0] for x, y in zip(logits, trans)]
# TODO: this is to patch that the model can sometimes predict < "O"
tags = [[x if x >= VOCAB_PREFIX_SIZE else VOCAB_PREFIX_SIZE for x in y] for y in tags]
# that gives us N lists of |sent| tags, whereas we want |sent| lists of N tags
tags = list(zip(*tags))
# now unmap that to the tags in the vocab
tags = self.vocab['tag'].unmap(tags)
# for now, allow either TagVocab or CompositeVocab
# TODO: we might want to return all of the predictions
# rather than a single column
tags = [x[predict_tagset] if isinstance(x, list) else x for x in tags]
tags = fix_singleton_tags(tags)
tag_seqs += [tags]

Expand Down Expand Up @@ -153,6 +174,11 @@ def load(self, filename, pretrain=None, args=None, foundation_cache=None):
raise
self.args = checkpoint['config']
if args: self.args.update(args)
# if predict_tagset was not explicitly set in the args,
# we use the value the model was trained with
if self.args.get('predict_tagset', None) is None:
self.args['predict_tagset'] = checkpoint['config'].get('predict_tagset', None)

self.vocab = MultiVocab.load_state_dict(checkpoint['vocab'])

emb_matrix=None
Expand All @@ -164,6 +190,11 @@ def load(self, filename, pretrain=None, args=None, foundation_cache=None):
logger.debug("Model %s has a finetuned transformer. Not using transformer cache to make sure the finetuned version of the transformer isn't accidentally used elsewhere", filename)
foundation_cache = NoTransformerFoundationCache(foundation_cache)
force_bert_saved = True
if any(x.startswith("crit.") for x in checkpoint['model'].keys()):
logger.debug("Old model format detected. Updating to the new format with one column of tags")
checkpoint['model']['crits.0._transitions'] = checkpoint['model'].pop('crit._transitions')
checkpoint['model']['tag_clfs.0.weight'] = checkpoint['model'].pop('tag_clf.weight')
checkpoint['model']['tag_clfs.0.bias'] = checkpoint['model'].pop('tag_clf.bias')
self.model = NERTagger(self.args, self.vocab, emb_matrix=emb_matrix, foundation_cache=foundation_cache, force_bert_saved=force_bert_saved)
self.model.load_state_dict(checkpoint['model'], strict=False)

Expand All @@ -185,7 +216,7 @@ def get_known_tags(self):
Removes the S-, B-, etc, and does not include O
"""
tags = set()
for tag in self.vocab['tag']:
for tag in self.vocab['tag'].items(0):
if tag in VOCAB_PREFIX:
continue
if tag == 'O':
Expand Down
Loading