Skip to content
This repository was archived by the owner on Apr 9, 2022. It is now read-only.

Commit cb2d2fd

Browse files
committed
Change mask dtype to bool
1 parent 339e617 commit cb2d2fd

17 files changed

+48
-36
lines changed

allennlp_semparse/models/atis/atis_semantic_parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def _get_initial_state(
280280
linking_scores: torch.Tensor,
281281
) -> GrammarBasedState:
282282
embedded_utterance = self._utterance_embedder(utterance)
283-
utterance_mask = util.get_text_field_mask(utterance).float()
283+
utterance_mask = util.get_text_field_mask(utterance)
284284

285285
batch_size = embedded_utterance.size(0)
286286
num_entities = max([len(world.entities) for world in worlds])

allennlp_semparse/models/nlvr/nlvr_coverage_semantic_parser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -344,9 +344,9 @@ def _get_checklist_info(
344344
target_checklist = agenda.new_tensor(target_checklist_list, dtype=torch.float)
345345
if self._penalize_non_agenda_actions:
346346
# All terminal actions are relevant
347-
checklist_mask = torch.ones_like(target_checklist)
347+
checklist_mask = torch.ones_like(target_checklist, dtype=torch.bool)
348348
else:
349-
checklist_mask = (target_checklist != 0).float()
349+
checklist_mask = target_checklist != 0
350350
return target_checklist, terminal_actions, checklist_mask
351351

352352
def _update_metrics(

allennlp_semparse/models/nlvr/nlvr_semantic_parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def forward(self): # type: ignore
8484
def _get_initial_rnn_state(self, sentence: Dict[str, torch.LongTensor]):
8585
embedded_input = self._sentence_embedder(sentence)
8686
# (batch_size, sentence_length)
87-
sentence_mask = util.get_text_field_mask(sentence).float()
87+
sentence_mask = util.get_text_field_mask(sentence)
8888

8989
batch_size = embedded_input.size(0)
9090

allennlp_semparse/models/text2sql_parser.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def forward(
143143
trailing dimension.
144144
"""
145145
embedded_utterance = self._utterance_embedder(tokens)
146-
mask = util.get_text_field_mask(tokens).float()
146+
mask = util.get_text_field_mask(tokens)
147147
batch_size = embedded_utterance.size(0)
148148

149149
# (batch_size, num_tokens, encoder_output_dim)
@@ -227,7 +227,10 @@ def forward(
227227
return outputs
228228

229229
def _get_initial_state(
230-
self, encoder_outputs: torch.Tensor, mask: torch.Tensor, actions: List[List[ProductionRule]]
230+
self,
231+
encoder_outputs: torch.Tensor,
232+
mask: torch.BoolTensor,
233+
actions: List[List[ProductionRule]]
231234
) -> GrammarBasedState:
232235

233236
batch_size = encoder_outputs.size(0)

allennlp_semparse/models/wikitables/wikitables_erm_semantic_parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,7 @@ def _get_checklist_info(
395395
terminal_actions = agenda.new_tensor(terminal_indices)
396396
# (max_num_terminals, 1)
397397
target_checklist = agenda.new_tensor(target_checklist_list, dtype=torch.float)
398-
checklist_mask = (target_checklist != 0).float()
398+
checklist_mask = target_checklist != 0
399399
return target_checklist, terminal_actions, checklist_mask
400400

401401
def _get_state_cost(

allennlp_semparse/models/wikitables/wikitables_semantic_parser.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -169,10 +169,10 @@ def _get_initial_rnn_and_grammar_state(
169169
table_text = table["text"]
170170
# (batch_size, question_length, embedding_dim)
171171
embedded_question = self._question_embedder(question)
172-
question_mask = util.get_text_field_mask(question).float()
172+
question_mask = util.get_text_field_mask(question)
173173
# (batch_size, num_entities, num_entity_tokens, embedding_dim)
174174
embedded_table = self._question_embedder(table_text, num_wrapping_dims=1)
175-
table_mask = util.get_text_field_mask(table_text, num_wrapping_dims=1).float()
175+
table_mask = util.get_text_field_mask(table_text, num_wrapping_dims=1)
176176

177177
batch_size, num_entities, num_entity_tokens, _ = embedded_table.size()
178178
num_question_tokens = embedded_question.size(1)
@@ -203,7 +203,7 @@ def _get_initial_rnn_and_grammar_state(
203203

204204
neighbor_mask = util.get_text_field_mask(
205205
{"ignored": {"ignored": neighbor_indices + 1}}, num_wrapping_dims=1
206-
).float()
206+
)
207207

208208
# Encoder initialized to easily obtain a masked average.
209209
neighbor_encoder = TimeDistributed(
@@ -435,7 +435,7 @@ def _get_linking_probabilities(
435435
self,
436436
worlds: List[WikiTablesLanguage],
437437
linking_scores: torch.FloatTensor,
438-
question_mask: torch.LongTensor,
438+
question_mask: torch.BoolTensor,
439439
entity_type_dict: Dict[int, int],
440440
) -> torch.FloatTensor:
441441
"""
@@ -448,7 +448,7 @@ def _get_linking_probabilities(
448448
worlds : ``List[WikiTablesLanguage]``
449449
linking_scores : ``torch.FloatTensor``
450450
Has shape (batch_size, num_question_tokens, num_entities).
451-
question_mask: ``torch.LongTensor``
451+
question_mask: ``torch.BoolTensor``
452452
Has shape (batch_size, num_question_tokens).
453453
entity_type_dict : ``Dict[int, int]``
454454
This is a mapping from ((batch_index * num_entities) + entity_index) to entity type id.
@@ -514,7 +514,7 @@ def _get_linking_probabilities(
514514
probabilities = torch.cat(all_probabilities, dim=1)
515515
batch_probabilities.append(probabilities)
516516
batch_probabilities = torch.stack(batch_probabilities, dim=0)
517-
return batch_probabilities * question_mask.unsqueeze(-1).float()
517+
return batch_probabilities * question_mask.unsqueeze(-1)
518518

519519
@staticmethod
520520
def _action_history_match(predicted: List[int], targets: torch.LongTensor) -> int:

allennlp_semparse/state_machines/constrained_beam_search.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class ConstrainedBeamSearch:
3737
A ``(batch_size, num_sequences, sequence_length)`` tensor containing the transition
3838
sequences that we will search in. The values in this tensor must match whatever the
3939
``State`` keeps in its ``action_history`` variable (typically this is action indices).
40-
allowed_sequence_mask : ``torch.Tensor``
40+
allowed_sequence_mask : ``torch.BoolTensor``
4141
A ``(batch_size, num_sequences, sequence_length)`` tensor indicating whether each entry in
4242
the ``allowed_sequences`` tensor is padding. The allowed sequences could be padded both on
4343
the ``num_sequences`` dimension and the ``sequence_length`` dimension.
@@ -53,7 +53,7 @@ def __init__(
5353
self,
5454
beam_size: Optional[int],
5555
allowed_sequences: torch.Tensor,
56-
allowed_sequence_mask: torch.Tensor,
56+
allowed_sequence_mask: torch.BoolTensor,
5757
per_node_beam_size: int = None,
5858
) -> None:
5959
self._beam_size = beam_size

allennlp_semparse/state_machines/states/checklist_statelet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class ChecklistStatelet:
2323
ideally be. It is the same size as ``terminal_actions``, and it contains 1 for each corresponding
2424
action in the list that we want to see in the final logical form, and 0 for each corresponding
2525
action that we do not.
26-
checklist_mask : ``torch.Tensor``
26+
checklist_mask : ``torch.BoolTensor``
2727
Mask corresponding to ``terminal_actions``, indicating which of those actions are relevant
2828
for checklist computation. For example, if the parser is penalizing non-agenda terminal
2929
actions, all the terminal actions are relevant.
@@ -39,7 +39,7 @@ def __init__(
3939
self,
4040
terminal_actions: torch.Tensor,
4141
checklist_target: torch.Tensor,
42-
checklist_mask: torch.Tensor,
42+
checklist_mask: torch.BoolTensor,
4343
checklist: torch.Tensor,
4444
terminal_indices_dict: Dict[int, int] = None,
4545
) -> None:

allennlp_semparse/state_machines/states/rnn_statelet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class RnnStatelet:
4040
mask unmodified, regardless of what's in the grouping for this state. We'll use the
4141
``batch_indices`` for the group to pull pieces out of these lists when we're ready to
4242
actually do some computation.
43-
encoder_output_mask : ``List[torch.Tensor]``
43+
encoder_output_mask : ``List[torch.BoolTensor]``
4444
A list of variables, each of shape ``(input_sequence_length,)``, containing a mask over
4545
question tokens for each batch instance. This is a list over batch elements, for the same
4646
reasons as above.
@@ -53,7 +53,7 @@ def __init__(
5353
previous_action_embedding: torch.Tensor,
5454
attended_input: torch.Tensor,
5555
encoder_outputs: List[torch.Tensor],
56-
encoder_output_mask: List[torch.Tensor],
56+
encoder_output_mask: List[torch.BoolTensor],
5757
) -> None:
5858
self.hidden_state = hidden_state
5959
self.memory_cell = memory_cell

allennlp_semparse/state_machines/transition_functions/basic_transition_function.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,10 @@ def make_state(
339339
return new_states
340340

341341
def attend_on_question(
342-
self, query: torch.Tensor, encoder_outputs: torch.Tensor, encoder_output_mask: torch.Tensor
342+
self,
343+
query: torch.Tensor,
344+
encoder_outputs: torch.Tensor,
345+
encoder_output_mask: torch.BoolTensor,
343346
) -> Tuple[torch.Tensor, torch.Tensor]:
344347
"""
345348
Given a query (which is typically the decoder hidden state), compute an attention over the

0 commit comments

Comments
 (0)