Skip to content

Commit 249c8ce

Browse files
fix bug:cannot run on cuda device
1 parent 80d6fc8 commit 249c8ce

File tree

8 files changed

+36
-9
lines changed

8 files changed

+36
-9
lines changed

EduCDM/DINA/GD/DINA.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def __init__(self, user_num, item_num, hidden_dim, ste=False):
8585
self.dina_net = DINANet(user_num, item_num, hidden_dim)
8686

8787
def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.001) -> ...:
88+
self.dina_net = self.dina_net.to(device)
8889
loss_function = nn.BCELoss()
8990

9091
trainer = torch.optim.Adam(self.dina_net.parameters(), lr)
@@ -109,10 +110,11 @@ def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.00
109110
print("[Epoch %d] LogisticLoss: %.6f" % (e, float(np.mean(losses))))
110111

111112
if test_data is not None:
112-
auc, accuracy = self.eval(test_data)
113+
auc, accuracy = self.eval(test_data, device=device)
113114
print("[Epoch %d] auc: %.6f, accuracy: %.6f" % (e, auc, accuracy))
114115

115116
def eval(self, test_data, device="cpu") -> tuple:
117+
self.dina_net = self.dina_net.to(device)
116118
self.dina_net.eval()
117119
y_pred = []
118120
y_true = []

EduCDM/IRR/DINA.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def __init__(self, user_num, item_num, knowledge_num, ste=False, zeta=0.5):
1717
self.zeta = zeta
1818

1919
def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.001) -> ...:
20+
self.dina_net = self.dina_net.to(device)
2021
point_loss_function = nn.BCELoss()
2122
pair_loss_function = PairSCELoss()
2223
loss_function = HarmonicLoss(self.zeta)
@@ -32,6 +33,7 @@ def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.00
3233
user_id: torch.Tensor = user_id.to(device)
3334
item_id: torch.Tensor = item_id.to(device)
3435
knowledge: torch.Tensor = knowledge.to(device)
36+
n_samples: torch.Tensor = n_samples.to(device)
3537
predicted_pos_score: torch.Tensor = self.dina_net(user_id, item_id, knowledge)
3638
score: torch.Tensor = score.to(device)
3739
neg_score = 1 - score
@@ -40,6 +42,7 @@ def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.00
4042
predicted_neg_scores = []
4143
if neg_users:
4244
for neg_user in neg_users:
45+
neg_user: torch.Tensor = neg_user.to(device)
4346
predicted_neg_score = self.dina_net(neg_user, item_id, knowledge)
4447
predicted_neg_scores.append(predicted_neg_score)
4548

@@ -75,10 +78,11 @@ def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.00
7578
)
7679

7780
if test_data is not None:
78-
eval_data = self.eval(test_data)
81+
eval_data = self.eval(test_data, device=device)
7982
print("[Epoch %d]\n%s" % (e, eval_data))
8083

8184
def eval(self, test_data, device="cpu"):
85+
self.dina_net = self.dina_net.to(device)
8286
self.dina_net.eval()
8387
y_pred = []
8488
y_true = []
@@ -87,6 +91,7 @@ def eval(self, test_data, device="cpu"):
8791
user_id, item_id, knowledge, response = batch_data
8892
user_id: torch.Tensor = user_id.to(device)
8993
item_id: torch.Tensor = item_id.to(device)
94+
knowledge: torch.Tensor = knowledge.to(device)
9095
pred: torch.Tensor = self.dina_net(user_id, item_id, knowledge)
9196
y_pred.extend(pred.tolist())
9297
y_true.extend(response.tolist())

EduCDM/IRR/IRT.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@
1414

1515

1616
class IRT(PointIRT):
17-
def __init__(self, user_num, item_num, knowledge_num, value_range=10, zeta=0.5):
18-
super(IRT, self).__init__(user_num, item_num, value_range=value_range)
17+
def __init__(self, user_num, item_num, knowledge_num, zeta=0.5):
18+
super(IRT, self).__init__(user_num, item_num)
1919
self.knowledge_num = knowledge_num
2020
self.zeta = zeta
2121

