Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit f8b8d89

Browse files
committed
Okay
1 parent 2da3b5c commit f8b8d89

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

float8_experimental/float8_tensor.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -101,17 +101,17 @@ def __repr__(self):
101101

102102
def __tensor_flatten__(self):
103103
ctx = {
104-
"_scale": self._scale,
104+
# "_scale": self._scale,
105105
"_orig_dtype": self._orig_dtype,
106106
}
107107
# return ("_data", "_scale"), (self._orig_dtype)
108-
return ["_data"], ctx
108+
return ["_data", "_scale"], ctx
109109

110110
@staticmethod
111111
def __tensor_unflatten__(inner_tensors: Dict, metadata):
112-
assert len(inner_tensors) == 1
113-
# return Float8Tensor(tensors["_data"], tensors["_scale"], metadatas[0])
114-
return Float8Tensor(inner_tensors["_data"], metadata["_scale"], metadata["_orig_dtype"])
112+
assert len(inner_tensors) == 2
113+
return Float8Tensor(inner_tensors["_data"], inner_tensors["_scale"], metadata["_orig_dtype"])
114+
# return Float8Tensor(inner_tensors["_data"], metadata["_scale"], metadata["_orig_dtype"])
115115

116116
def to_original_precision(self):
117117
return FromFloat8ConstrFunc.apply(self)

0 commit comments

Comments
 (0)