diff --git a/modules/dml/hijack/torch.py b/modules/dml/hijack/torch.py index ff5ee4df8e5..3eb6c2a019c 100644 --- a/modules/dml/hijack/torch.py +++ b/modules/dml/hijack/torch.py @@ -49,3 +49,17 @@ def pow_(self: torch.Tensor, *args, **kwargs): return _pow_(self.cpu(), *args, **kwargs).to(self.device) return _pow_(self, *args, **kwargs) torch.Tensor.pow_ = pow_ + + +_load = torch.load +def load(f, map_location, *args, **kwargs): + if type(map_location) in (str, torch.device,): + device = torch.device(map_location) + if device.type == "privateuseone": + data = _load(f, *args, map_location="cpu", **kwargs) + for k in data: + for weight in data[k]: + data[k][weight] = data[k][weight].to(device) + return data + return _load(f, *args, map_location=map_location, **kwargs) +torch.load = load