Skip to content

Commit 9ece6a8

Browse files
committed
update code
1 parent 82c0711 commit 9ece6a8

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

lmdeploy/pytorch/backends/dlinfer/rotary_embedding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,13 @@ def _rotary_embedding_fwd(position_ids: torch.Tensor,
4040
class DlinferRotaryEmbeddingImpl(RotaryEmbeddingImpl, nn.Module):
4141
"""Base rotary embedding."""
4242

43-
def __init__(self, dim: int, base: int = 10000, scaling_factor: float = 1.0):
43+
def __init__(self, dim: int, base: float = 10000.0, scaling_factor: float = 1.0):
4444
super().__init__()
4545
self.scaling_factor = scaling_factor
4646
self.dim = dim
4747
self.base = base
4848
# yapf: disable
49-
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device='cuda').float() / self.dim))
49+
inv_freq = 1.0 / (self.base**(torch.arange(0, self.dim, 2, dtype=torch.float, device='cuda') / self.dim))
5050
# yapf: enable
5151
self.register_buffer('inv_freq', inv_freq, persistent=False)
5252

0 commit comments

Comments
 (0)