Skip to content
This repository was archived by the owner on Apr 9, 2022. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -336,9 +336,9 @@ def _get_checklist_info(
target_checklist = agenda.new_tensor(target_checklist_list, dtype=torch.float)
if self._penalize_non_agenda_actions:
# All terminal actions are relevant
checklist_mask = torch.ones_like(target_checklist)
checklist_mask = torch.ones_like(target_checklist, dtype=torch.bool)
else:
checklist_mask = (target_checklist != 0).float()
checklist_mask = target_checklist != 0
return target_checklist, terminal_actions, checklist_mask

def _update_metrics(
Expand Down
5 changes: 4 additions & 1 deletion allennlp_semparse/models/text2sql_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,10 @@ def forward(
return outputs

def _get_initial_state(
self, encoder_outputs: torch.Tensor, mask: torch.Tensor, actions: List[List[ProductionRule]]
self,
encoder_outputs: torch.Tensor,
mask: torch.BoolTensor,
actions: List[List[ProductionRule]]
) -> GrammarBasedState:

batch_size = encoder_outputs.size(0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def _get_checklist_info(
terminal_actions = agenda.new_tensor(terminal_indices)
# (max_num_terminals, 1)
target_checklist = agenda.new_tensor(target_checklist_list, dtype=torch.float)
checklist_mask = (target_checklist != 0).float()
checklist_mask = target_checklist != 0
return target_checklist, terminal_actions, checklist_mask

def _get_state_cost(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def _get_initial_rnn_and_grammar_state(

neighbor_mask = util.get_text_field_mask(
{"ignored": {"ignored": neighbor_indices + 1}}, num_wrapping_dims=1
).float()
)

# Encoder initialized to easily obtain a masked average.
neighbor_encoder = TimeDistributed(
Expand Down Expand Up @@ -435,7 +435,7 @@ def _get_linking_probabilities(
self,
worlds: List[WikiTablesLanguage],
linking_scores: torch.FloatTensor,
question_mask: torch.LongTensor,
question_mask: torch.BoolTensor,
entity_type_dict: Dict[int, int],
) -> torch.FloatTensor:
"""
Expand All @@ -448,7 +448,7 @@ def _get_linking_probabilities(
worlds : ``List[WikiTablesLanguage]``
linking_scores : ``torch.FloatTensor``
Has shape (batch_size, num_question_tokens, num_entities).
question_mask: ``torch.LongTensor``
question_mask: ``torch.BoolTensor``
Has shape (batch_size, num_question_tokens).
entity_type_dict : ``Dict[int, int]``
This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id.
Expand Down Expand Up @@ -514,7 +514,7 @@ def _get_linking_probabilities(
probabilities = torch.cat(all_probabilities, dim=1)
batch_probabilities.append(probabilities)
batch_probabilities = torch.stack(batch_probabilities, dim=0)
return batch_probabilities * question_mask.unsqueeze(-1).float()
return batch_probabilities * question_mask.unsqueeze(-1)

@staticmethod
def _action_history_match(predicted: List[int], targets: torch.LongTensor) -> int:
Expand Down
4 changes: 2 additions & 2 deletions allennlp_semparse/state_machines/constrained_beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class ConstrainedBeamSearch:
A ``(batch_size, num_sequences, sequence_length)`` tensor containing the transition
sequences that we will search in. The values in this tensor must match whatever the
``State`` keeps in its ``action_history`` variable (typically this is action indices).
allowed_sequence_mask : ``torch.Tensor``
allowed_sequence_mask : ``torch.BoolTensor``
A ``(batch_size, num_sequences, sequence_length)`` tensor indicating whether each entry in
the ``allowed_sequences`` tensor is padding. The allowed sequences could be padded both on
the ``num_sequences`` dimension and the ``sequence_length`` dimension.
Expand All @@ -53,7 +53,7 @@ def __init__(
self,
beam_size: Optional[int],
allowed_sequences: torch.Tensor,
allowed_sequence_mask: torch.Tensor,
allowed_sequence_mask: torch.BoolTensor,
per_node_beam_size: int = None,
) -> None:
self._beam_size = beam_size
Expand Down
4 changes: 2 additions & 2 deletions allennlp_semparse/state_machines/states/checklist_statelet.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class ChecklistStatelet:
ideally be. It is the same size as ``terminal_actions``, and it contains 1 for each corresponding
action in the list that we want to see in the final logical form, and 0 for each corresponding
action that we do not.
checklist_mask : ``torch.Tensor``
checklist_mask : ``torch.BoolTensor``
Mask corresponding to ``terminal_actions``, indicating which of those actions are relevant
for checklist computation. For example, if the parser is penalizing non-agenda terminal
actions, all the terminal actions are relevant.
Expand All @@ -39,7 +39,7 @@ def __init__(
self,
terminal_actions: torch.Tensor,
checklist_target: torch.Tensor,
checklist_mask: torch.Tensor,
checklist_mask: torch.BoolTensor,
checklist: torch.Tensor,
terminal_indices_dict: Dict[int, int] = None,
) -> None:
Expand Down
4 changes: 2 additions & 2 deletions allennlp_semparse/state_machines/states/rnn_statelet.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class RnnStatelet:
mask unmodified, regardless of what's in the grouping for this state. We'll use the
``batch_indices`` for the group to pull pieces out of these lists when we're ready to
actually do some computation.
encoder_output_mask : ``List[torch.Tensor]``
encoder_output_mask : ``List[torch.BoolTensor]``
A list of variables, each of shape ``(input_sequence_length,)``, containing a mask over
question tokens for each batch instance. This is a list over batch elements, for the same
reasons as above.
Expand All @@ -53,7 +53,7 @@ def __init__(
previous_action_embedding: torch.Tensor,
attended_input: torch.Tensor,
encoder_outputs: List[torch.Tensor],
encoder_output_mask: List[torch.Tensor],
encoder_output_mask: List[torch.BoolTensor],
) -> None:
self.hidden_state = hidden_state
self.memory_cell = memory_cell
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,10 @@ def make_state(
return new_states

def attend_on_question(
self, query: torch.Tensor, encoder_outputs: torch.Tensor, encoder_output_mask: torch.Tensor
self,
query: torch.Tensor,
encoder_outputs: torch.Tensor,
encoder_output_mask: torch.BoolTensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Given a query (which is typically the decoder hidden state), compute an attention over the
Expand Down
4 changes: 2 additions & 2 deletions allennlp_semparse/state_machines/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

def construct_prefix_tree(
targets: Union[torch.Tensor, List[List[List[int]]]],
target_mask: Optional[Union[torch.Tensor, List[List[List[int]]]]] = None,
target_mask: Optional[Union[torch.BoolTensor, List[List[List[int]]]]] = None,
) -> List[Dict[Tuple[int, ...], Set[int]]]:
"""
Takes a list of valid target action sequences and creates a mapping from all possible
Expand All @@ -33,7 +33,7 @@ def construct_prefix_tree(
targets = targets.detach().cpu().numpy().tolist()
if target_mask is not None:
if not isinstance(target_mask, list):
target_mask = target_mask.detach().cpu().numpy().tolist()
target_mask = target_mask.detach().long().cpu().numpy().tolist()
else:
target_mask = [None for _ in targets]

Expand Down
2 changes: 1 addition & 1 deletion tests/models/nlvr/nlvr_coverage_semantic_parser_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_get_checklist_info(self):
target_checklist, terminal_actions, checklist_mask = checklist_info
assert_almost_equal(target_checklist.data.numpy(), [[1], [0], [1]])
assert_almost_equal(terminal_actions.data.numpy(), [[0], [2], [4]])
assert_almost_equal(checklist_mask.data.numpy(), [[1], [1], [1]])
assert_almost_equal(checklist_mask.long().data.numpy(), [[1], [1], [1]])

def test_initialize_weights_from_archive(self):
original_model_parameters = self.model.named_parameters()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_get_linking_probabilities(self):
[[0, 1, 8, 10, 10, 4], [3, 2, -1, -2, 1, -6]],
]
linking_scores = torch.FloatTensor(linking_scores)
question_mask = torch.LongTensor([[1, 1], [1, 0]])
question_mask = torch.tensor([[True, True], [True, False]])
_, entity_type_dict = self.model._get_type_vector(worlds, num_entities, linking_scores)

# (batch_size, num_question_tokens, num_entities)
Expand Down
14 changes: 7 additions & 7 deletions tests/state_machines/constrained_beam_search_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@ def test_search(self):
]
]
)
mask = torch.Tensor(
mask = torch.tensor(
[
[
[1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 0, 0, 0],
[1, 1, 1, 1, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0, 0],
[True, True, True, True, True, True, True],
[True, True, True, True, False, False, False],
[True, True, True, True, False, False, False],
[True, True, True, True, True, True, False],
[True, True, True, True, True, True, False],
[True, True, True, True, True, False, False],
]
]
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@ def setup_method(self):
self.targets = torch.Tensor(
[[[2, 3, 4], [1, 3, 4], [1, 2, 4]], [[3, 4, 0], [2, 3, 4], [0, 0, 0]]]
)
self.target_mask = torch.Tensor(
[[[1, 1, 1], [1, 1, 1], [1, 1, 1]], [[1, 1, 0], [1, 1, 1], [0, 0, 0]]]
self.target_mask = torch.tensor(
[
[[True, True, True], [True, True, True], [True, True, True]],
[[True, True, False], [True, True, True], [False, False, False]],
]
)

self.supervision = (self.targets, self.target_mask)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def setup_method(self):
self.encoder_outputs = torch.FloatTensor(
[[[1, 2], [3, 4], [5, 6]], [[10, 11], [12, 13], [14, 15]]]
)
self.encoder_output_mask = torch.FloatTensor([[1, 1, 1], [1, 1, 0]])
self.encoder_output_mask = torch.tensor([[True, True, True], [True, True, False]])
self.possible_actions = [
[
("e -> f", False, None),
Expand Down
7 changes: 5 additions & 2 deletions tests/state_machines/util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@ def test_create_allowed_transitions(self):
targets = torch.Tensor(
[[[2, 3, 4], [1, 3, 4], [1, 2, 4]], [[3, 4, 0], [2, 3, 4], [0, 0, 0]]]
)
target_mask = torch.Tensor(
[[[1, 1, 1], [1, 1, 1], [1, 1, 1]], [[1, 1, 0], [1, 1, 1], [0, 0, 0]]]
target_mask = torch.tensor(
[
[[True, True, True], [True, True, True], [True, True, True]],
[[True, True, False], [True, True, True], [False, False, False]],
]
)
prefix_tree = util.construct_prefix_tree(targets, target_mask)

Expand Down