Skip to content

Commit

Permalink
fix gumbel temp arg (#1438)
Browse files Browse the repository at this point in the history
Summary:
Fix #2897

Pull Request resolved: fairinternal/fairseq-py#1438

Reviewed By: myleott

Differential Revision: D24992106

Pulled By: alexeib

fbshipit-source-id: 0cb15c2e865c3e8f7950e8f5e6c54c5000637af2
  • Loading branch information
alexeib authored and facebook-github-bot committed Nov 16, 2020
1 parent 73748c8 commit e2e735d
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion fairseq/modules/gumbel_vector_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,10 @@ def block(input_dim, output_dim):
nn.init.normal_(self.weight_proj.weight, mean=0, std=1)
nn.init.zeros_(self.weight_proj.bias)

assert len(temp) == 3, temp
if isinstance(temp, str):
import ast
temp = ast.literal_eval(temp)
assert len(temp) == 3, f"{temp}, {len(temp)}"

self.max_temp, self.min_temp, self.temp_decay = temp
self.curr_temp = self.max_temp
Expand Down

0 comments on commit e2e735d

Please sign in to comment.