Skip to content

Commit ba331d6

Browse files
authored
Fix to avoid RuntimeError (#2138)
1 parent 0fb371d commit ba331d6

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

litgpt/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -872,7 +872,7 @@ def build_rope_cache(
872872
theta = theta / factor
873873

874874
# Create position indices `[0, 1, ..., seq_len - 1]`
875-
seq_idx = torch.arange(seq_len, device=device) / condense_ratio
875+
seq_idx = torch.arange(seq_len, device=device).float() / condense_ratio
876876

877877
# Calculate the product of position index and $\theta_i$
878878
idx_theta = torch.outer(seq_idx, theta).repeat(1, 2)

0 commit comments

Comments
 (0)