-
Notifications
You must be signed in to change notification settings - Fork 6.1k
[Utils] Add deprecate function and move testing_utils under utils #659
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6db4338
8447d88
5cde36b
2d80a08
e1af3ab
04909f8
e51a656
3b2d8af
dec5930
812fb8c
7a8eb4a
e6187f4
43f984c
82c3abb
99e8480
a9d9ba1
b6ae0d0
4db2c92
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,15 +16,14 @@ | |
# and https://github.com/hojonathanho/diffusion | ||
|
||
import math | ||
import warnings | ||
from dataclasses import dataclass | ||
from typing import Optional, Tuple, Union | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from ..configuration_utils import ConfigMixin, register_to_config | ||
from ..utils import BaseOutput | ||
from ..utils import BaseOutput, deprecate | ||
from .scheduling_utils import SchedulerMixin | ||
|
||
|
||
|
@@ -122,12 +121,12 @@ def __init__( | |
steps_offset: int = 0, | ||
**kwargs, | ||
): | ||
if "tensor_format" in kwargs: | ||
warnings.warn( | ||
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`." | ||
"If you're running your code in PyTorch, you can safely remove this argument.", | ||
DeprecationWarning, | ||
) | ||
deprecate( | ||
"tensor_format", | ||
"0.5.0", | ||
"If you're running your code in PyTorch, you can safely remove this argument.", | ||
take_from=kwargs, | ||
) | ||
|
||
if trained_betas is not None: | ||
self.betas = torch.from_numpy(trained_betas) | ||
|
@@ -175,17 +174,10 @@ def set_timesteps(self, num_inference_steps: int, **kwargs): | |
num_inference_steps (`int`): | ||
the number of diffusion steps used when generating samples with a pre-trained model. | ||
""" | ||
|
||
offset = self.config.steps_offset | ||
|
||
if "offset" in kwargs: | ||
warnings.warn( | ||
"`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0." | ||
" Please pass `steps_offset` to `__init__` instead.", | ||
DeprecationWarning, | ||
) | ||
|
||
offset = kwargs["offset"] | ||
deprecated_offset = deprecate( | ||
"offset", "0.5.0", "Please pass `steps_offset` to `__init__` instead.", take_from=kwargs | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. bump as we would have already missed this version |
||
) | ||
offset = deprecated_offset or self.config.steps_offset | ||
|
||
self.num_inference_steps = num_inference_steps | ||
step_ratio = self.config.num_train_timesteps // self.num_inference_steps | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,13 +15,13 @@ | |
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim | ||
|
||
import math | ||
import warnings | ||
from typing import Optional, Tuple, Union | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from ..configuration_utils import ConfigMixin, register_to_config | ||
from ..utils import deprecate | ||
from .scheduling_utils import SchedulerMixin, SchedulerOutput | ||
|
||
|
||
|
@@ -102,12 +102,12 @@ def __init__( | |
steps_offset: int = 0, | ||
**kwargs, | ||
): | ||
if "tensor_format" in kwargs: | ||
warnings.warn( | ||
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`." | ||
"If you're running your code in PyTorch, you can safely remove this argument.", | ||
DeprecationWarning, | ||
) | ||
deprecate( | ||
"tensor_format", | ||
"0.5.0", | ||
"If you're running your code in PyTorch, you can safely remove this argument.", | ||
take_from=kwargs, | ||
) | ||
|
||
if trained_betas is not None: | ||
self.betas = torch.from_numpy(trained_betas) | ||
|
@@ -155,16 +155,10 @@ def set_timesteps(self, num_inference_steps: int, **kwargs) -> torch.FloatTensor | |
num_inference_steps (`int`): | ||
the number of diffusion steps used when generating samples with a pre-trained model. | ||
""" | ||
|
||
offset = self.config.steps_offset | ||
|
||
if "offset" in kwargs: | ||
warnings.warn( | ||
"`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0." | ||
" Please pass `steps_offset` to `__init__` instead." | ||
) | ||
|
||
offset = kwargs["offset"] | ||
deprecated_offset = deprecate( | ||
"offset", "0.5.0", "Please pass `steps_offset` to `__init__` instead.", take_from=kwargs | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. bump as we would have already missed this version |
||
) | ||
offset = deprecated_offset or self.config.steps_offset | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great naming here, the intention is very clear! |
||
|
||
self.num_inference_steps = num_inference_steps | ||
step_ratio = self.config.num_train_timesteps // self.num_inference_steps | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a big change and I like to keep it for a bit until we force people to not pass
steps_offset
anymore.