@@ -59,7 +59,12 @@ def process_tensor(tensor):
5959 "info" : info ,
6060 }
6161 else :
62- return {"type" : "big_int_tensor" , "data" : tensor .clone (), "info" : info }
62+ return {
63+ "type" : "big_int_tensor_by_range" ,
64+ "min_val" : tensor .min ().item (),
65+ "max_val" : tensor .max ().item (),
66+ "info" : info ,
67+ }
6368 elif tensor .numel () < 1024 :
6469 return {"type" : "small_tensor" , "data" : tensor .clone (), "info" : info }
6570 else :
@@ -73,16 +78,25 @@ def process_tensor(tensor):
7378 processed_inputs = {"type" : "unknown" , "value" : example_inputs }
7479
7580 def handle_named_tensors (tensor ):
76- data_value = None
77- data_type = "random_tensor"
81+ info = tensor_info (tensor )
7882 if tensor .dtype in [torch .int8 , torch .int16 , torch .int32 , torch .int64 ]:
7983 if tensor .numel () < 1024 :
80- data_type = "small_int_tensor"
81- data_value = tensor .clone ()
84+ return {
85+ "info" : info ,
86+ "data" : tensor .clone (),
87+ "type" : "small_int_tensor" ,
88+ }
8289 else :
83- data_type = "big_int_tensor"
84- info = tensor_info (tensor )
85- return {"info" : info , "data" : data_value , "type" : data_type }
90+ return {
91+ "info" : info ,
92+ "min_val" : tensor .min ().item (),
93+ "max_val" : tensor .max ().item (),
94+ "type" : "big_int_tensor_by_range" ,
95+ }
96+ if tensor .numel () < 1024 :
97+ return {"info" : info , "data" : tensor .clone (), "type" : "small_tensor" }
98+ else :
99+ return {"info" : info , "data" : None , "type" : "random_tensor" }
86100
87101 processed_weights = {
88102 key : handle_named_tensors (tensor ) for key , tensor in state_dict .items ()
@@ -114,46 +128,46 @@ def format_data(data):
114128 return "None"
115129 elif isinstance (data , torch .Tensor ):
116130 if data .dtype .is_floating_point :
117- return "[{}]" .format (", " .join (f"{ x :.6f} " for x in data .tolist ()))
131+ return "[{}]" .format (
132+ ", " .join (f"{ x :.6f} " for x in data .flatten ().tolist ())
133+ )
118134 else :
119- return "[{}]" .format (", " .join (f"{ x } " for x in data .tolist ()))
135+ return "[{}]" .format (", " .join (f"{ x } " for x in data .flatten (). tolist ()))
120136 else :
121137 return repr (data )
122138
123139 def process_tensor_info (tensor_info , name_prefix = "example_input" ):
124- data_list = None
125- if "input_" in tensor_info ["name" ]:
126- if tensor_info ["type" ] in ["small_tensor" , "small_int_tensor" ]:
127- data_list = tensor_info ["data" ].flatten ()
128- elif tensor_info ["type" ] == "big_int_tensor" :
129- data_list = f"pt-filename:xxx-key"
130- else :
131- pass
132- else :
133- if tensor_info ["type" ] == "small_int_tensor" :
134- data_list = tensor_info ["data" ].flatten ()
135- if tensor_info ["type" ] == "big_int_tensor" :
136- raise ValueError (
137- "Unexpected cases: there are weights in big tensor of int type "
138- )
140+ tensor_type = tensor_info .get ("type" )
139141 info = tensor_info .get ("info" , {})
140142 dtype = info .get ("dtype" , "torch.float" )
141143 shape = info .get ("shape" , [])
142144 device = info .get ("device" , "cpu" )
143145 mean = info .get ("mean" , 0.0 )
144146 std = info .get ("std" , 1.0 )
145147 uid = f"{ name_prefix } _tensor_meta_{ tensor_info .get ('name' , '' )} "
146- return [
148+
149+ lines = [
147150 (f"class { uid } :" ),
148151 (f"\t name = \" { tensor_info .get ('name' , '' )} \" " ),
149152 (f"\t shape = { shape } " ),
150153 (f'\t dtype = "{ dtype } "' ),
151154 (f'\t device = "{ device } "' ),
152155 (f"\t mean = { get_limited_precision_float_str (mean )} " ),
153156 (f"\t std = { get_limited_precision_float_str (std )} " ),
154- (f"\t data = { format_data (data_list )} " ),
155- ("" ),
156157 ]
158+ if tensor_type == "big_int_tensor_by_range" :
159+ lines .append (f"\t min_val = { tensor_info ['min_val' ]} " )
160+ lines .append (f"\t max_val = { tensor_info ['max_val' ]} " )
161+ elif "data" in tensor_info :
162+ data_list = (
163+ tensor_info ["data" ].flatten ()
164+ if isinstance (tensor_info ["data" ], torch .Tensor )
165+ else tensor_info ["data" ]
166+ )
167+ lines .append (f"\t data = { format_data (data_list )} " )
168+
169+ lines .append ("" )
170+ return lines
157171
158172 input_infos = converted ["input_info" ]
159173 if isinstance (input_infos , dict ):
@@ -202,7 +216,16 @@ def convert_meta_classes_to_tensors(file_path):
202216 }
203217 data_value = None
204218 data_type = getattr (torch , attrs .get ("dtype" , "torch.float" ).split ("." )[- 1 ])
205- if attrs .get ("data" ) is not None :
219+ shape = attrs .get ("shape" , [])
220+
221+ if "min_val" in attrs and "max_val" in attrs :
222+ min_val = attrs ["min_val" ]
223+ max_val = attrs ["max_val" ]
224+ # torch.randint's upper bound is exclusive, so add 1
225+ data_value = torch .randint (
226+ min_val , max_val + 1 , size = shape , dtype = data_type
227+ )
228+ elif attrs .get ("data" ) is not None :
206229 if isinstance (attrs .get ("data" ), str ):
207230 raise ValueError ("Unimplemented" )
208231 else :
0 commit comments