From 7b8737aacb241c20e4ddfba881ed78cfacb4521b Mon Sep 17 00:00:00 2001 From: Joey Ballentine <34788790+joeyballentine@users.noreply.github.com> Date: Wed, 23 Aug 2023 09:23:53 -0400 Subject: [PATCH] Add "Separate Alpha" checkbox to upscale nodes (#2127) * Add separate alpha option for pytorch upscale * Add separate alpha for ncnn and onnx * fix linting * pr suggestion * change condition * reuse alpha thing * update description --- .../nodes/impl/upscale/convenient_upscale.py | 50 ++++++++++++------- .../ncnn/processing/upscale_image.py | 31 ++++++++++-- .../ncnn/utility/interpolate_models.py | 2 +- .../onnx/processing/upscale_image.py | 21 +++++++- .../onnx/utility/interpolate_models.py | 2 +- .../pytorch/processing/upscale_image.py | 30 ++++++++++- 6 files changed, 112 insertions(+), 24 deletions(-) diff --git a/backend/src/nodes/impl/upscale/convenient_upscale.py b/backend/src/nodes/impl/upscale/convenient_upscale.py index 8d472af5b..7df176b58 100644 --- a/backend/src/nodes/impl/upscale/convenient_upscale.py +++ b/backend/src/nodes/impl/upscale/convenient_upscale.py @@ -20,11 +20,20 @@ def with_black_and_white_backgrounds(img: np.ndarray) -> Tuple[np.ndarray, np.nd return black, white +def denoise_and_flatten_alpha(img: np.ndarray) -> np.ndarray: + alpha_min = np.min(img, axis=2) + alpha_max = np.max(img, axis=2) + alpha_mean = np.mean(img, axis=2) + alpha = alpha_max * alpha_mean + alpha_min * (1 - alpha_mean) + return alpha + + def convenient_upscale( img: np.ndarray, model_in_nc: int, model_out_nc: int, upscale: ImageOp, + separate_alpha: bool = False, ) -> np.ndarray: """ Upscales the given image in an intuitive/convenient way. @@ -56,23 +65,30 @@ def convenient_upscale( unique_alpha = np.full(rgb.shape[:-1], unique[0], np.float32) return np.dstack((rgb, unique_alpha)) - # Transparency hack (white/black background difference alpha) - black, white = with_black_and_white_backgrounds(img) - black_up = as_target_channels( - upscale(as_target_channels(black, model_in_nc, True)), 3, True - ) - white_up = as_target_channels( - upscale(as_target_channels(white, model_in_nc, True)), 3, True - ) - - # Interpolate between the alpha values to get a more defined alpha - alpha_candidates = 1 - (white_up - black_up) # type: ignore - alpha_min = np.min(alpha_candidates, axis=2) - alpha_max = np.max(alpha_candidates, axis=2) - alpha_mean = np.mean(alpha_candidates, axis=2) - alpha = alpha_max * alpha_mean + alpha_min * (1 - alpha_mean) - - return np.dstack((black_up, alpha)) + if separate_alpha: + # Upscale the RGB channels and alpha channel separately + rgb = as_target_channels( + upscale(as_target_channels(img[:, :, :3], model_in_nc, True)), 3, True + ) + alpha = denoise_and_flatten_alpha( + upscale(as_target_channels(img[:, :, 3], model_in_nc, True)) + ) + return np.dstack((rgb, alpha)) + else: + # Transparency hack (white/black background difference alpha) + black, white = with_black_and_white_backgrounds(img) + black_up = as_target_channels( + upscale(as_target_channels(black, model_in_nc, True)), 3, True + ) + white_up = as_target_channels( + upscale(as_target_channels(white, model_in_nc, True)), 3, True + ) + + # Interpolate between the alpha values to get a more defined alpha + alpha_candidates = 1 - (white_up - black_up) # type: ignore + alpha = denoise_and_flatten_alpha(alpha_candidates) + + return np.dstack((black_up, alpha)) return as_target_channels( upscale(as_target_channels(img, model_in_nc, True)), in_img_c, True diff --git a/backend/src/packages/chaiNNer_ncnn/ncnn/processing/upscale_image.py b/backend/src/packages/chaiNNer_ncnn/ncnn/processing/upscale_image.py index 8703b68c0..46d1eddf1 100644 --- a/backend/src/packages/chaiNNer_ncnn/ncnn/processing/upscale_image.py +++ b/backend/src/packages/chaiNNer_ncnn/ncnn/processing/upscale_image.py @@ -15,6 +15,7 @@ use_gpu = False from sanic.log import logger +from nodes.groups import Condition, if_group from nodes.impl.ncnn.auto_split import ncnn_auto_split from nodes.impl.ncnn.model import NcnnModelWrapper from nodes.impl.ncnn.session import get_ncnn_net @@ -25,7 +26,12 @@ ) from nodes.impl.upscale.convenient_upscale import convenient_upscale from nodes.impl.upscale.tiler import MaxTileSize -from nodes.properties.inputs import ImageInput, NcnnModelInput, TileSizeDropdown +from nodes.properties.inputs import ( + BoolInput, + ImageInput, + NcnnModelInput, + TileSizeDropdown, +) from nodes.properties.outputs import ImageOutput from nodes.utils.exec_options import get_execution_options from nodes.utils.utils import get_h_w_c @@ -138,6 +144,25 @@ def estimate_cpu(): "Generally it's recommended to use the largest tile size possible for best performance (with the ideal scenario being no tiling at all), but depending on the model and image size, this may not be possible.", "If you are having issues with the automatic mode, you can manually select a tile size. On certain machines, a very small tile size such as 256 or 128 might be required for it to work at all.", ), + if_group( + Condition.type(1, "Image { channels: 4 } ") + & ( + Condition.type(0, "NcnnNetwork { inputChannels: 1, outputChannels: 1 }") + | Condition.type( + 0, "NcnnNetwork { inputChannels: 3, outputChannels: 3 }" + ) + ) + )( + BoolInput("Separate Alpha", default=False).with_docs( + "Upscale alpha separately from color. Enabling this option will cause the alpha of" + " the upscaled image to be less noisy and more accurate to the alpha of the original" + " image, but the image may suffer from dark borders near transparency edges" + " (transition from fully transparent to fully opaque).", + "Whether enabling this option will improve the upscaled image depends on the original" + " image. We generally recommend this option for images with smooth transitions between" + " transparent and opaque regions.", + ) + ), ], outputs=[ ImageOutput(image_type="""convenientUpscale(Input0, Input1)"""), @@ -145,7 +170,7 @@ def estimate_cpu(): limited_to_8bpc=True, ) def upscale_image_node( - img: np.ndarray, model: NcnnModelWrapper, tile_size: TileSize + img: np.ndarray, model: NcnnModelWrapper, tile_size: TileSize, separate_alpha: bool ) -> np.ndarray: def upscale(i: np.ndarray) -> np.ndarray: ic = get_h_w_c(i)[2] @@ -166,4 +191,4 @@ def upscale(i: np.ndarray) -> np.ndarray: i = cv2.cvtColor(i, cv2.COLOR_RGBA2BGRA) return i - return convenient_upscale(img, model.in_nc, model.out_nc, upscale) + return convenient_upscale(img, model.in_nc, model.out_nc, upscale, separate_alpha) diff --git a/backend/src/packages/chaiNNer_ncnn/ncnn/utility/interpolate_models.py b/backend/src/packages/chaiNNer_ncnn/ncnn/utility/interpolate_models.py index 115f47871..c54c849b1 100644 --- a/backend/src/packages/chaiNNer_ncnn/ncnn/utility/interpolate_models.py +++ b/backend/src/packages/chaiNNer_ncnn/ncnn/utility/interpolate_models.py @@ -15,7 +15,7 @@ def check_will_upscale(interp: NcnnModelWrapper): fake_img = np.ones((3, 3, 3), dtype=np.float32, order="F") - result = upscale_image_node(fake_img, interp, NO_TILING) + result = upscale_image_node(fake_img, interp, NO_TILING, False) mean_color = np.mean(result) del result diff --git a/backend/src/packages/chaiNNer_onnx/onnx/processing/upscale_image.py b/backend/src/packages/chaiNNer_onnx/onnx/processing/upscale_image.py index 4a99b610c..5c677f39b 100644 --- a/backend/src/packages/chaiNNer_onnx/onnx/processing/upscale_image.py +++ b/backend/src/packages/chaiNNer_onnx/onnx/processing/upscale_image.py @@ -6,6 +6,7 @@ import onnxruntime as ort from sanic.log import logger +from nodes.groups import Condition, if_group from nodes.impl.onnx.auto_split import onnx_auto_split from nodes.impl.onnx.model import OnnxModel from nodes.impl.onnx.session import get_onnx_session @@ -17,7 +18,12 @@ ) from nodes.impl.upscale.convenient_upscale import convenient_upscale from nodes.impl.upscale.tiler import ExactTileSize -from nodes.properties.inputs import ImageInput, OnnxGenericModelInput, TileSizeDropdown +from nodes.properties.inputs import ( + BoolInput, + ImageInput, + OnnxGenericModelInput, + TileSizeDropdown, +) from nodes.properties.outputs import ImageOutput from nodes.utils.exec_options import get_execution_options from nodes.utils.utils import get_h_w_c @@ -70,6 +76,17 @@ def estimate(): "ONNX upscaling does not support an automatic mode, meaning you may need to" " manually select a tile size for it to work.", ), + if_group(Condition.type(1, "Image { channels: 4 } "))( + BoolInput("Separate Alpha", default=False).with_docs( + "Upscale alpha separately from color. Enabling this option will cause the alpha of" + " the upscaled image to be less noisy and more accurate to the alpha of the original" + " image, but the image may suffer from dark borders near transparency edges" + " (transition from fully transparent to fully opaque).", + "Whether enabling this option will improve the upscaled image depends on the original" + " image. We generally recommend this option for images with smooth transitions between" + " transparent and opaque regions.", + ) + ), ], outputs=[ImageOutput("Image")], name="Upscale Image", @@ -79,6 +96,7 @@ def upscale_image_node( img: np.ndarray, model: OnnxModel, tile_size: TileSize, + separate_alpha: bool, ) -> np.ndarray: """Upscales an image with a pretrained model""" session = get_onnx_session(model, get_execution_options()) @@ -101,4 +119,5 @@ def upscale_image_node( in_nc, out_nc, lambda i: upscale(i, session, tile_size, change_shape, exact_size), + separate_alpha, ) diff --git a/backend/src/packages/chaiNNer_onnx/onnx/utility/interpolate_models.py b/backend/src/packages/chaiNNer_onnx/onnx/utility/interpolate_models.py index a296e4db5..5cefa7e6e 100644 --- a/backend/src/packages/chaiNNer_onnx/onnx/utility/interpolate_models.py +++ b/backend/src/packages/chaiNNer_onnx/onnx/utility/interpolate_models.py @@ -49,7 +49,7 @@ def perform_interp( def check_will_upscale(model: OnnxModel): fake_img = np.ones((3, 3, 3), dtype=np.float32, order="F") - result = upscale_image_node(fake_img, model, NO_TILING) + result = upscale_image_node(fake_img, model, NO_TILING, False) mean_color = np.mean(result) del result diff --git a/backend/src/packages/chaiNNer_pytorch/pytorch/processing/upscale_image.py b/backend/src/packages/chaiNNer_pytorch/pytorch/processing/upscale_image.py index 7a89a9fbf..5f437f828 100644 --- a/backend/src/packages/chaiNNer_pytorch/pytorch/processing/upscale_image.py +++ b/backend/src/packages/chaiNNer_pytorch/pytorch/processing/upscale_image.py @@ -18,7 +18,12 @@ ) from nodes.impl.upscale.convenient_upscale import convenient_upscale from nodes.impl.upscale.tiler import MaxTileSize -from nodes.properties.inputs import ImageInput, SrModelInput, TileSizeDropdown +from nodes.properties.inputs import ( + BoolInput, + ImageInput, + SrModelInput, + TileSizeDropdown, +) from nodes.properties.outputs import ImageOutput from nodes.utils.exec_options import ExecutionOptions, get_execution_options from nodes.utils.utils import get_h_w_c @@ -107,6 +112,27 @@ def estimate(): hint=True, ) ), + if_group( + Condition.type(1, "Image { channels: 4 } ") + & ( + Condition.type( + 0, "PyTorchModel { inputChannels: 1, outputChannels: 1 }" + ) + | Condition.type( + 0, "PyTorchModel { inputChannels: 3, outputChannels: 3 }" + ) + ) + )( + BoolInput("Separate Alpha", default=False).with_docs( + "Upscale alpha separately from color. Enabling this option will cause the alpha of" + " the upscaled image to be less noisy and more accurate to the alpha of the original" + " image, but the image may suffer from dark borders near transparency edges" + " (transition from fully transparent to fully opaque).", + "Whether enabling this option will improve the upscaled image depends on the original" + " image. We generally recommend this option for images with smooth transitions between" + " transparent and opaque regions.", + ) + ), ], outputs=[ ImageOutput( @@ -119,6 +145,7 @@ def upscale_image_node( img: np.ndarray, model: PyTorchSRModel, tile_size: TileSize, + separate_alpha: bool, ) -> np.ndarray: """Upscales an image with a pretrained model""" @@ -141,4 +168,5 @@ def upscale_image_node( in_nc, out_nc, lambda i: upscale(i, model, tile_size, exec_options), + separate_alpha, )