Skip to content

Commit

Permalink
Update speculator.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sahilsuneja1 authored Jun 7, 2024
1 parent 50e34a9 commit ed4c7d6
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions fms_extras/models/speculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __init__(
for _ in range(n_predict)
]
)
if scale_input:
if self.scale_input:
self.ln0 = LayerNormParameterized(
emb_dim, elementwise_shift=False, elementwise_scale=False
)
Expand Down Expand Up @@ -209,7 +209,7 @@ def forward(
Has size [self.n_predict b n v] where v is vocab size.
"""
out = []
if scale_input:
if self.scale_input:
state = self.ln0(state) / (2**0.5)
for i in range(self.n_predict):
z = self.emb[i](inds[:, i : i + state.size(1)])
Expand Down

0 comments on commit ed4c7d6

Please sign in to comment.