Skip to content

Commit

Permalink
integrate tokenizer (#47)
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhangcs committed Mar 16, 2021
1 parent 23fb605 commit d6b8b6d
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 30 deletions.
7 changes: 6 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@
setup_requires=[
'setuptools>=18.0',
],
install_requires=['torch>=1.4.0', 'transformers>=3.1.0', 'nltk'],
install_requires=[
'torch>=1.7.1',
'transformers>=3.1.0',
'nltk',
'stanza',
'dill'],
entry_points={
'console_scripts': [
'biaffine-dependency=supar.cmds.biaffine_dependency:main',
Expand Down
17 changes: 13 additions & 4 deletions supar/parsers/constituency.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,17 @@ def evaluate(self, data, buckets=8, batch_size=5000, mbr=True,

return super().evaluate(**Config().update(locals()))

def predict(self, data, pred=None, buckets=8, batch_size=5000, prob=False, mbr=True, verbose=True, **kwargs):
def predict(self, data, pred=None, lang='en', buckets=8, batch_size=5000, prob=False, mbr=True, verbose=True, **kwargs):
r"""
Args:
data (list[list] or str):
The data for prediction, both a list of instances and filename are allowed.
pred (str):
If specified, the predicted results will be saved to the file. Default: ``None``.
lang (str):
Language code (e.g., 'en') or language name (e.g., 'English') for the text to tokenize.
``None`` if tokenization is not required.
Default: ``en``.
buckets (int):
The number of buckets that sentences are assigned to. Default: 32.
batch_size (int):
Expand Down Expand Up @@ -235,15 +239,16 @@ def build(cls, path,
if 'char' in args.feat:
CHAR = SubwordField('chars', pad=pad, unk=unk, bos=bos, eos=eos, fix_len=args.fix_len)
if 'bert' in args.feat:
from transformers import AutoTokenizer
from transformers import AutoTokenizer, GPT2Tokenizer, GPT2TokenizerFast
tokenizer = AutoTokenizer.from_pretrained(args.bert)
BERT = SubwordField('bert',
pad=tokenizer.pad_token,
unk=tokenizer.unk_token,
bos=tokenizer.cls_token or tokenizer.cls_token,
eos=tokenizer.sep_token or tokenizer.sep_token,
fix_len=args.fix_len,
tokenize=tokenizer.tokenize)
tokenize=tokenizer.tokenize,
fn=lambda x: ' '+x if isinstance(tokenizer, (GPT2Tokenizer, GPT2TokenizerFast)) else None)
BERT.vocab = tokenizer.get_vocab()
TREE = RawField('trees')
CHART = ChartField('charts')
Expand Down Expand Up @@ -345,13 +350,17 @@ def evaluate(self, data, buckets=8, batch_size=5000,

return super().evaluate(**Config().update(locals()))

def predict(self, data, pred=None, buckets=8, batch_size=5000, prob=False, verbose=True, **kwargs):
def predict(self, data, pred=None, lang='en', buckets=8, batch_size=5000, prob=False, verbose=True, **kwargs):
r"""
Args:
data (list[list] or str):
The data for prediction, both a list of instances and filename are allowed.
pred (str):
If specified, the predicted results will be saved to the file. Default: ``None``.
lang (str):
Language code (e.g., 'en') or language name (e.g., 'English') for the text to tokenize.
``None`` if tokenization is not required.
Default: ``en``.
buckets (int):
The number of buckets that sentences are assigned to. Default: 32.
batch_size (int):
Expand Down
40 changes: 31 additions & 9 deletions supar/parsers/dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,18 @@ def evaluate(self, data, buckets=8, batch_size=5000,

return super().evaluate(**Config().update(locals()))

def predict(self, data, pred=None, buckets=8, batch_size=5000, prob=False, comp=False,
def predict(self, data, pred=None, lang='en', buckets=8, batch_size=5000, prob=False, comp=False,
tree=True, proj=False, verbose=True, **kwargs):
r"""
Args:
data (list[list] or str):
The data for prediction, both a list of instances and filename are allowed.
pred (str):
If specified, the predicted results will be saved to the file. Default: ``None``.
lang (str):
Language code (e.g., 'en') or language name (e.g., 'English') for the text to tokenize.
``None`` if tokenization is not required.
Default: ``en``.
buckets (int):
The number of buckets that sentences are assigned to. Default: 32.
batch_size (int):
Expand Down Expand Up @@ -255,14 +259,15 @@ def build(cls, path,
if 'char' in args.feat:
CHAR = SubwordField('chars', pad=pad, unk=unk, bos=bos, fix_len=args.fix_len)
if 'bert' in args.feat:
from transformers import AutoTokenizer
from transformers import AutoTokenizer, GPT2Tokenizer, GPT2TokenizerFast
tokenizer = AutoTokenizer.from_pretrained(args.bert)
BERT = SubwordField('bert',
pad=tokenizer.pad_token,
unk=tokenizer.unk_token,
bos=tokenizer.bos_token or tokenizer.cls_token,
fix_len=args.fix_len,
tokenize=tokenizer.tokenize)
tokenize=tokenizer.tokenize,
fn=lambda x: ' '+x if isinstance(tokenizer, (GPT2Tokenizer, GPT2TokenizerFast)) else None)
BERT.vocab = tokenizer.get_vocab()
ARC = Field('arcs', bos=bos, use_vocab=False, fn=CoNLL.get_arcs)
REL = Field('rels', bos=bos)
Expand Down Expand Up @@ -364,14 +369,18 @@ def evaluate(self, data, buckets=8, batch_size=5000, punct=False,

return super().evaluate(**Config().update(locals()))

def predict(self, data, pred=None, buckets=8, batch_size=5000, prob=False,
def predict(self, data, pred=None, lang='en', buckets=8, batch_size=5000, prob=False,
mbr=True, tree=True, proj=False, verbose=True, **kwargs):
r"""
Args:
data (list[list] or str):
The data for prediction, both a list of instances and filename are allowed.
pred (str):
If specified, the predicted results will be saved to the file. Default: ``None``.
lang (str):
Language code (e.g., 'en') or language name (e.g., 'English') for the text to tokenize.
``None`` if tokenization is not required.
Default: ``en``.
buckets (int):
The number of buckets that sentences are assigned to. Default: 32.
batch_size (int):
Expand Down Expand Up @@ -548,14 +557,18 @@ def evaluate(self, data, buckets=8, batch_size=5000, punct=False,

return super().evaluate(**Config().update(locals()))

def predict(self, data, pred=None, buckets=8, batch_size=5000, prob=False, comp=False,
def predict(self, data, pred=None, lang='en', buckets=8, batch_size=5000, prob=False, comp=False,
mbr=True, tree=True, proj=True, verbose=True, **kwargs):
r"""
Args:
data (list[list] or str):
The data for prediction, both a list of instances and filename are allowed.
pred (str):
If specified, the predicted results will be saved to the file. Default: ``None``.
lang (str):
Language code (e.g., 'en') or language name (e.g., 'English') for the text to tokenize.
``None`` if tokenization is not required.
Default: ``en``.
buckets (int):
The number of buckets that sentences are assigned to. Default: 32.
batch_size (int):
Expand Down Expand Up @@ -740,14 +753,18 @@ def evaluate(self, data, buckets=8, batch_size=5000, punct=False,

return super().evaluate(**Config().update(locals()))

def predict(self, data, pred=None, buckets=8, batch_size=5000, prob=False, comp=False,
def predict(self, data, pred=None, lang='en', buckets=8, batch_size=5000, prob=False, comp=False,
mbr=True, tree=True, proj=True, verbose=True, **kwargs):
r"""
Args:
data (list[list] or str):
The data for prediction, both a list of instances and filename are allowed.
pred (str):
If specified, the predicted results will be saved to the file. Default: ``None``.
lang (str):
Language code (e.g., 'en') or language name (e.g., 'English') for the text to tokenize.
``None`` if tokenization is not required.
Default: ``en``.
buckets (int):
The number of buckets that sentences are assigned to. Default: 32.
batch_size (int):
Expand Down Expand Up @@ -903,14 +920,15 @@ def build(cls, path,
if 'char' in args.feat:
CHAR = SubwordField('chars', pad=pad, unk=unk, bos=bos, fix_len=args.fix_len)
if 'bert' in args.feat:
from transformers import AutoTokenizer
from transformers import AutoTokenizer, GPT2Tokenizer, GPT2TokenizerFast
tokenizer = AutoTokenizer.from_pretrained(args.bert)
BERT = SubwordField('bert',
pad=tokenizer.pad_token,
unk=tokenizer.unk_token,
bos=tokenizer.bos_token or tokenizer.cls_token,
fix_len=args.fix_len,
tokenize=tokenizer.tokenize)
tokenize=tokenizer.tokenize,
fn=lambda x: ' '+x if isinstance(tokenizer, (GPT2Tokenizer, GPT2TokenizerFast)) else None)
BERT.vocab = tokenizer.get_vocab()
ARC = Field('arcs', bos=bos, use_vocab=False, fn=CoNLL.get_arcs)
SIB = ChartField('sibs', bos=bos, use_vocab=False, fn=CoNLL.get_sibs)
Expand Down Expand Up @@ -1015,14 +1033,18 @@ def evaluate(self, data, buckets=8, batch_size=5000, punct=False,

return super().evaluate(**Config().update(locals()))

def predict(self, data, pred=None, buckets=8, batch_size=5000, prob=False, comp=False,
def predict(self, data, pred=None, lang='en', buckets=8, batch_size=5000, prob=False, comp=False,
tree=True, proj=True, verbose=True, **kwargs):
r"""
Args:
data (list[list] or str):
The data for prediction, both a list of instances and filename are allowed.
pred (str):
If specified, the predicted results will be saved to the file. Default: ``None``.
lang (str):
Language code (e.g., 'en') or language name (e.g., 'English') for the text to tokenize.
``None`` if tokenization is not required.
Default: ``en``.
buckets (int):
The number of buckets that sentences are assigned to. Default: 32.
batch_size (int):
Expand Down
9 changes: 5 additions & 4 deletions supar/parsers/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
from datetime import datetime, timedelta

import dill
import supar
import torch
import torch.distributed as dist
Expand Down Expand Up @@ -59,7 +60,7 @@ def train(self, train, dev, test, buckets=32, batch_size=5000, clip=5.0, epochs=
logger.info(f"{'test:':5} loss: {loss:.4f} - {test_metric}")

t = datetime.now() - start
if dev_metric > best_metric and epoch >= args.patience:
if dev_metric > best_metric:
best_e, best_metric = epoch, dev_metric
if is_master():
self.save(args.path)
Expand Down Expand Up @@ -95,15 +96,15 @@ def evaluate(self, data, buckets=8, batch_size=5000, **kwargs):

return loss, metric

def predict(self, data, pred=None, buckets=8, batch_size=5000, prob=False, **kwargs):
def predict(self, data, pred=None, lang='en', buckets=8, batch_size=5000, prob=False, **kwargs):
args = self.args.update(locals())
init_logger(logger, verbose=args.verbose)

if args.prob:
self.transform.append(Field('probs'))

logger.info("Loading the data")
dataset = Dataset(self.transform, data)
dataset = Dataset(self.transform, data, lang=lang)
dataset.build(args.batch_size, args.buckets)
logger.info(f"\n{dataset}")

Expand Down Expand Up @@ -183,4 +184,4 @@ def save(self, path):
'state_dict': state_dict,
'pretrained': pretrained,
'transform': self.transform}
torch.save(state, path)
torch.save(state, path, pickle_module=dill)
22 changes: 16 additions & 6 deletions supar/parsers/semantic_dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,17 @@ def evaluate(self, data, buckets=8, batch_size=5000, verbose=True, **kwargs):

return super().evaluate(**Config().update(locals()))

def predict(self, data, pred=None, buckets=8, batch_size=5000, verbose=True, **kwargs):
def predict(self, data, pred=None, lang='en', buckets=8, batch_size=5000, verbose=True, **kwargs):
r"""
Args:
data (list[list] or str):
The data for prediction, both a list of instances and filename are allowed.
pred (str):
If specified, the predicted results will be saved to the file. Default: ``None``.
lang (str):
Language code (e.g., 'en') or language name (e.g., 'English') for the text to tokenize.
``None`` if tokenization is not required.
Default: ``en``.
buckets (int):
The number of buckets that sentences are assigned to. Default: 32.
batch_size (int):
Expand Down Expand Up @@ -209,14 +213,15 @@ def build(cls,
if 'lemma' in args.feat:
LEMMA = Field('lemmas', pad=pad, unk=unk, bos=bos, lower=True)
if 'bert' in args.feat:
from transformers import AutoTokenizer
from transformers import AutoTokenizer, GPT2Tokenizer, GPT2TokenizerFast
tokenizer = AutoTokenizer.from_pretrained(args.bert)
BERT = SubwordField('bert',
pad=tokenizer.pad_token,
unk=tokenizer.unk_token,
bos=tokenizer.bos_token or tokenizer.cls_token,
fix_len=args.fix_len,
tokenize=tokenizer.tokenize)
tokenize=tokenizer.tokenize,
fn=lambda x: ' '+x if isinstance(tokenizer, (GPT2Tokenizer, GPT2TokenizerFast)) else None)
BERT.vocab = tokenizer.get_vocab()
EDGE = ChartField('edges', use_vocab=False, fn=CoNLL.get_edges)
LABEL = ChartField('labels', fn=CoNLL.get_labels)
Expand Down Expand Up @@ -310,13 +315,17 @@ def evaluate(self, data, buckets=8, batch_size=5000, verbose=True, **kwargs):

return super().evaluate(**Config().update(locals()))

def predict(self, data, pred=None, buckets=8, batch_size=5000, verbose=True, **kwargs):
def predict(self, data, pred=None, lang='en', buckets=8, batch_size=5000, verbose=True, **kwargs):
r"""
Args:
data (list[list] or str):
The data for prediction, both a list of instances and filename are allowed.
pred (str):
If specified, the predicted results will be saved to the file. Default: ``None``.
lang (str):
Language code (e.g., 'en') or language name (e.g., 'English') for the text to tokenize.
``None`` if tokenization is not required.
Default: ``en``.
buckets (int):
The number of buckets that sentences are assigned to. Default: 32.
batch_size (int):
Expand Down Expand Up @@ -448,14 +457,15 @@ def build(cls,
if 'lemma' in args.feat:
LEMMA = Field('lemmas', pad=pad, unk=unk, bos=bos, lower=True)
if 'bert' in args.feat:
from transformers import AutoTokenizer
from transformers import AutoTokenizer, GPT2Tokenizer, GPT2TokenizerFast
tokenizer = AutoTokenizer.from_pretrained(args.bert)
BERT = SubwordField('bert',
pad=tokenizer.pad_token,
unk=tokenizer.unk_token,
bos=tokenizer.bos_token or tokenizer.cls_token,
fix_len=args.fix_len,
tokenize=tokenizer.tokenize)
tokenize=tokenizer.tokenize,
fn=lambda x: ' '+x if isinstance(tokenizer, (GPT2Tokenizer, GPT2TokenizerFast)) else None)
BERT.vocab = tokenizer.get_vocab()
EDGE = ChartField('edges', use_vocab=False, fn=CoNLL.get_edges)
LABEL = ChartField('labels', fn=CoNLL.get_labels)
Expand Down
15 changes: 15 additions & 0 deletions supar/utils/tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# -*- coding: utf-8 -*-


class Tokenizer:

def __init__(self, lang='en'):
import stanza
try:
self.pipeline = stanza.Pipeline(lang=lang, processors='tokenize', tokenize_no_ssplit=True)
except Exception:
stanza.download(lang=lang, resources_url='stanford')
self.pipeline = stanza.Pipeline(lang=lang, processors='tokenize', tokenize_no_ssplit=True)

def __call__(self, text):
return [i.text for i in self.pipeline(text).sentences[0].tokens]
Loading

0 comments on commit d6b8b6d

Please sign in to comment.