File tree 1 file changed +16
-6
lines changed
1 file changed +16
-6
lines changed Original file line number Diff line number Diff line change @@ -227,9 +227,14 @@ def update(
227
227
RecComputeMode .FUSED_TASKS_AND_STATES_COMPUTATION ,
228
228
]:
229
229
if not isinstance (labels , torch .Tensor ):
230
- raise RecMetricException (
231
- "Fused computation only support where 'labels' is a tensor"
232
- )
230
+ try :
231
+ labels = torch .stack (
232
+ [labels [task .name ] for task in self ._tasks ]
233
+ )
234
+ except Exception as e :
235
+ raise RecMetricException (
236
+ f"Failed to convert labels to tensor for fused computation: { e } "
237
+ )
233
238
labels = labels .view (- 1 , self ._batch_size )
234
239
if self ._should_validate_update :
235
240
# Set the default value to be all True. When weights is None, it's considered
@@ -241,9 +246,14 @@ def update(
241
246
)
242
247
if weights is not None :
243
248
if not isinstance (weights , torch .Tensor ):
244
- raise RecMetricException (
245
- "Fused computation only support where 'weights' is a tensor"
246
- )
249
+ try :
250
+ weights = torch .stack (
251
+ [weights [task .name ] for task in self ._tasks ]
252
+ )
253
+ except Exception as e :
254
+ raise RecMetricException (
255
+ f"Failed to convert weights to tensor for fused computation: { e } "
256
+ )
247
257
has_valid_weights = torch .gt (
248
258
torch .count_nonzero (
249
259
weights .view (- 1 , self ._batch_size ), dim = - 1
You can’t perform that action at this time.
0 commit comments