Skip to content

Commit

Permalink
matching parameters counts with constructor of optimStateFp8
Browse files Browse the repository at this point in the history
  • Loading branch information
MirMustafaAli committed Nov 12, 2024
1 parent 366743c commit 38951ae
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion torchao/prototype/low_bit_optim/subclass_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class OptimStateFp8(TorchAOBaseTensor):
tensor_attrs = ["codes", "scale"]

@staticmethod
def __new__(cls, codes: Tensor, scale: Tensor):
def __new__(cls, codes: Tensor, scale: Tensor, k: Optional[Tensor] =None):
return Tensor._make_wrapper_subclass(cls, codes.shape, device=codes.device)

def __init__(self, codes: Tensor, scale: Tensor, k: Optional[Tensor] =None):
Expand Down

0 comments on commit 38951ae

Please sign in to comment.