Skip to content

Commit 5f23e70

Browse files
HugeEnginefacebook-github-bot
authored andcommitted
Support RecMetrics new options in Pyper config (#2987)
Summary: Fix the issue in tower QPS metric update issue in P1806989639 and discussed in [link](https://fb.workplace.com/groups/527654686243695/posts/989373686738457/?comment_id=990338486641977&reply_comment_id=990584749950684), and expose the new TorchMetric fusion [feature](https://fb.workplace.com/groups/527654686243695/permalink/989373686738457) to Pyper users (default is disabled) Differential Revision: D75120928
1 parent 4b5b5e2 commit 5f23e70

File tree

1 file changed

+16
-6
lines changed

1 file changed

+16
-6
lines changed

torchrec/metrics/tower_qps.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -227,9 +227,14 @@ def update(
227227
RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION,
228228
]:
229229
if not isinstance(labels, torch.Tensor):
230-
raise RecMetricException(
231-
"Fused computation only support where 'labels' is a tensor"
232-
)
230+
try:
231+
labels = torch.stack(
232+
[labels[task.name] for task in self._tasks]
233+
)
234+
except Exception as e:
235+
raise RecMetricException(
236+
f"Failed to convert labels to tensor for fused computation: {e}"
237+
)
233238
labels = labels.view(-1, self._batch_size)
234239
if self._should_validate_update:
235240
# Set the default value to be all True. When weights is None, it's considered
@@ -241,9 +246,14 @@ def update(
241246
)
242247
if weights is not None:
243248
if not isinstance(weights, torch.Tensor):
244-
raise RecMetricException(
245-
"Fused computation only support where 'weights' is a tensor"
246-
)
249+
try:
250+
weights = torch.stack(
251+
[weights[task.name] for task in self._tasks]
252+
)
253+
except Exception as e:
254+
raise RecMetricException(
255+
f"Failed to convert weights to tensor for fused computation: {e}"
256+
)
247257
has_valid_weights = torch.gt(
248258
torch.count_nonzero(
249259
weights.view(-1, self._batch_size), dim=-1

0 commit comments

Comments
 (0)