From e696c090db9ed950424d9ba6a52a9ff60c1db637 Mon Sep 17 00:00:00 2001 From: Wessel Bruinsma Date: Sun, 21 Apr 2024 11:18:30 +0200 Subject: [PATCH] Don't use `groups` for transposed convs for Keras --- neuralprocesses/tensorflow/nn.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/neuralprocesses/tensorflow/nn.py b/neuralprocesses/tensorflow/nn.py index 1956a17f..da8c42d9 100644 --- a/neuralprocesses/tensorflow/nn.py +++ b/neuralprocesses/tensorflow/nn.py @@ -1,3 +1,4 @@ +import warnings from functools import partial from typing import Optional, Union @@ -119,6 +120,16 @@ def ConvNd( else: suffix = "" + if groups > 1: + if transposed: + warnings.warn( + "Keras does not depthwise separable transposed convolutions! " + "Using non-separable convolutions for the transposed convolutions. " + "This could be a LOT more expensive." + ) + else: + additional_args["groups"] = groups + conv_layer = getattr(tf.keras.layers, f"Conv{dim}D{suffix}")( input_shape=(in_channels,) + (None,) * dim, filters=out_channels, @@ -126,7 +137,6 @@ def ConvNd( strides=stride, padding="same", dilation_rate=dilation, - groups=groups, use_bias=bias, data_format=data_format, dtype=dtype,