@@ -265,43 +265,6 @@ def __init__(self, config, **kwargs):
265265 # self.config = config
266266 self .transformer = OPTStack (self .config )
267267
268- # def forward(
269- # self,
270- # **data,
271- # ):
272- # input_ids = data.get("input_ids", None)
273- # # attention_mask = data.get("attention_mask", None)
274- # # position_ids = data.get("position_ids", None)
275- # labels = data.get("labels", None)
276- # use_cache = data.get("use_cache", None)
277- # output_attentions = data.get("output_attentions", None)
278- # output_hidden_states = data.get("output_hidden_states", True)
279- #
280- # transformer_outputs = self.transformer(
281- # input_ids,
282- # attention_mask=None,
283- # position_ids=None,
284- # use_cache=use_cache,
285- # output_attentions=output_attentions,
286- # output_hidden_states=output_hidden_states,
287- # )
288- # hidden_states = transformer_outputs
289- #
290- # lm_logits = self.lm_head(hidden_states)
291- #
292- # return_data = {"logits": lm_logits}
293- # if labels is not None:
294- # # Shift so that tokens < n predict n
295- # shift_logits = lm_logits[..., :-1, :].contiguous()
296- # shift_labels = labels[..., 1:].contiguous()
297- # loss_fct = nn.CrossEntropyLoss()
298- # loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
299- # shift_labels.view(-1))
300- # return_data["loss"] = loss
301- #
302- # return return_data
303-
304-
305268 def load_weights (self , checkpoint_path ):
306269 checkpoint = torch .load (checkpoint_path ,
307270 map_location = torch .device ("cpu" ))
0 commit comments