diff --git a/vedacore/misc/checkpoint.py b/vedacore/misc/checkpoint.py index e189417..ce43496 100644 --- a/vedacore/misc/checkpoint.py +++ b/vedacore/misc/checkpoint.py @@ -166,7 +166,9 @@ def optimizer_to_cpu(state_dict): for key, val in state_dict.items(): tmp = dict() for k, v in val.items(): - tmp[k] = v.cpu() + if isinstance(v, torch.Tensor): + v = v.cpu() + tmp[k] = v state_dict_cpu[key] = tmp return state_dict_cpu