Skip to content

Commit cc7f10d

Browse files
committed
Janky patch to avoid errors from predicting states which are effetively OOV - EMPTY in particular is treated by the CompositeVocab as 'leave this blank'. The better fix would be to remove those states from the output layer entirely
1 parent cf9aa2f commit cc7f10d

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

stanza/models/common/vocab.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
ROOT = '<ROOT>'
1414
ROOT_ID = 3
1515
VOCAB_PREFIX = [PAD, UNK, EMPTY, ROOT]
16+
VOCAB_PREFIX_SIZE = len(VOCAB_PREFIX)
1617

1718
class BaseVocab:
1819
""" A base class for common vocabulary operations. Each subclass should at least

stanza/models/ner/trainer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from stanza.models.common.foundation_cache import NoTransformerFoundationCache
1111
from stanza.models.common.trainer import Trainer as BaseTrainer
12-
from stanza.models.common.vocab import VOCAB_PREFIX
12+
from stanza.models.common.vocab import VOCAB_PREFIX, VOCAB_PREFIX_SIZE
1313
from stanza.models.common import utils, loss
1414
from stanza.models.ner.model import NERTagger
1515
from stanza.models.ner.vocab import MultiVocab
@@ -129,6 +129,8 @@ def predict(self, batch, unsort=True):
129129
for i in range(batch_size):
130130
# for each tag column in the output, decode the tag assignments
131131
tags = [viterbi_decode(x[i, :sentlens[i]], y)[0] for x, y in zip(logits, trans)]
132+
# TODO: this is to patch that the model can sometimes predict < "O"
133+
tags = [[x if x >= VOCAB_PREFIX_SIZE else VOCAB_PREFIX_SIZE for x in y] for y in tags]
132134
# that gives us N lists of |sent| tags, whereas we want |sent| lists of N tags
133135
tags = list(zip(*tags))
134136
# now unmap that to the tags in the vocab

0 commit comments

Comments
 (0)