Skip to content

Commit

Permalink
Feat/support weights passed from dataframe
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock committed Dec 17, 2024
1 parent 58f98a8 commit f9263d6
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/fsrs_optimizer/fsrs_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,10 @@ def __init__(
self.seq_len = torch.tensor(
dataframe["tensor"].map(len).values, dtype=torch.long
)
if "weights" in dataframe.columns:
self.weights = torch.tensor(dataframe["weights"].values, dtype=torch.float)
else:
self.weights = torch.ones(len(dataframe), dtype=torch.float)
length = len(dataframe)
batch_num, remainder = divmod(length, max(1, batch_size))
self.batch_num = batch_num + 1 if remainder > 0 else batch_num
Expand All @@ -256,6 +260,7 @@ def __init__(
self.t_train[start_index:end_index].to(device),
self.y_train[start_index:end_index].to(device),
seq_lens.to(device),
self.weights[start_index:end_index].to(device),
)

def __getitem__(self, idx):
Expand Down

0 comments on commit f9263d6

Please sign in to comment.