Skip to content

Commit

Permalink
Allow building one field from several datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhangcs committed Aug 31, 2023
1 parent 06ea307 commit 831df04
Showing 1 changed file with 27 additions and 15 deletions.
42 changes: 27 additions & 15 deletions supar/utils/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,18 +177,18 @@ def preprocess(self, data: Union[str, Iterable]) -> Iterable:

def build(
self,
dataset: Dataset,
data: Union[Dataset, Iterable[Dataset]],
min_freq: int = 1,
embed: Optional[Embedding] = None,
norm: Callable = None
) -> Field:
r"""
Constructs a :class:`~supar.utils.vocab.Vocab` object for this field from the dataset.
Constructs a :class:`~supar.utils.vocab.Vocab` object for this field from one or more datasets.
If the vocabulary has already existed, this function will have no effect.
Args:
dataset (Dataset):
A :class:`~supar.utils.data.Dataset` object.
data (Union[Dataset, Iterable[Dataset]]):
One or more :class:`~supar.utils.data.Dataset` object.
One of the attributes should be named after the name of this field.
min_freq (int):
The minimum frequency needed to include a token in the vocabulary. Default: 1.
Expand All @@ -202,14 +202,18 @@ def build(
return

@wait
def build_vocab(dataset):
def build_vocab(data):
return Vocab(counter=Counter(token
for seq in progress_bar(getattr(dataset, self.name))
for seq in progress_bar(getattr(data, self.name))
for token in self.preprocess(seq)),
min_freq=min_freq,
specials=self.specials,
unk_index=self.unk_index)
self.vocab = build_vocab(dataset)
if isinstance(data, Dataset):
data = [data]
self.vocab = build_vocab(data[0])
for i in data[1:]:
self.vocab.update(build_vocab(i))

if not embed:
self.embed = None
Expand Down Expand Up @@ -305,7 +309,7 @@ def __init__(self, *args, **kwargs):

def build(
self,
dataset: Dataset,
data: Union[Dataset, Iterable[Dataset]],
min_freq: int = 1,
embed: Optional[Embedding] = None,
norm: Callable = None
Expand All @@ -314,15 +318,19 @@ def build(
return

@wait
def build_vocab(dataset):
def build_vocab(data):
return Vocab(counter=Counter(piece
for seq in progress_bar(getattr(dataset, self.name))
for seq in progress_bar(getattr(data, self.name))
for token in seq
for piece in self.preprocess(token)),
min_freq=min_freq,
specials=self.specials,
unk_index=self.unk_index)
self.vocab = build_vocab(dataset)
if isinstance(data, Dataset):
data = [data]
self.vocab = build_vocab(data[0])
for i in data[1:]:
self.vocab.update(build_vocab(i))

if not embed:
self.embed = None
Expand Down Expand Up @@ -377,19 +385,23 @@ class ChartField(Field):

def build(
self,
dataset: Dataset,
data: Union[Dataset, Iterable[Dataset]],
min_freq: int = 1
) -> ChartField:
@wait
def build_vocab(dataset):
def build_vocab(data):
return Vocab(counter=Counter(i
for chart in progress_bar(getattr(dataset, self.name))
for chart in progress_bar(getattr(data, self.name))
for row in self.preprocess(chart)
for i in row if i is not None),
min_freq=min_freq,
specials=self.specials,
unk_index=self.unk_index)
self.vocab = build_vocab(dataset)
if isinstance(data, Dataset):
data = [data]
self.vocab = build_vocab(data[0])
for i in data[1:]:
self.vocab.update(build_vocab(i))
return self

def transform(self, charts: Iterable[List[List]]) -> Iterable[torch.Tensor]:
Expand Down

0 comments on commit 831df04

Please sign in to comment.