-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add convert_image_dtype to functionals #2078
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the quick PR!
I have a few comments. Also, I would propose to follow a bit more closely the implementation in TensorFlow, as there are a few cases that need to be taken into account.
Let me know what you think
@fmassa If I understand image = image / scale_factor(image.dtype)
if saturate:
image = torch.clamp(image, 0.0, 1.0)
image = image * scale_factor(dtype) The problem that Am I missing something here, or is it just that simple? Edit: I think I understand the problem. In theory my way works if it wasn't for this pesky precision on floating point tensors: import torch
def convert(x, dtype):
return x.mul(torch.iinfo(dtype).max).to(dtype)
x = torch.tensor(1.0, dtype=torch.float)
for dtype in (torch.short, torch.int, torch.long):
print(convert(x, dtype))
I will handle this, but I will probably take some time. |
I think this will take some more work and decisions. I dug into the
I'll go through them one by one. |
This is the simplest one as it is basically just a cast since the intervals are the same. One caveat though: even with |
In
As you can see the last interval is significantly smaller than the others. In general the last interval is given by IMO we should aim for something like this:
We could achieve this with import timeit
import torch
x = torch.ones((1, 3, 256, 256))
dtype = torch.uint8
c = float(torch.iinfo(dtype).max)
def theirs(x):
return x.mul(c + 0.5).to(dtype)
def ours(x):
return x.mul(c + 1.0).clamp(c).to(dtype)
number = 10000
their_time = timeit.timeit(lambda: theirs(x), number=number)
print(f"their time: {their_time /number * 1e6:.2f} µs")
our_time = timeit.timeit(lambda: ours(x), number=number)
print(f"our time: {our_time / number * 1e6:.2f} µs")
rel_diff = our_time / their_time - 1.0
print(f"rel. diff.: {rel_diff:+.1%}")
Mileage may vary for different systems or runs. While this is significant relative increase, I think the absolute difference from about 40 µs is probably acceptable.Thoughts? |
They cast to |
Hi Philip, About your points:
I'm not sure if we should pay a (fairly large) runtime penalty for the saturation check. We should pretty much never encounter any value larger than 3.4028234663852886e+38 for an image (and if we do encounter, this is probably an error on the user side).
This is a fair point, and it seems like TF implementation is suboptimal. floor(min(image * (c + 1), c)) why not do instead floor(image * (c + 1 - eps)) where eps is say 0.001? a = torch.linspace(0, 1, 10001)
print(a.mul(127.999).floor().int().bincount()) yields tensor([79, 78, 78, 78, 78, 78, 78, 79, 78, 78, 78, 78, 78, 78, 78, 79, 78, 78,
78, 78, 78, 78, 78, 79, 78, 78, 78, 78, 78, 78, 78, 79, 78, 78, 78, 78,
78, 78, 78, 79, 78, 78, 78, 78, 78, 78, 78, 79, 78, 78, 78, 78, 78, 78,
78, 79, 78, 78, 78, 78, 78, 78, 78, 79, 78, 78, 78, 78, 78, 78, 78, 79,
78, 78, 78, 78, 78, 78, 78, 79, 78, 78, 78, 78, 78, 78, 78, 79, 78, 78,
78, 78, 78, 78, 78, 79, 78, 78, 78, 78, 78, 78, 78, 79, 78, 78, 78, 78,
78, 78, 78, 79, 78, 78, 78, 78, 78, 78, 78, 79, 78, 78, 78, 78, 78, 78,
78, 79]) while print(a.mul(3.999).floor().int().bincount()) gives tensor([2501, 2501, 2500, 2499]) |
Also, as a general note, I think it might be better to move the functions inside the main function either outside, or inline the code in the main function. They are very short anyway, and are only called once so no point in having them as a function (plus we pay an overhead of having to re-define the function at every function call, and it makes it harder for torchscript as well) My preference would be to inline the helper functions in the main code |
Maybe I still got the
Fair point. While experimenting with it I've encountered another problem (same for my approach):
If we for example want to convert an import torch
c = float(torch.iinfo(torch.int32).max)
eps = 1e-3
image = torch.tensor(1.0, dtype=torch.float)
scaled_images = (
image * (c + 1 - eps),
image * (c + 0.5),
image * (c + 1) - 64,
image * (c + 1) - 65,
)
print("\n".join([str(image.to(torch.int32)) for image in scaled_images]))
For our example we have to at least subtract
Agreed. I keep them separate until the last commit to help myself keep a better overview. |
It all depends on what we mean by
I think this definition make sense, and I'm not sure we would want to clamp float values to be within 0-1 inside this function. Those are good points, and that's probably why the TF implementation has so many conditionals -- to make the implementation fast when possible. |
I don't know how or if that works for them. I've converted the import torch
def saturate_cast(value: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
def info(dtype):
if dtype.is_floating_point:
return torch.finfo(dtype)
else:
return torch.iinfo(dtype)
input_info = info(value.dtype)
output_info = info(dtype)
if input_info.min < output_info.min:
value = torch.max(value, torch.tensor(output_info.min, dtype=value.dtype))
if input_info.max > output_info.max:
value = torch.min(value, torch.tensor(output_info.max, dtype=value.dtype))
return value.to(dtype)
image = torch.tensor(1.0, dtype=torch.float32)
dtype = torch.int32
scale = torch.iinfo(dtype).max + 0.5
scaled = image * scale
print(scaled)
print(saturate_cast(scaled, dtype))
I've expected this much since this does not handle the problematic I've addressed above. I do not have the capability to setup |
Good point! I just tried the above snippet with TF (using colab), and got the same results as in PyTorch import tensorflow as tf
a = tf.fill([1], 2147483647.5, tf.float32)
print(tf.dtypes.saturate_cast(a, dtype=tf.int32)) which gives
I'm not sure what's the best approach we should follow here. |
I'll work something out and get back to you. |
I've played with it and I don't think this can be handled in an easy or concise way. With a little effort I can safeguard the upper limit, but with that the lower limit is no longer
Either I'm missing your point or I think this assumption is incorrect. This problem applies to every conversion of floating point tensors to int tensors with the same or a higher number of bits. So without further handling the conversion from I'm not sure how to move forward on this. Edit: I've found a way to handle both the upper and lower bounds. Let me know what you think: import torch
import itertools
float_dtypes = (torch.float32, torch.float64)
int_dtypes = (torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64)
int_nextpow2 = {
torch.float32: 23,
torch.float64: 52,
torch.uint8: 8,
torch.int8: 7,
torch.int16: 15,
torch.int32: 31,
torch.int64: 63,
}
def float_to_int(x, dtype):
max = torch.iinfo(dtype).max
m = int_nextpow2[x.dtype]
n = int_nextpow2[dtype]
if m >= n:
return (x * max).to(dtype)
else:
c = 2 ** (n - (m + 1))
return torch.max((x * max - c).to(dtype) + c - 1, torch.zeros(1, dtype=dtype))
for float_dtype, int_dtype in itertools.product(float_dtypes, int_dtypes):
x = torch.tensor((0.0, 1.0), dtype=float_dtype)
y = float_to_int(x, int_dtype)
actual = tuple(y.tolist())
desired = (0, torch.iinfo(int_dtype).max)
if actual != desired:
print(
(
f"Conversion from {float_dtype} to {int_dtype} did not work as "
f"expected: {actual} != {desired}"
)
) The The idea is to check if the float |
Hi @pmeier The issue with the last solution you proposed is that we get back to the original behavior that were trying to fix, which is that now the a = torch.linspace(0, 1, 10001)
r = float_to_int(a, torch.uint8).bincount()
print(r) gives us tensor([40, 39, 39, 39, 40, 39, 39, 39, 39, 40, 39, 39, 39, 40, 39, 39, 39, 39,
40, 39, 39, 39, 39, 40, 39, 39, 39, 40, 39, 39, 39, 39, 40, 39, 39, 39,
39, 40, 39, 39, 39, 40, 39, 39, 39, 39, 40, 39, 39, 39, 39, 40, 39, 39,
39, 40, 39, 39, 39, 39, 40, 39, 39, 39, 40, 39, 39, 39, 39, 40, 39, 39,
39, 39, 40, 39, 39, 39, 40, 39, 39, 39, 39, 40, 39, 39, 39, 39, 40, 39,
39, 39, 40, 39, 39, 39, 39, 40, 39, 39, 39, 39, 40, 39, 39, 39, 40, 39,
39, 39, 39, 40, 39, 39, 39, 40, 39, 39, 39, 39, 40, 39, 39, 39, 39, 40,
39, 39, 39, 40, 39, 39, 39, 39, 40, 39, 39, 39, 39, 40, 39, 39, 39, 40,
39, 39, 39, 39, 40, 39, 39, 39, 39, 40, 39, 39, 39, 40, 39, 39, 39, 39,
40, 39, 39, 39, 40, 39, 39, 39, 39, 40, 39, 39, 39, 39, 40, 39, 39, 39,
40, 39, 39, 39, 39, 40, 39, 39, 39, 39, 40, 39, 39, 39, 40, 39, 39, 39,
39, 40, 39, 39, 39, 39, 40, 39, 39, 39, 40, 39, 39, 39, 39, 40, 39, 39,
39, 40, 39, 39, 39, 39, 40, 39, 39, 39, 39, 40, 39, 39, 39, 40, 39, 39,
39, 39, 40, 39, 39, 39, 39, 40, 39, 39, 39, 40, 39, 39, 39, 39, 40, 39,
39, 39, 39, 1]) Maybe there is an easy fix for this though (like passing Proposal to move forwardIn order to move forward, I would propose that we only allow float -> integer conversion if the dtype allows for the correct behavior, and raise an error if this is not the case. So we would only allow converting Thoughts? |
Good catch! Seems I was so focused on fixing this that I forgot that.
I think that is reasonable. Do you want me to completely disable this or add a |
I would say to completely disable this for now, and raise an error (with a good error message) if the user tried to do this. We can then see how many users will complain about this in the future. Also, one thing I noticed in the TF |
@pmeier do you think you would have some time to work on this sometime this week? Otherwise I can build on top of it and get it merged. |
@fmassa Sorry for the hold-up. I'm covered until Friday. If you need this before, feel free to build on top of it. Otherwise I'll work on it on Friday and should get it done if I don't stumble upon another issue that needs discussing. |
Sounds good, thanks for the heads up! This can wait until Friday, thanks a lot! |
@fmassa Maybe we can discuss this before I work on it further: the last missing conversion is The conversion of a black pixel, i.e.
The first part (
Thus, if we convert from higher number of bits to a lower ( Is this something you want to address further or simply leave it as is? In |
I think it's fine if we don't map exactly 255 to 2147483647 (or 32767 for |
613f0cd
to
28e2fbf
Compare
nit: but not in this PR, just to keep in mind |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thanks a lot @pmeier !
As a follow-up PR, could you add tests for torchscript support as well?
Could you point me to an example how to do that? |
@pmeier it will basically be another line in the test that checks that vision/test/test_functional_tensor.py Lines 16 to 20 in c2e8a00
|
This adds a
convert_image_dtype
function as discussed in #2060 (comment).Idea behind this function is to first convert the image into the interval
[0.0, 1.0]
and afterwards in the desired interval of the givendtype
.