@@ -184,43 +184,91 @@ def testConvolutionalWeightsCA(self, clustering_centroids, pulling_indices,
184184 self ._check_pull_values (clustering_algo , pulling_indices , expected_output )
185185
186186 @parameterized .parameters (
187- ([[0. , 1 , 2 ], [3 , 4 , 5 ]],
188- [[[[0 ], [0 ]], [[0 ], [1 ]]],
189- [[[0 ], [2 ]], [[1 ], [0 ]]]],
190- [[[[0 ], [0 ]], [[0 ], [0 ]]],
191- [[[0 ], [0 ]], [[1 ], [1 ]]]]))
187+ ("channels_last" ,
188+ [[1 , 2 ], [3 , 4 ], [5 , 6 ]], # 3 channels and 2 cluster per channel
189+ # pulling indices has shape (2, 2, 1, 3)
190+ [[[[0 , 1 , 0 ]], [[0 , 1 , 1 ]]], [[[1 , 0 , 1 ]], [[0 , 1 , 0 ]]]],
191+ [[[[1 , 4 , 5 ]], [[1 , 4 , 6 ]]], [[[2 , 3 , 6 ]], [[1 , 4 , 5 ]]]]),
192+ ("channels_first" ,
193+ [[1 , 2 ], [3 , 4 ], [4 , 5 ], [6 , 7 ]], # 4 channels and 2 clusters per channel
194+ # pulling indices has shape (1, 4, 2, 2)
195+ [[[[0 , 1 ], [1 , 1 ]], [[0 , 0 ], [0 , 1 ]],
196+ [[1 , 0 ], [0 , 0 ]], [[1 , 1 ], [0 , 0 ]]]],
197+ [[[[1 , 2 ], [2 , 2 ]], [[3 , 3 ], [3 , 4 ]],
198+ [[5 , 4 ], [4 , 4 ]], [[7 , 7 ], [6 , 6 ]]]])
199+ )
192200 def testConvolutionalWeightsPerChannelCA (self ,
201+ data_format ,
193202 clustering_centroids ,
194203 pulling_indices ,
195204 expected_output ):
196- """Verifies that PerChannelCA works as expected."""
205+ """Verifies that get_clustered_weight function works as expected."""
197206 clustering_centroids = tf .Variable (clustering_centroids , dtype = tf .float32 )
198- clustering_algo = clustering_registry .PerChannelCA (
199- clustering_centroids , GradientAggregation .SUM
207+ clustering_algo = clustering_registry .ClusteringAlgorithmPerChannel (
208+ clustering_centroids , GradientAggregation .SUM , data_format
200209 )
210+ # Note that clustered_weights has the same shape as pulling_indices,
211+ # because they are defined inside of the check function.
201212 self ._check_pull_values (clustering_algo , pulling_indices , expected_output )
202213
214+ @parameterized .parameters (
215+ ("channels_last" ,
216+ [[1 , 2 ], [3 , 4 ], [5 , 6 ]], # 3 channels and 2 cluster per channel
217+ # weight has shape (2, 2, 1, 3)
218+ [[[[1.1 , 3.2 , 5.2 ]], [[2.0 , 4.1 , 5.2 ]]],
219+ [[[2.1 , 2. , 6.1 ]], [[1. , 5. , 5. ]]]],
220+ # expected pulling indices
221+ [[[[0 , 0 , 0 ]], [[1 , 1 , 0 ]]], [[[1 , 0 , 1 ]], [[0 , 1 , 0 ]]]]),
222+ ("channels_first" ,
223+ # 4 channels and 2 clusters per channel
224+ [[1 , 2 ], [3 , 4 ], [4 , 5 ], [6 , 7 ]],
225+ # weight has shape (1, 4, 2, 2)
226+ [[[[0.1 , 1.5 ], [2.0 , 1.1 ]], [[0. , 3.5 ], [4.4 , 4. ]],
227+ [[4.1 , 4.2 ], [5.3 , 6. ]], [[7. , 7.1 ], [6.1 , 5.8 ]]]],
228+ # expected pulling indices
229+ [[[[0 , 0 ], [1 , 0 ]], [[0 , 0 ], [1 , 1 ]],
230+ [[0 , 0 ], [1 , 1 ]], [[1 , 1 ], [0 , 0 ]]]])
231+ )
232+ def testConvolutionalPullingIndicesPerChannelCA (self ,
233+ data_format ,
234+ clustering_centroids ,
235+ weight ,
236+ expected_output ):
237+ """Verifies that get_pulling_indices function works as expected."""
238+ clustering_centroids = tf .Variable (clustering_centroids , dtype = tf .float32 )
239+ clustering_algo = clustering_registry .ClusteringAlgorithmPerChannel (
240+ clustering_centroids , GradientAggregation .SUM , data_format
241+ )
242+ weight = tf .convert_to_tensor (weight )
243+ pulling_indices = clustering_algo .get_pulling_indices (weight )
244+
245+ # check that pulling_indices has the same shape as weight
246+ self .assertEqual (pulling_indices .shape , weight .shape )
247+ self .assertAllEqual (pulling_indices , expected_output )
248+
203249 @parameterized .parameters (
204250 (GradientAggregation .AVG ,
205- [[[[0 ], [0 ]], [[0 ], [1 ]]],
206- [[[0 ], [2 ]], [[1 ], [0 ]]]], [[1 , 1 , 0 ], [1 , 1 , 1 ]]),
251+ [[[[0 , 0 , 0 ]], [[1 , 1 , 0 ]]],
252+ [[[1 , 0 , 1 ]], [[0 , 1 , 0 ]]]],
253+ [[1 , 1 ], [1 , 1 ], [1 , 1 ]]),
207254 (GradientAggregation .SUM ,
208- [[[[0 ], [0 ]], [[0 ], [1 ]]],
209- [[[0 ], [2 ]], [[1 ], [0 ]]]], [[3 , 1 , 0 ], [2 , 1 , 1 ]])
255+ [[[[0 , 0 , 0 ]], [[1 , 1 , 0 ]]],
256+ [[[1 , 0 , 1 ]], [[0 , 1 , 0 ]]]],
257+ [[2 , 2 ], [2 , 2 ], [3 , 1 ]])
210258 )
211- def testConvolutionalPerChannelCAGrad (self ,
259+ def testConvolutionalPerChannelCAGradChannelsLast (self ,
212260 cluster_gradient_aggregation ,
213261 pulling_indices ,
214262 expected_grad_centroids ):
215- """Verifies that the gradients of convolutional layer work as expected ."""
263+ """Verifies that the gradients of convolutional layer works ."""
216264
217- clustering_centroids = tf .Variable ([[0. , 1 , 2 ], [3 , 4 , 5 ]],
265+ clustering_centroids = tf .Variable ([[1 , 2 ], [3 , 4 ], [ 5 , 6 ]],
218266 dtype = tf .float32 )
219- weight = tf .constant ([[[[0 .1 , 3.0 ]], [[0.2 , 0.1 ]]],
220- [[[0 .1 , 3.0 ]], [[0.2 , 0.1 ]]]])
267+ weight = tf .constant ([[[[1 .1 , 3.2 , 5.2 ]], [[2.0 , 4.1 , 5.2 ]]],
268+ [[[2 .1 , 2. , 6.1 ]], [[1. , 5. , 5. ]]]])
221269
222- clustering_algo = clustering_registry .PerChannelCA (
223- clustering_centroids , cluster_gradient_aggregation
270+ clustering_algo = clustering_registry .ClusteringAlgorithmPerChannel (
271+ clustering_centroids , cluster_gradient_aggregation , "channels_last"
224272 )
225273 self ._check_gradients_clustered_weight (
226274 clustering_algo ,
@@ -229,6 +277,37 @@ def testConvolutionalPerChannelCAGrad(self,
229277 expected_grad_centroids ,
230278 )
231279
280+ @parameterized .parameters (
281+ (GradientAggregation .AVG ,
282+ [[[[0 , 0 ], [1 , 0 ]], [[0 , 0 ], [1 , 1 ]],
283+ [[0 , 0 ], [1 , 1 ]], [[1 , 1 ], [0 , 0 ]]]],
284+ [[1 , 1 ], [1 , 1 ], [1 , 1 ], [1 , 1 ]]),
285+ (GradientAggregation .SUM ,
286+ [[[[0 , 0 ], [1 , 0 ]], [[0 , 0 ], [1 , 1 ]],
287+ [[0 , 0 ], [1 , 1 ]], [[1 , 1 ], [0 , 0 ]]]],
288+ [[3 , 1 ], [2 , 2 ], [2 , 2 ], [2 , 2 ]])
289+ )
290+ def testConvolutionalPerChannelCAGradChannelsFirst (self ,
291+ cluster_gradient_aggregation ,
292+ pulling_indices ,
293+ expected_grad_centroids ):
294+ """Verifies that the gradients of convolutional layer works."""
295+
296+ clustering_centroids = tf .Variable ([[1 , 2 ], [3 , 4 ], [4 , 5 ], [6 , 7 ]],
297+ dtype = tf .float32 )
298+ weight = tf .constant ([[[[0.1 , 1.5 ], [2.0 , 1.1 ]],
299+ [[0. , 3.5 ], [4.4 , 4. ]], [[4.1 , 4.2 ], [5.3 , 6. ]],
300+ [[7. , 7.1 ], [6.1 , 5.8 ]]]])
301+
302+ clustering_algo = clustering_registry .ClusteringAlgorithmPerChannel (
303+ clustering_centroids , cluster_gradient_aggregation , "channels_first"
304+ )
305+ self ._check_gradients_clustered_weight (
306+ clustering_algo ,
307+ weight ,
308+ pulling_indices ,
309+ expected_grad_centroids ,
310+ )
232311
233312class CustomLayer (layers .Layer ):
234313 """A custom non-clusterable layer class."""
0 commit comments