-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Add support of mode and remove channels #3024
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
#pragma once | ||
|
||
/* Should be kept in-sync with Python ImageReadMode enum */ | ||
using ImageReadMode = int64_t; | ||
#define IMAGE_READ_MODE_UNCHANGED 0 | ||
#define IMAGE_READ_MODE_GRAY 1 | ||
#define IMAGE_READ_MODE_GRAY_ALPHA 2 | ||
#define IMAGE_READ_MODE_RGB 3 | ||
#define IMAGE_READ_MODE_RGB_ALPHA 4 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: new line I'll add it along with the other proposed corrections to minimize CI runs.
Comment on lines
+5
to
+9
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: this is fine as is, but I wonder if a more modern pattern now is to use |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,8 @@ | ||
#pragma once | ||
|
||
#include "readjpeg_cpu.h" | ||
#include "readpng_cpu.h" | ||
#include <torch/torch.h> | ||
#include "image_read_mode.h" | ||
|
||
C10_EXPORT torch::Tensor decode_image( | ||
const torch::Tensor& data, | ||
int64_t channels = 0); | ||
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,9 @@ | ||
#include "readjpeg_cpu.h" | ||
|
||
#include <ATen/ATen.h> | ||
#include <string> | ||
|
||
#if !JPEG_FOUND | ||
torch::Tensor decodeJPEG(const torch::Tensor& data, int64_t channels) { | ||
torch::Tensor decodeJPEG(const torch::Tensor& data, ImageReadMode mode) { | ||
TORCH_CHECK( | ||
false, "decodeJPEG: torchvision not compiled with libjpeg support"); | ||
} | ||
|
@@ -69,16 +68,13 @@ static void torch_jpeg_set_source_mgr( | |
src->pub.next_input_byte = src->data; | ||
} | ||
|
||
torch::Tensor decodeJPEG(const torch::Tensor& data, int64_t channels) { | ||
torch::Tensor decodeJPEG(const torch::Tensor& data, ImageReadMode mode) { | ||
// Check that the input tensor dtype is uint8 | ||
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); | ||
// Check that the input tensor is 1-dimensional | ||
TORCH_CHECK( | ||
data.dim() == 1 && data.numel() > 0, | ||
"Expected a non empty 1-dimensional tensor"); | ||
TORCH_CHECK( | ||
channels == 0 || channels == 1 || channels == 3, | ||
"Number of channels not supported"); | ||
|
||
struct jpeg_decompress_struct cinfo; | ||
struct torch_jpeg_error_mgr jerr; | ||
|
@@ -102,30 +98,33 @@ torch::Tensor decodeJPEG(const torch::Tensor& data, int64_t channels) { | |
// read info from header. | ||
jpeg_read_header(&cinfo, TRUE); | ||
|
||
int current_channels = cinfo.num_components; | ||
int channels = cinfo.num_components; | ||
|
||
if (channels > 0 && channels != current_channels) { | ||
switch (channels) { | ||
case 1: // Gray | ||
cinfo.out_color_space = JCS_GRAYSCALE; | ||
if (mode != IMAGE_READ_MODE_UNCHANGED) { | ||
switch (mode) { | ||
case IMAGE_READ_MODE_GRAY: | ||
if (cinfo.jpeg_color_space != JCS_GRAYSCALE) { | ||
cinfo.out_color_space = JCS_GRAYSCALE; | ||
channels = 1; | ||
} | ||
break; | ||
case 3: // RGB | ||
cinfo.out_color_space = JCS_RGB; | ||
case IMAGE_READ_MODE_RGB: | ||
if (cinfo.jpeg_color_space != JCS_RGB) { | ||
cinfo.out_color_space = JCS_RGB; | ||
channels = 3; | ||
} | ||
break; | ||
/* | ||
* Libjpeg does not support converting from CMYK to grayscale etc. There | ||
* is a way to do this but it involves converting it manually to RGB: | ||
* https://github.com/tensorflow/tensorflow/blob/86871065265b04e0db8ca360c046421efb2bdeb4/tensorflow/core/lib/jpeg/jpeg_mem.cc#L284-L313 | ||
* | ||
*/ | ||
default: | ||
jpeg_destroy_decompress(&cinfo); | ||
TORCH_CHECK(false, "Invalid number of output channels."); | ||
TORCH_CHECK(false, "Provided mode not supported"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note that if an unsupported conversion operation is requested, for instance CMYK to RGB, this check is not going to trigger. Instead we will get a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe it would be good to be explicit here and mention that it's not supported because the input is jpeg? Otherwise it could be confusing to the user, wdyt? |
||
} | ||
|
||
jpeg_calc_output_dimensions(&cinfo); | ||
} else { | ||
channels = current_channels; | ||
} | ||
|
||
jpeg_start_decompress(&cinfo); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,8 @@ | ||
#pragma once | ||
|
||
#include <torch/torch.h> | ||
#include "image_read_mode.h" | ||
|
||
C10_EXPORT torch::Tensor decodeJPEG( | ||
const torch::Tensor& data, | ||
int64_t channels = 0); | ||
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,26 +1,22 @@ | ||
#include "readpng_cpu.h" | ||
|
||
// Comment | ||
#include <ATen/ATen.h> | ||
#include <string> | ||
|
||
#if !PNG_FOUND | ||
torch::Tensor decodePNG(const torch::Tensor& data, int64_t channels) { | ||
torch::Tensor decodePNG(const torch::Tensor& data, ImageReadMode mode) { | ||
TORCH_CHECK(false, "decodePNG: torchvision not compiled with libPNG support"); | ||
} | ||
#else | ||
#include <png.h> | ||
#include <setjmp.h> | ||
|
||
torch::Tensor decodePNG(const torch::Tensor& data, int64_t channels) { | ||
torch::Tensor decodePNG(const torch::Tensor& data, ImageReadMode mode) { | ||
// Check that the input tensor dtype is uint8 | ||
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); | ||
// Check that the input tensor is 1-dimensional | ||
TORCH_CHECK( | ||
data.dim() == 1 && data.numel() > 0, | ||
"Expected a non empty 1-dimensional tensor"); | ||
TORCH_CHECK( | ||
channels >= 0 && channels <= 4, "Number of channels not supported"); | ||
|
||
auto png_ptr = | ||
png_create_read_struct(PNG_LIBPNG_VER_STRING, nullptr, nullptr, nullptr); | ||
|
@@ -74,75 +70,85 @@ torch::Tensor decodePNG(const torch::Tensor& data, int64_t channels) { | |
TORCH_CHECK(retval == 1, "Could read image metadata from content.") | ||
} | ||
|
||
int current_channels = png_get_channels(png_ptr, info_ptr); | ||
int channels = png_get_channels(png_ptr, info_ptr); | ||
|
||
if (channels > 0) { | ||
if (mode != IMAGE_READ_MODE_UNCHANGED) { | ||
// TODO: consider supporting PNG_INFO_tRNS | ||
bool is_palette = (color_type & PNG_COLOR_MASK_PALETTE) != 0; | ||
bool has_color = (color_type & PNG_COLOR_MASK_COLOR) != 0; | ||
bool has_alpha = (color_type & PNG_COLOR_MASK_ALPHA) != 0; | ||
|
||
switch (channels) { | ||
case 1: // Gray | ||
if (is_palette) { | ||
png_set_palette_to_rgb(png_ptr); | ||
has_alpha = true; | ||
} | ||
|
||
if (has_alpha) { | ||
png_set_strip_alpha(png_ptr); | ||
} | ||
|
||
if (has_color) { | ||
png_set_rgb_to_gray(png_ptr, 1, 0.2989, 0.587); | ||
switch (mode) { | ||
case IMAGE_READ_MODE_GRAY: | ||
if (color_type != PNG_COLOR_TYPE_GRAY) { | ||
if (is_palette) { | ||
png_set_palette_to_rgb(png_ptr); | ||
has_alpha = true; | ||
} | ||
|
||
if (has_alpha) { | ||
png_set_strip_alpha(png_ptr); | ||
} | ||
|
||
if (has_color) { | ||
png_set_rgb_to_gray(png_ptr, 1, 0.2989, 0.587); | ||
} | ||
channels = 1; | ||
} | ||
break; | ||
case 2: // Gray + Alpha | ||
if (is_palette) { | ||
png_set_palette_to_rgb(png_ptr); | ||
has_alpha = true; | ||
} | ||
|
||
if (!has_alpha) { | ||
png_set_add_alpha(png_ptr, (1 << bit_depth) - 1, PNG_FILLER_AFTER); | ||
} | ||
|
||
if (has_color) { | ||
png_set_rgb_to_gray(png_ptr, 1, 0.2989, 0.587); | ||
case IMAGE_READ_MODE_GRAY_ALPHA: | ||
if (color_type != PNG_COLOR_TYPE_GRAY_ALPHA) { | ||
if (is_palette) { | ||
png_set_palette_to_rgb(png_ptr); | ||
has_alpha = true; | ||
} | ||
|
||
if (!has_alpha) { | ||
png_set_add_alpha(png_ptr, (1 << bit_depth) - 1, PNG_FILLER_AFTER); | ||
} | ||
|
||
if (has_color) { | ||
png_set_rgb_to_gray(png_ptr, 1, 0.2989, 0.587); | ||
} | ||
channels = 2; | ||
} | ||
break; | ||
case 3: | ||
if (is_palette) { | ||
png_set_palette_to_rgb(png_ptr); | ||
has_alpha = true; | ||
} else if (!has_color) { | ||
png_set_gray_to_rgb(png_ptr); | ||
} | ||
|
||
if (has_alpha) { | ||
png_set_strip_alpha(png_ptr); | ||
case IMAGE_READ_MODE_RGB: | ||
if (color_type != PNG_COLOR_TYPE_RGB) { | ||
if (is_palette) { | ||
png_set_palette_to_rgb(png_ptr); | ||
has_alpha = true; | ||
} else if (!has_color) { | ||
png_set_gray_to_rgb(png_ptr); | ||
} | ||
|
||
if (has_alpha) { | ||
png_set_strip_alpha(png_ptr); | ||
} | ||
channels = 3; | ||
} | ||
break; | ||
case 4: | ||
if (is_palette) { | ||
png_set_palette_to_rgb(png_ptr); | ||
has_alpha = true; | ||
} else if (!has_color) { | ||
png_set_gray_to_rgb(png_ptr); | ||
} | ||
|
||
if (!has_alpha) { | ||
png_set_add_alpha(png_ptr, (1 << bit_depth) - 1, PNG_FILLER_AFTER); | ||
case IMAGE_READ_MODE_RGB_ALPHA: | ||
if (color_type != PNG_COLOR_TYPE_RGB_ALPHA) { | ||
if (is_palette) { | ||
png_set_palette_to_rgb(png_ptr); | ||
has_alpha = true; | ||
} else if (!has_color) { | ||
png_set_gray_to_rgb(png_ptr); | ||
} | ||
|
||
if (!has_alpha) { | ||
png_set_add_alpha(png_ptr, (1 << bit_depth) - 1, PNG_FILLER_AFTER); | ||
} | ||
channels = 4; | ||
} | ||
break; | ||
default: | ||
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); | ||
TORCH_CHECK(false, "Invalid number of output channels."); | ||
TORCH_CHECK(false, "Provided mode not supported"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment here about the error message, it would be better to specify that this was not supported for PNG maybe? |
||
} | ||
|
||
png_read_update_info(png_ptr, info_ptr); | ||
} else { | ||
channels = current_channels; | ||
} | ||
|
||
auto tensor = | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,8 @@ | ||
#pragma once | ||
|
||
// Comment | ||
#include <torch/torch.h> | ||
#include <string> | ||
#include "image_read_mode.h" | ||
|
||
C10_EXPORT torch::Tensor decodePNG( | ||
const torch::Tensor& data, | ||
int64_t channels = 0); | ||
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED); |
Uh oh!
There was an error while loading. Please reload this page.