Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Commit

Permalink
Add flag to control merging of source/target vocab
Browse files Browse the repository at this point in the history
Summary:
Add flag to control merging of source/target vocab
  - We may not want to merge source/target vocab for all usecases.

Reviewed By: Peyman-Heidari

Differential Revision: D37667016

fbshipit-source-id: 9e95c55c7e76ee11e0d5bf64d2102dde19adbf4f
  • Loading branch information
Shashank Jain authored and facebook-github-bot committed Jul 7, 2022
1 parent 6de2829 commit 019355f
Showing 1 changed file with 31 additions and 23 deletions.
54 changes: 31 additions & 23 deletions pytext/data/decoupled_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ class Config(Data.Config):
# For cloze-style parsing, ontology tokens appear in the source sequence, thus
# this controls whether the source tensorizer will receive the merged vocab.
merge_source_vocab: bool = False
merge_vocab: bool = True

@classmethod
def from_config(
Expand All @@ -164,6 +165,7 @@ def from_config(
noisy_decoupling=config.noisy_decoupling,
filter_target_ood_slots=config.filter_target_ood_slots,
merge_source_vocab=config.merge_source_vocab,
merge_vocab=config.merge_vocab,
unk_token=config.unk_token,
pad_token=config.pad_token,
bos_token=config.bos_token,
Expand Down Expand Up @@ -204,6 +206,7 @@ def __init__(
noisy_decoupling: bool = False,
filter_target_ood_slots: bool = True,
merge_source_vocab: bool = False,
merge_vocab: bool = True,
unk_token: str = Config.unk_token,
pad_token: str = Config.pad_token,
bos_token: str = Config.bos_token,
Expand All @@ -221,6 +224,8 @@ def __init__(
)
self.filter_target_ood_slots = filter_target_ood_slots
self.merge_source_vocab = merge_source_vocab
self.merge_vocab = merge_vocab

if decoupled_source and noisy_decoupling:
self.decoupled_func_source = get_noisy_decoupled
elif decoupled_source:
Expand Down Expand Up @@ -251,31 +256,34 @@ def __init__(

# Merge source and target vocabs, keeping them aligned. This is required
# by the implementation of the pointer mechanism in the model's decoder.
src_vocab = self.tensorizers["src_seq_tokens"].vocab
trg_vocab = self.tensorizers["trg_seq_tokens"].vocab
tokens_not_in_src = set(trg_vocab._vocab).difference(set(src_vocab._vocab))
merged_tokens = src_vocab._vocab.copy() + [
w for w in trg_vocab._vocab if w in tokens_not_in_src
] # Order stays consistent with trg vocab. No randomness from set.
merged_vocab = Vocabulary(
vocab_list=merged_tokens,
replacements=None,
unk_token=unk_token,
pad_token=pad_token,
bos_token=bos_token,
eos_token=eos_token,
mask_token=mask_token,
)
print(f"Source vocab: {len(src_vocab)} entries.")
print(f"Target vocab: {len(trg_vocab)} entries.")
print(f"Merged vocab: {len(merged_vocab)} entries.")
if merge_vocab:
src_vocab = self.tensorizers["src_seq_tokens"].vocab
trg_vocab = self.tensorizers["trg_seq_tokens"].vocab
tokens_not_in_src = set(trg_vocab._vocab).difference(
set(src_vocab._vocab)
)
merged_tokens = src_vocab._vocab.copy() + [
w for w in trg_vocab._vocab if w in tokens_not_in_src
] # Order stays consistent with trg vocab. No randomness from set.
merged_vocab = Vocabulary(
vocab_list=merged_tokens,
replacements=None,
unk_token=unk_token,
pad_token=pad_token,
bos_token=bos_token,
eos_token=eos_token,
mask_token=mask_token,
)
print(f"Source vocab: {len(src_vocab)} entries.")
print(f"Target vocab: {len(trg_vocab)} entries.")
print(f"Merged vocab: {len(merged_vocab)} entries.")

if self.merge_source_vocab:
self.tensorizers["src_seq_tokens"].vocab = merged_vocab
print("\tInitialized source tensorizer with merged vocab.")
if self.merge_source_vocab:
self.tensorizers["src_seq_tokens"].vocab = merged_vocab
print("\tInitialized source tensorizer with merged vocab.")

self.tensorizers["trg_seq_tokens"].vocab = merged_vocab
print("\tInitialized target tensorizer with merged vocab.")
self.tensorizers["trg_seq_tokens"].vocab = merged_vocab
print("\tInitialized target tensorizer with merged vocab.")

def numberize_rows(self, rows):
source_column = getattr(self.tensorizers["src_seq_tokens"], "text_column", None)
Expand Down

0 comments on commit 019355f

Please sign in to comment.