2222
def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.001) -> ...:
23+
self.irt_net = self.irt_net.to(device)
2324
point_loss_function = nn.BCELoss()
2425
pair_loss_function = PairSCELoss()
2526
loss_function = HarmonicLoss(self.zeta)
@@ -34,6 +35,7 @@ def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.00
3435
user_id, item_id, _, score, n_samples, *neg_users = batch_data
3536
user_id: torch.Tensor = user_id.to(device)
3637
item_id: torch.Tensor = item_id.to(device)
38+
n_samples: torch.Tensor = n_samples.to(device)
3739
predicted_pos_score: torch.Tensor = self.irt_net(user_id, item_id)
3840
score: torch.Tensor = score.to(device)
3941
neg_score = 1 - score
@@ -42,6 +44,7 @@ def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.00
4244
predicted_neg_scores = []
4345
if neg_users:
4446
for neg_user in neg_users:
47+
neg_user: torch.Tensor = neg_user.to(device)
4548
predicted_neg_score = self.irt_net(neg_user, item_id)
4649
predicted_neg_scores.append(predicted_neg_score)
4750

@@ -77,10 +80,11 @@ def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.00
7780
)
7881

7982
if test_data is not None:
80-
eval_data = self.eval(test_data)
83+
eval_data = self.eval(test_data, device=device)
8184
print("[Epoch %d]\n%s" % (e, eval_data))
8285

8386
def eval(self, test_data, device="cpu"):
87+
self.irt_net = self.irt_net.to(device)
8488
self.irt_net.eval()
8589
y_pred = []
8690
y_true = []

EduCDM/IRR/MIRT.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def __init__(self, user_num, item_num, knowledge_num, latent_dim=None, zeta=0.5)
2222
self.zeta = zeta
2323

2424
def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.001) -> ...:
25+
self.irt_net = self.irt_net.to(device)
2526
point_loss_function = nn.BCELoss()
2627
pair_loss_function = PairSCELoss()
2728
loss_function = HarmonicLoss(self.zeta)
@@ -36,6 +37,7 @@ def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.00
3637
user_id, item_id, _, score, n_samples, *neg_users = batch_data
3738
user_id: torch.Tensor = user_id.to(device)
3839
item_id: torch.Tensor = item_id.to(device)
40+
n_samples: torch.Tensor = n_samples.to(device)
3941
predicted_pos_score: torch.Tensor = self.irt_net(user_id, item_id)
4042
score: torch.Tensor = score.to(device)
4143
neg_score = 1 - score
@@ -44,6 +46,7 @@ def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.00
4446
predicted_neg_scores = []
4547
if neg_users:
4648
for neg_user in neg_users:
49+
neg_user: torch.Tensor = neg_user.to(device)
4750
predicted_neg_score = self.irt_net(neg_user, item_id)
4851
predicted_neg_scores.append(predicted_neg_score)
4952

@@ -79,10 +82,11 @@ def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.00
7982
)
8083

8184
if test_data is not None:
82-
eval_data = self.eval(test_data)
85+
eval_data = self.eval(test_data, device=device)
8386
print("[Epoch %d]\n%s" % (e, eval_data))
8487

8588
def eval(self, test_data, device="cpu"):
89+
self.irt_net = self.irt_net.to(device)
8690
self.irt_net.eval()
8791
y_pred = []
8892
y_true = []

EduCDM/IRR/NCDM.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def __init__(self, user_num, item_num, knowledge_num, zeta=0.5):
1717
self.zeta = zeta
1818

