forked from AUTOMATIC1111/stable-diffusion-webui
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add option for float32 sampling with float16 UNet
This also handles type casting so that ROCm and MPS torch devices work correctly without --no-half. One cast is required for deepbooru in deepbooru_model.py, some explicit casting is required for img2img and inpainting. depth_model can't be converted to float16 or it won't work correctly on some systems (it's known to have issues on MPS) so in sd_models.py model.depth_model is removed for model.half().
- Loading branch information
Showing
8 changed files
with
82 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import importlib | ||
|
||
class CondFunc: | ||
def __new__(cls, orig_func, sub_func, cond_func): | ||
self = super(CondFunc, cls).__new__(cls) | ||
if isinstance(orig_func, str): | ||
func_path = orig_func.split('.') | ||
for i in range(len(func_path)-2, -1, -1): | ||
try: | ||
resolved_obj = importlib.import_module('.'.join(func_path[:i])) | ||
break | ||
except ImportError: | ||
pass | ||
for attr_name in func_path[i:-1]: | ||
resolved_obj = getattr(resolved_obj, attr_name) | ||
orig_func = getattr(resolved_obj, func_path[-1]) | ||
setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs)) | ||
self.__init__(orig_func, sub_func, cond_func) | ||
return lambda *args, **kwargs: self(*args, **kwargs) | ||
def __init__(self, orig_func, sub_func, cond_func): | ||
self.__orig_func = orig_func | ||
self.__sub_func = sub_func | ||
self.__cond_func = cond_func | ||
def __call__(self, *args, **kwargs): | ||
if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs): | ||
return self.__sub_func(self.__orig_func, *args, **kwargs) | ||
else: | ||
return self.__orig_func(*args, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters