Skip to content

Commit

Permalink
conditionally convert model outputs containing session id for NDCG co…
Browse files Browse the repository at this point in the history
…mputation (pytorch#1678)

Summary:
Pull Request resolved: pytorch#1678

This diff conditionally applies model output conversion only for session_id lists for models that use NDCG. This is done at the metric level such that model owners do not have to apply the change across various models and instead can rely on the metric library to handle it for them.

RecMetrics will now conditionally convert list of session id strings to a tensor of session ids. The condition is based on 1) NDCG specified in the config and subsequently the session_key and 2) session id is a List of strings. Only then will it convert it to a tensor of session ids.

By default session_key is "session_id", if the model has a different key for its session id list then it must be specified in the config.

Reviewed By: howei

Differential Revision: D53327408

fbshipit-source-id: ecad32a7f375c9f0f9fca176de63416f2cc6bd23
  • Loading branch information
iamzainhuda authored and facebook-github-bot committed Feb 5, 2024
1 parent 73ad8b8 commit f8f6f61
Showing 1 changed file with 37 additions and 2 deletions.
39 changes: 37 additions & 2 deletions torchrec/metrics/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,24 @@
from torchrec.metrics.rec_metric import RecTaskInfo


def session_ids_to_tensor(session_ids: List[str]) -> torch.Tensor:
"""
This function is used to prepare model outputs with session_ids as List[str] to tensor to be consumed by the Metric computation
"""
curr_id = 1
session_lengths_list = [0]

for i, session in enumerate(session_ids[:-1]):
if session == session_ids[i + 1]:
session_lengths_list.append(curr_id)
else:
session_lengths_list.append(curr_id)
curr_id += 1

session_lengths_list.append(curr_id)
return torch.tensor(session_lengths_list[1:])


def is_empty_signals(
labels: torch.Tensor,
predictions: torch.Tensor,
Expand Down Expand Up @@ -63,10 +81,20 @@ def parse_model_outputs(


def parse_required_inputs(
model_out: Dict[str, torch.Tensor], required_inputs_list: List[str]
model_out: Dict[str, torch.Tensor],
required_inputs_list: List[str],
ndcg_transform_input: bool = False,
) -> Dict[str, torch.Tensor]:
required_inputs: Dict[str, torch.Tensor] = {}
for feature in required_inputs_list:
# convert feature defined from config only
if ndcg_transform_input:
model_out[feature] = (
# pyre-ignore[6]
session_ids_to_tensor(model_out[feature])
if isinstance(model_out[feature], list)
else model_out[feature]
)
required_inputs[feature] = model_out[feature].squeeze()
assert isinstance(required_inputs[feature], torch.Tensor)
return required_inputs
Expand All @@ -86,6 +114,8 @@ def parse_task_model_outputs(
all_predictions: Dict[str, torch.Tensor] = {}
all_weights: Dict[str, torch.Tensor] = {}
all_required_inputs: Dict[str, torch.Tensor] = {}
# Convert session_ids to tensor if NDCG metric
ndcg_transform_input = False
for task in tasks:
labels, predictions, weights = parse_model_outputs(
task.label_name, task.prediction_name, task.weight_name, model_out
Expand All @@ -99,7 +129,12 @@ def parse_task_model_outputs(
if torch.numel(labels) > 0:
all_labels[task.name] = labels

if task.name and task.name.startswith("ndcg"):
ndcg_transform_input = True

if required_inputs_list is not None:
all_required_inputs = parse_required_inputs(model_out, required_inputs_list)
all_required_inputs = parse_required_inputs(
model_out, required_inputs_list, ndcg_transform_input
)

return all_labels, all_predictions, all_weights, all_required_inputs

0 comments on commit f8f6f61

Please sign in to comment.