1919
def train(self, train_data, test_data=None, epoch=10, device="cpu", lr=0.002, silence=False) -> ...:
20+
self.ncdm_net = self.ncdm_net.to(device)
2021
point_loss_function = nn.BCELoss()
2122
pair_loss_function = PairSCELoss()
2223
loss_function = HarmonicLoss(self.zeta)
@@ -32,6 +33,7 @@ def train(self, train_data, test_data=None, epoch=10, device="cpu", lr=0.002, si
3233
user_id: torch.Tensor = user_id.to(device)
3334
item_id: torch.Tensor = item_id.to(device)
3435
knowledge: torch.Tensor = knowledge.to(device)
36+
n_samples: torch.Tensor = n_samples.to(device)
3537
predicted_pos_score: torch.Tensor = self.ncdm_net(user_id, item_id, knowledge)
3638
score: torch.Tensor = score.to(device)
3739
neg_score = 1 - score
@@ -40,6 +42,7 @@ def train(self, train_data, test_data=None, epoch=10, device="cpu", lr=0.002, si
4042
predicted_neg_scores = []
4143
if neg_users:
4244
for neg_user in neg_users:
45+
neg_user: torch.Tensor = neg_user.to(device)
4346
predicted_neg_score = self.ncdm_net(neg_user, item_id, knowledge)
4447
predicted_neg_scores.append(predicted_neg_score)
4548

@@ -75,10 +78,11 @@ def train(self, train_data, test_data=None, epoch=10, device="cpu", lr=0.002, si
7578
)
7679

7780
if test_data is not None:
78-
eval_data = self.eval(test_data)
81+
eval_data = self.eval(test_data, device=device)
7982
print("[Epoch %d]\n%s" % (e, eval_data))
8083

8184
def eval(self, test_data, device="cpu"):
85+
self.ncdm_net = self.ncdm_net.to(device)
8286
self.ncdm_net.eval()
8387
y_pred = []
8488
y_true = []
@@ -87,6 +91,7 @@ def eval(self, test_data, device="cpu"):
8791
user_id, item_id, knowledge, response = batch_data
8892
user_id: torch.Tensor = user_id.to(device)
8993
item_id: torch.Tensor = item_id.to(device)
94+
knowledge: torch.Tensor = knowledge.to(device)
9095
pred: torch.Tensor = self.ncdm_net(user_id, item_id, knowledge)
9196
y_pred.extend(pred.tolist())
9297
y_true.extend(response.tolist())

EduCDM/IRT/GD/IRT.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def __init__(self, user_num, item_num, value_range=None, a_range=None):
5353
self.irt_net = IRTNet(user_num, item_num, value_range, a_range)
5454

5555
def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.001) -> ...:
56+
self.irt_net = self.irt_net.to(device)
5657
loss_function = nn.BCELoss()
5758

5859
trainer = torch.optim.Adam(self.irt_net.parameters(), lr)
@@ -76,10 +77,11 @@ def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.00
7677
print("[Epoch %d] LogisticLoss: %.6f" % (e, float(np.mean(losses))))
7778

7879
if test_data is not None:
79-
auc, accuracy = self.eval(test_data)
80+
auc, accuracy = self.eval(test_data, device=device)
8081
print("[Epoch %d] auc: %.6f, accuracy: %.6f" % (e, auc, accuracy))
8182

8283
def eval(self, test_data, device="cpu") -> tuple:
84+
self.irt_net = self.irt_net.to(device)
8385
self.irt_net.eval()
8486
y_pred = []
8587
y_true = []

EduCDM/MCD/MCD.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ def __init__(self, user_num, item_num, latent_dim):
3636
self.mf_net = MFNet(user_num, item_num, latent_dim)
3737

3838
def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.001) -> ...:
39+
self.mf_net = self.mf_net.to(device)
40+
3941
loss_function = nn.BCELoss()
4042

4143
trainer = torch.optim.Adam(self.mf_net.parameters(), lr)
@@ -63,6 +65,7 @@ def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.00
6365
print("[Epoch %d] auc: %.6f, accuracy: %.6f" % (e, auc, accuracy))
6466

6567
def eval(self, test_data, device="cpu") -> tuple:
68+
self.mf_net = self.mf_net.to(device)
6669
self.mf_net.eval()
6770
y_pred = []
6871
y_true = []

EduCDM/MIRT/MIRT.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def __init__(self, user_num, item_num, latent_dim, a_range=None):
7575
self.irt_net = MIRTNet(user_num, item_num, latent_dim, a_range)
7676

7777
def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.001) -> ...:
78+
self.irt_net = self.irt_net.to(device)
7879
loss_function = nn.BCELoss()
7980

8081
trainer = torch.optim.Adam(self.irt_net.parameters(), lr)
@@ -98,10 +99,11 @@ def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.00
9899
print("[Epoch %d] LogisticLoss: %.6f" % (e, float(np.mean(losses))))
99100

100101
if test_data is not None:
101-
auc, accuracy = self.eval(test_data)
102+
auc, accuracy = self.eval(test_data,device=device)
102103
print("[Epoch %d] auc: %.6f, accuracy: %.6f" % (e, auc, accuracy))
103104

104105
def eval(self, test_data, device="cpu") -> tuple:
106+
self.irt_net = self.irt_net.to(device)
105107
self.irt_net.eval()
106108
y_pred = []
107109
y_true = []

0 commit comments

Comments
 (0)