Skip to content

Commit

Permalink
Integrated tokenizer (#47)
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhangcs committed Apr 1, 2022
1 parent 6005eca commit 3d4b341
Show file tree
Hide file tree
Showing 10 changed files with 107 additions and 69 deletions.
13 changes: 10 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

setup(
name='supar',
version='1.0.1-a1',
version='1.0.1',
author='Yu Zhang',
author_email='yzhang.cs@outlook.com',
description='Syntactic Parsing Models',
Expand All @@ -22,14 +22,21 @@
setup_requires=[
'setuptools>=18.0',
],
install_requires=['torch>=1.7.0', 'transformers>=3.1.0', 'nltk'],
install_requires=[
'torch>=1.7.0',
'transformers>=3.1.0',
'nltk',
'stanza',
'dill'],
entry_points={
'console_scripts': [
'biaffine-dependency=supar.cmds.biaffine_dependency:main',
'crfnp-dependency=supar.cmds.crfnp_dependency:main',
'crf-dependency=supar.cmds.crf_dependency:main',
'crf2o-dependency=supar.cmds.crf2o_dependency:main',
'crf-constituency=supar.cmds.crf_constituency:main'
'crf-constituency=supar.cmds.crf_constituency:main',
'biaffine-semantic-dependency=supar.cmds.biaffine_semantic_dependency:main',
'vi-semantic-dependency=supar.cmds.vi_semantic_dependency:main'
]
},
python_requires='>=3.6',
Expand Down
2 changes: 1 addition & 1 deletion supar/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
'VISemanticDependencyParser',
'Parser']

__version__ = '1.0.1-a1'
__version__ = '1.0.1'

PARSER = {parser.NAME: parser for parser in [BiaffineDependencyParser,
CRFNPDependencyParser,
Expand Down
14 changes: 9 additions & 5 deletions supar/parsers/constituency.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,17 @@ def evaluate(self, data, buckets=8, batch_size=5000, mbr=True,

return super().evaluate(**Config().update(locals()))

def predict(self, data, pred=None, buckets=8, batch_size=5000, prob=False, mbr=True, verbose=True, **kwargs):
def predict(self, data, pred=None, lang='en', buckets=8, batch_size=5000, prob=False, mbr=True, verbose=True, **kwargs):
r"""
Args:
data (list[list] or str):
The data for prediction, both a list of instances and filename are allowed.
pred (str):
If specified, the predicted results will be saved to the file. Default: ``None``.
lang (str):
Language code (e.g., 'en') or language name (e.g., 'English') for the text to tokenize.
``None`` if tokenization is not required.
Default: ``en``.
buckets (int):
The number of buckets that sentences are assigned to. Default: 32.
batch_size (int):
Expand Down Expand Up @@ -241,15 +245,15 @@ def build(cls, path,
if args.feat == 'char':
FEAT = SubwordField('chars', pad=pad, unk=unk, bos=bos, eos=eos, fix_len=args.fix_len)
elif args.feat == 'bert':
from transformers import AutoTokenizer
from transformers import AutoTokenizer, GPT2Tokenizer, GPT2TokenizerFast
tokenizer = AutoTokenizer.from_pretrained(args.bert)
FEAT = SubwordField('bert',
pad=tokenizer.pad_token,
unk=tokenizer.unk_token,
bos=tokenizer.cls_token or tokenizer.cls_token,
eos=tokenizer.sep_token or tokenizer.sep_token,
bos=tokenizer.bos_token or tokenizer.cls_token,
fix_len=args.fix_len,
tokenize=tokenizer.tokenize)
tokenize=tokenizer.tokenize,
fn=lambda x: ' '+x if isinstance(tokenizer, (GPT2Tokenizer, GPT2TokenizerFast)) else None)
FEAT.vocab = tokenizer.get_vocab()
else:
FEAT = Field('tags', bos=bos, eos=eos)
Expand Down
29 changes: 23 additions & 6 deletions supar/parsers/dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,18 @@ def evaluate(self, data, buckets=8, batch_size=5000,

return super().evaluate(**Config().update(locals()))

def predict(self, data, pred=None, buckets=8, batch_size=5000,
def predict(self, data, pred=None, lang='en', buckets=8, batch_size=5000,
prob=False, tree=True, proj=False, verbose=True, **kwargs):
r"""
Args:
data (list[list] or str):
The data for prediction, both a list of instances and filename are allowed.
pred (str):
If specified, the predicted results will be saved to the file. Default: ``None``.
lang (str):
Language code (e.g., 'en') or language name (e.g., 'English') for the text to tokenize.
``None`` if tokenization is not required.
Default: ``en``.
buckets (int):
The number of buckets that sentences are assigned to. Default: 32.
batch_size (int):
Expand Down Expand Up @@ -248,14 +252,15 @@ def build(cls, path,
if args.feat == 'char':
FEAT = SubwordField('chars', pad=pad, unk=unk, bos=bos, fix_len=args.fix_len)
elif args.feat == 'bert':
from transformers import AutoTokenizer
from transformers import AutoTokenizer, GPT2Tokenizer, GPT2TokenizerFast
tokenizer = AutoTokenizer.from_pretrained(args.bert)
FEAT = SubwordField('bert',
pad=tokenizer.pad_token,
unk=tokenizer.unk_token,
bos=tokenizer.bos_token or tokenizer.cls_token,
fix_len=args.fix_len,
tokenize=tokenizer.tokenize)
tokenize=tokenizer.tokenize,
fn=lambda x: ' '+x if isinstance(tokenizer, (GPT2Tokenizer, GPT2TokenizerFast)) else None)
FEAT.vocab = tokenizer.get_vocab()
else:
FEAT = Field('tags', bos=bos)
Expand Down Expand Up @@ -372,14 +377,18 @@ def evaluate(self, data, buckets=8, batch_size=5000, punct=False,

return super().evaluate(**Config().update(locals()))

def predict(self, data, pred=None, buckets=8, batch_size=5000, prob=False,
def predict(self, data, pred=None, lang='en', buckets=8, batch_size=5000, prob=False,
mbr=True, tree=True, proj=False, verbose=True, **kwargs):
r"""
Args:
data (list[list] or str):
The data for prediction, both a list of instances and filename are allowed.
pred (str):
If specified, the predicted results will be saved to the file. Default: ``None``.
lang (str):
Language code (e.g., 'en') or language name (e.g., 'English') for the text to tokenize.
``None`` if tokenization is not required.
Default: ``en``.
buckets (int):
The number of buckets that sentences are assigned to. Default: 32.
batch_size (int):
Expand Down Expand Up @@ -554,14 +563,18 @@ def evaluate(self, data, buckets=8, batch_size=5000, punct=False,

return super().evaluate(**Config().update(locals()))

def predict(self, data, pred=None, buckets=8, batch_size=5000, prob=False,
def predict(self, data, pred=None, lang='en', buckets=8, batch_size=5000, prob=False,
mbr=True, tree=True, proj=True, verbose=True, **kwargs):
r"""
Args:
data (list[list] or str):
The data for prediction, both a list of instances and filename are allowed.
pred (str):
If specified, the predicted results will be saved to the file. Default: ``None``.
lang (str):
Language code (e.g., 'en') or language name (e.g., 'English') for the text to tokenize.
``None`` if tokenization is not required.
Default: ``en``.
buckets (int):
The number of buckets that sentences are assigned to. Default: 32.
batch_size (int):
Expand Down Expand Up @@ -742,14 +755,18 @@ def evaluate(self, data, buckets=8, batch_size=5000, punct=False,

return super().evaluate(**Config().update(locals()))

def predict(self, data, pred=None, buckets=8, batch_size=5000, prob=False,
def predict(self, data, pred=None, lang='en', buckets=8, batch_size=5000, prob=False,
mbr=True, tree=True, proj=True, verbose=True, **kwargs):
r"""
Args:
data (list[list] or str):
The data for prediction, both a list of instances and filename are allowed.
pred (str):
If specified, the predicted results will be saved to the file. Default: ``None``.
lang (str):
Language code (e.g., 'en') or language name (e.g., 'English') for the text to tokenize.
``None`` if tokenization is not required.
Default: ``en``.
buckets (int):
The number of buckets that sentences are assigned to. Default: 32.
batch_size (int):
Expand Down
7 changes: 4 additions & 3 deletions supar/parsers/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
from datetime import datetime, timedelta

import dill
import supar
import torch
import torch.distributed as dist
Expand Down Expand Up @@ -96,7 +97,7 @@ def evaluate(self, data, buckets=8, batch_size=5000, **kwargs):

return loss, metric

def predict(self, data, pred=None, buckets=8, batch_size=5000, prob=False, **kwargs):
def predict(self, data, pred=None, lang='en', buckets=8, batch_size=5000, prob=False, **kwargs):
args = self.args.update(locals())
init_logger(logger, verbose=args.verbose)

Expand All @@ -105,7 +106,7 @@ def predict(self, data, pred=None, buckets=8, batch_size=5000, prob=False, **kwa
self.transform.append(Field('probs'))

logger.info("Loading the data")
dataset = Dataset(self.transform, data)
dataset = Dataset(self.transform, data, lang=lang)
dataset.build(args.batch_size, args.buckets)
logger.info(f"\n{dataset}")

Expand Down Expand Up @@ -185,4 +186,4 @@ def save(self, path):
'state_dict': state_dict,
'pretrained': pretrained,
'transform': self.transform}
torch.save(state, path)
torch.save(state, path, pickle_module=dill)
24 changes: 18 additions & 6 deletions supar/parsers/semantic_dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,17 @@ def evaluate(self, data, buckets=8, batch_size=5000, verbose=True, **kwargs):

return super().evaluate(**Config().update(locals()))

def predict(self, data, pred=None, buckets=8, batch_size=5000, verbose=True, **kwargs):
def predict(self, data, pred=None, lang='en', buckets=8, batch_size=5000, verbose=True, **kwargs):
r"""
Args:
data (list[list] or str):
The data for prediction, both a list of instances and filename are allowed.
pred (str):
If specified, the predicted results will be saved to the file. Default: ``None``.
lang (str):
Language code (e.g., 'en') or language name (e.g., 'English') for the text to tokenize.
``None`` if tokenization is not required.
Default: ``en``.
buckets (int):
The number of buckets that sentences are assigned to. Default: 32.
batch_size (int):
Expand Down Expand Up @@ -219,14 +223,16 @@ def build(cls,
if 'lemma' in args.feat:
LEMMA = Field('lemmas', pad=pad, unk=unk, bos=bos, lower=True)
if 'bert' in args.feat:
from transformers import AutoTokenizer
from transformers import AutoTokenizer, GPT2Tokenizer, GPT2TokenizerFast
tokenizer = AutoTokenizer.from_pretrained(args.bert)
BERT = SubwordField('bert',
pad=tokenizer.pad_token,
unk=tokenizer.unk_token,
bos=tokenizer.bos_token or tokenizer.cls_token,
eos=tokenizer.eos_token or tokenizer.sep_token,
fix_len=args.fix_len,
tokenize=tokenizer.tokenize)
tokenize=tokenizer.tokenize,
fn=lambda x: ' '+x if isinstance(tokenizer, (GPT2Tokenizer, GPT2TokenizerFast)) else None)
BERT.vocab = tokenizer.get_vocab()
EDGE = ChartField('edges', use_vocab=False, fn=CoNLL.get_edges)
LABEL = ChartField('labels', fn=CoNLL.get_labels)
Expand Down Expand Up @@ -324,13 +330,17 @@ def evaluate(self, data, buckets=8, batch_size=5000, verbose=True, **kwargs):

return super().evaluate(**Config().update(locals()))

def predict(self, data, pred=None, buckets=8, batch_size=5000, verbose=True, **kwargs):
def predict(self, data, pred=None, lang='en', buckets=8, batch_size=5000, verbose=True, **kwargs):
r"""
Args:
data (list[list] or str):
The data for prediction, both a list of instances and filename are allowed.
pred (str):
If specified, the predicted results will be saved to the file. Default: ``None``.
lang (str):
Language code (e.g., 'en') or language name (e.g., 'English') for the text to tokenize.
``None`` if tokenization is not required.
Default: ``en``.
buckets (int):
The number of buckets that sentences are assigned to. Default: 32.
batch_size (int):
Expand Down Expand Up @@ -465,14 +475,16 @@ def build(cls,
if 'lemma' in args.feat:
LEMMA = Field('lemmas', pad=pad, unk=unk, bos=bos, lower=True)
if 'bert' in args.feat:
from transformers import AutoTokenizer
from transformers import AutoTokenizer, GPT2Tokenizer, GPT2TokenizerFast
tokenizer = AutoTokenizer.from_pretrained(args.bert)
BERT = SubwordField('bert',
pad=tokenizer.pad_token,
unk=tokenizer.unk_token,
bos=tokenizer.bos_token or tokenizer.cls_token,
eos=tokenizer.eos_token or tokenizer.sep_token,
fix_len=args.fix_len,
tokenize=tokenizer.tokenize)
tokenize=tokenizer.tokenize,
fn=lambda x: ' '+x if isinstance(tokenizer, (GPT2Tokenizer, GPT2TokenizerFast)) else None)
BERT.vocab = tokenizer.get_vocab()
EDGE = ChartField('edges', use_vocab=False, fn=CoNLL.get_edges)
LABEL = ChartField('labels', fn=CoNLL.get_labels)
Expand Down
2 changes: 1 addition & 1 deletion supar/utils/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def transform(self, sequences):
self.fix_len = max(len(token) for seq in sequences for token in seq)
if self.use_vocab:
sequences = [[[self.vocab[i] if i in self.vocab else self.unk_index for i in token] if token else [self.unk_index]
for token in seq] for seq in sequences]
for token in seq] for seq in sequences]
if self.bos:
sequences = [[[self.bos_index]] + seq for seq in sequences]
if self.eos:
Expand Down
15 changes: 15 additions & 0 deletions supar/utils/tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# -*- coding: utf-8 -*-


class Tokenizer:

def __init__(self, lang='en'):
import stanza
try:
self.pipeline = stanza.Pipeline(lang=lang, processors='tokenize', tokenize_no_ssplit=True)
except Exception:
stanza.download(lang=lang, resources_url='stanford')
self.pipeline = stanza.Pipeline(lang=lang, processors='tokenize', tokenize_no_ssplit=True)

def __call__(self, text):
return [i.text for i in self.pipeline(text).sentences[0].tokens]
30 changes: 24 additions & 6 deletions supar/utils/transform.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# -*- coding: utf-8 -*-

import os
from collections.abc import Iterable

import nltk
from supar.utils.logging import get_logger, progress_bar
from supar.utils.tokenizer import Tokenizer

logger = get_logger(__name__)

Expand Down Expand Up @@ -343,14 +345,18 @@ def istree(cls, sequence, proj=False, multiroot=False):
return False
return next(tarjan(sequence), None) is None

def load(self, data, proj=False, max_len=None, **kwargs):
def load(self, data, lang='en', proj=False, max_len=None, **kwargs):
r"""
Loads the data in CoNLL-X format.
Also supports for loading data from CoNLL-U file with comments and non-integer IDs.
Args:
data (list[list] or str):
A list of instances or a filename.
lang (str):
Language code (e.g., 'en') or language name (e.g., 'English') for the text to tokenize.
``None`` if tokenization is not required.
Default: ``en``.
proj (bool):
If ``True``, discards all non-projective sentences. Default: ``False``.
max_len (int):
Expand All @@ -360,11 +366,15 @@ def load(self, data, proj=False, max_len=None, **kwargs):
A list of :class:`CoNLLSentence` instances.
"""

if isinstance(data, str):
if isinstance(data, str) and os.path.exists(data):
with open(data, 'r') as f:
lines = [line.strip() for line in f]
else:
data = [data] if isinstance(data[0], str) else data
if lang is not None:
tokenizer = Tokenizer(lang)
data = [tokenizer(i) for i in ([data] if isinstance(data, str) else data)]
else:
data = [data] if isinstance(data[0], str) else data
lines = '\n'.join([self.toconll(i) for i in data]).split('\n')

i, start, sentences = 0, 0, []
Expand Down Expand Up @@ -680,23 +690,31 @@ def track(node):
return [tree]
return nltk.Tree(root, track(iter(sequence)))

def load(self, data, max_len=None, **kwargs):
def load(self, data, lang='en', max_len=None, **kwargs):
r"""
Args:
data (list[list] or str):
A list of instances or a filename.
lang (str):
Language code (e.g., 'en') or language name (e.g., 'English') for the text to tokenize.
``None`` if tokenization is not required.
Default: ``en``.
max_len (int):
Sentences exceeding the length will be discarded. Default: ``None``.
Returns:
A list of :class:`TreeSentence` instances.
"""
if isinstance(data, str):
if isinstance(data, str) and os.path.exists(data):
with open(data, 'r') as f:
trees = [nltk.Tree.fromstring(string) for string in f]
self.root = trees[0].label()
else:
data = [data] if isinstance(data[0], str) else data
if lang is not None:
tokenizer = Tokenizer(lang)
data = [tokenizer(i) for i in ([data] if isinstance(data, str) else data)]
else:
data = [data] if isinstance(data[0], str) else data
trees = [self.totree(i, self.root) for i in data]

i, sentences = 0, []
Expand Down
Loading

0 comments on commit 3d4b341

Please sign in to comment.