Skip to content

Commit 8b70f1e

Browse files
authored
Merge pull request #33 from ViviHong200709/main
[BUGFIX] Limit the range of parameters in IRT and MIRT
2 parents c6216d9 + ff8a4ec commit 8b70f1e

File tree

10 files changed

+86
-18
lines changed

10 files changed

+86
-18
lines changed

EduCDM/DINA/GD/DINA.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def __init__(self, user_num, item_num, hidden_dim, ste=False):
8585
self.dina_net = DINANet(user_num, item_num, hidden_dim)
8686

8787
def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.001) -> ...:
88+
self.dina_net = self.dina_net.to(device)
8889
loss_function = nn.BCELoss()
8990

9091
trainer = torch.optim.Adam(self.dina_net.parameters(), lr)
@@ -109,10 +110,11 @@ def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.00
109110
print("[Epoch %d] LogisticLoss: %.6f" % (e, float(np.mean(losses))))
110111

111112
if test_data is not None:
112-
auc, accuracy = self.eval(test_data)
113+
auc, accuracy = self.eval(test_data, device=device)
113114
print("[Epoch %d] auc: %.6f, accuracy: %.6f" % (e, auc, accuracy))
114115

115116
def eval(self, test_data, device="cpu") -> tuple:
117+
self.dina_net = self.dina_net.to(device)
116118
self.dina_net.eval()
117119
y_pred = []
118120
y_true = []

EduCDM/IRR/DINA.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def __init__(self, user_num, item_num, knowledge_num, ste=False, zeta=0.5):
1717
self.zeta = zeta
1818

1919
def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.001) -> ...:
20+
self.dina_net = self.dina_net.to(device)
2021
point_loss_function = nn.BCELoss()
2122
pair_loss_function = PairSCELoss()
2223
loss_function = HarmonicLoss(self.zeta)
@@ -32,6 +33,7 @@ def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.00
3233
user_id: torch.Tensor = user_id.to(device)
3334
item_id: torch.Tensor = item_id.to(device)
3435
knowledge: torch.Tensor = knowledge.to(device)
36+
n_samples: torch.Tensor = n_samples.to(device)
3537
predicted_pos_score: torch.Tensor = self.dina_net(user_id, item_id, knowledge)
3638
score: torch.Tensor = score.to(device)
3739
neg_score = 1 - score
@@ -40,6 +42,7 @@ def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.00
4042
predicted_neg_scores = []
4143
if neg_users:
4244
for neg_user in neg_users:
45+
neg_user: torch.Tensor = neg_user.to(device)
4346
predicted_neg_score = self.dina_net(neg_user, item_id, knowledge)
4447
predicted_neg_scores.append(predicted_neg_score)
4548

@@ -75,10 +78,11 @@ def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.00
7578
)
7679

7780
if test_data is not None:
78-
eval_data = self.eval(test_data)
81+
eval_data = self.eval(test_data, device=device)
7982
print("[Epoch %d]\n%s" % (e, eval_data))
8083

8184
def eval(self, test_data, device="cpu"):
85+
self.dina_net = self.dina_net.to(device)
8286
self.dina_net.eval()
8387
y_pred = []
8488
y_true = []
@@ -87,6 +91,7 @@ def eval(self, test_data, device="cpu"):
8791
user_id, item_id, knowledge, response = batch_data
8892
user_id: torch.Tensor = user_id.to(device)
8993
item_id: torch.Tensor = item_id.to(device)
94+
knowledge: torch.Tensor = knowledge.to(device)
9095
pred: torch.Tensor = self.dina_net(user_id, item_id, knowledge)
9196
y_pred.extend(pred.tolist())
9297
y_true.extend(response.tolist())

EduCDM/IRR/IRT.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@
1414

1515

1616
class IRT(PointIRT):
17-
def __init__(self, user_num, item_num, knowledge_num, value_range=10, zeta=0.5):
18-
super(IRT, self).__init__(user_num, item_num, value_range=value_range)
17+
def __init__(self, user_num, item_num, knowledge_num, zeta=0.5):
18+
super(IRT, self).__init__(user_num, item_num)
1919
self.knowledge_num = knowledge_num
2020
self.zeta = zeta
2121

2222
def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.001) -> ...:
23+
self.irt_net = self.irt_net.to(device)
2324
point_loss_function = nn.BCELoss()
2425
pair_loss_function = PairSCELoss()
2526
loss_function = HarmonicLoss(self.zeta)
@@ -34,6 +35,7 @@ def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.00
3435
user_id, item_id, _, score, n_samples, *neg_users = batch_data
3536
user_id: torch.Tensor = user_id.to(device)
3637
item_id: torch.Tensor = item_id.to(device)
38+
n_samples: torch.Tensor = n_samples.to(device)
3739
predicted_pos_score: torch.Tensor = self.irt_net(user_id, item_id)
3840
score: torch.Tensor = score.to(device)
3941
neg_score = 1 - score
@@ -42,6 +44,7 @@ def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.00
4244
predicted_neg_scores = []
4345
if neg_users:
4446
for neg_user in neg_users:
47+
neg_user: torch.Tensor = neg_user.to(device)
4548
predicted_neg_score = self.irt_net(neg_user, item_id)
4649
predicted_neg_scores.append(predicted_neg_score)
4750

@@ -77,10 +80,11 @@ def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.00
7780
)
7881

7982
if test_data is not None:
80-
eval_data = self.eval(test_data)
83+
eval_data = self.eval(test_data, device=device)
8184
print("[Epoch %d]\n%s" % (e, eval_data))
8285

8386
def eval(self, test_data, device="cpu"):
87+
self.irt_net = self.irt_net.to(device)
8488
self.irt_net.eval()
8589
y_pred = []
8690
y_true = []

EduCDM/IRR/MIRT.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def __init__(self, user_num, item_num, knowledge_num, latent_dim=None, zeta=0.5)
2222
self.zeta = zeta
2323

2424
def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.001) -> ...:
25+
self.irt_net = self.irt_net.to(device)
2526
point_loss_function = nn.BCELoss()
2627
pair_loss_function = PairSCELoss()
2728
loss_function = HarmonicLoss(self.zeta)
@@ -36,6 +37,7 @@ def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.00
3637
user_id, item_id, _, score, n_samples, *neg_users = batch_data
3738
user_id: torch.Tensor = user_id.to(device)
3839
item_id: torch.Tensor = item_id.to(device)
40+
n_samples: torch.Tensor = n_samples.to(device)
3941
predicted_pos_score: torch.Tensor = self.irt_net(user_id, item_id)
4042
score: torch.Tensor = score.to(device)
4143
neg_score = 1 - score
@@ -44,6 +46,7 @@ def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.00
4446
predicted_neg_scores = []
4547
if neg_users:
4648
for neg_user in neg_users:
49+
neg_user: torch.Tensor = neg_user.to(device)
4750
predicted_neg_score = self.irt_net(neg_user, item_id)
4851
predicted_neg_scores.append(predicted_neg_score)
4952

@@ -79,10 +82,11 @@ def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.00
7982
)
8083

8184
if test_data is not None:
82-
eval_data = self.eval(test_data)
85+
eval_data = self.eval(test_data, device=device)
8386
print("[Epoch %d]\n%s" % (e, eval_data))
8487

8588
def eval(self, test_data, device="cpu"):
89+
self.irt_net = self.irt_net.to(device)
8690
self.irt_net.eval()
8791
y_pred = []
8892
y_true = []

EduCDM/IRR/NCDM.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def __init__(self, user_num, item_num, knowledge_num, zeta=0.5):
1717
self.zeta = zeta
1818

1919
def train(self, train_data, test_data=None, epoch=10, device="cpu", lr=0.002, silence=False) -> ...:
20+
self.ncdm_net = self.ncdm_net.to(device)
2021
point_loss_function = nn.BCELoss()
2122
pair_loss_function = PairSCELoss()
2223
loss_function = HarmonicLoss(self.zeta)
@@ -32,6 +33,7 @@ def train(self, train_data, test_data=None, epoch=10, device="cpu", lr=0.002, si
3233
user_id: torch.Tensor = user_id.to(device)
3334
item_id: torch.Tensor = item_id.to(device)
3435
knowledge: torch.Tensor = knowledge.to(device)
36+
n_samples: torch.Tensor = n_samples.to(device)
3537
predicted_pos_score: torch.Tensor = self.ncdm_net(user_id, item_id, knowledge)
3638
score: torch.Tensor = score.to(device)
3739
neg_score = 1 - score
@@ -40,6 +42,7 @@ def train(self, train_data, test_data=None, epoch=10, device="cpu", lr=0.002, si
4042
predicted_neg_scores = []
4143
if neg_users:
4244
for neg_user in neg_users:
45+
neg_user: torch.Tensor = neg_user.to(device)
4346
predicted_neg_score = self.ncdm_net(neg_user, item_id, knowledge)
4447
predicted_neg_scores.append(predicted_neg_score)
4548

@@ -75,10 +78,11 @@ def train(self, train_data, test_data=None, epoch=10, device="cpu", lr=0.002, si
7578
)
7679

7780
if test_data is not None:
78-
eval_data = self.eval(test_data)
81+
eval_data = self.eval(test_data, device=device)
7982
print("[Epoch %d]\n%s" % (e, eval_data))
8083

8184
def eval(self, test_data, device="cpu"):
85+
self.ncdm_net = self.ncdm_net.to(device)
8286
self.ncdm_net.eval()
8387
y_pred = []
8488
y_true = []
@@ -87,6 +91,7 @@ def eval(self, test_data, device="cpu"):
8791
user_id, item_id, knowledge, response = batch_data
8892
user_id: torch.Tensor = user_id.to(device)
8993
item_id: torch.Tensor = item_id.to(device)
94+
knowledge: torch.Tensor = knowledge.to(device)
9095
pred: torch.Tensor = self.ncdm_net(user_id, item_id, knowledge)
9196
y_pred.extend(pred.tolist())
9297
y_true.extend(response.tolist())

EduCDM/IRT/GD/IRT.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@
66
import torch
77
from EduCDM import CDM
88
from torch import nn
9+
import torch.nn.functional as F
910
from tqdm import tqdm
1011
from ..irt import irt3pl
1112
from sklearn.metrics import roc_auc_score, accuracy_score
1213

1314

1415
class IRTNet(nn.Module):
15-
def __init__(self, user_num, item_num, value_range, irf_kwargs=None):
16+
def __init__(self, user_num, item_num, value_range, a_range, irf_kwargs=None):
1617
super(IRTNet, self).__init__()
1718
self.user_num = user_num
1819
self.item_num = item_num
@@ -22,16 +23,23 @@ def __init__(self, user_num, item_num, value_range, irf_kwargs=None):
2223
self.b = nn.Embedding(self.item_num, 1)
2324
self.c = nn.Embedding(self.item_num, 1)
2425
self.value_range = value_range
26+
self.a_range = a_range
2527

2628
def forward(self, user, item):
2729
theta = torch.squeeze(self.theta(user), dim=-1)
28-
theta = self.value_range * (torch.sigmoid(theta) - 0.5)
2930
a = torch.squeeze(self.a(item), dim=-1)
30-
a = torch.sigmoid(a)
3131
b = torch.squeeze(self.b(item), dim=-1)
32-
b = self.value_range * (torch.sigmoid(b) - 0.5)
3332
c = torch.squeeze(self.c(item), dim=-1)
3433
c = torch.sigmoid(c)
34+
if self.value_range is not None:
35+
theta = self.value_range * (torch.sigmoid(theta) - 0.5)
36+
b = self.value_range * (torch.sigmoid(b) - 0.5)
37+
if self.a_range is not None:
38+
a = self.a_range * torch.sigmoid(a)
39+
else:
40+
a = F.softplus(a)
41+
if torch.max(theta != theta) or torch.max(a != a) or torch.max(b != b): # pragma: no cover
42+
raise ValueError('ValueError:theta,a,b may contains nan! The value_range or a_range is too large.')
3543
return self.irf(theta, a, b, c, **self.irf_kwargs)
3644

3745
@classmethod
@@ -40,11 +48,12 @@ def irf(cls, theta, a, b, c, **kwargs):
4048

4149

4250
class IRT(CDM):
43-
def __init__(self, user_num, item_num, value_range=10):
51+
def __init__(self, user_num, item_num, value_range=None, a_range=None):
4452
super(IRT, self).__init__()
45-
self.irt_net = IRTNet(user_num, item_num, value_range)
53+
self.irt_net = IRTNet(user_num, item_num, value_range, a_range)
4654

4755
def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.001) -> ...:
56+
self.irt_net = self.irt_net.to(device)
4857
loss_function = nn.BCELoss()
4958

5059
trainer = torch.optim.Adam(self.irt_net.parameters(), lr)
@@ -68,10 +77,11 @@ def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.00
6877
print("[Epoch %d] LogisticLoss: %.6f" % (e, float(np.mean(losses))))
6978

7079
if test_data is not None:
71-
auc, accuracy = self.eval(test_data)
80+
auc, accuracy = self.eval(test_data, device=device)
7281
print("[Epoch %d] auc: %.6f, accuracy: %.6f" % (e, auc, accuracy))
7382

7483
def eval(self, test_data, device="cpu") -> tuple:
84+
self.irt_net = self.irt_net.to(device)
7585
self.irt_net.eval()
7686
y_pred = []
7787
y_true = []

EduCDM/MCD/MCD.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def __init__(self, user_num, item_num, latent_dim):
3636
self.mf_net = MFNet(user_num, item_num, latent_dim)
3737

3838
def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.001) -> ...:
39+
self.mf_net = self.mf_net.to(device)
3940
loss_function = nn.BCELoss()
4041

4142
trainer = torch.optim.Adam(self.mf_net.parameters(), lr)
@@ -63,6 +64,7 @@ def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.00
6364
print("[Epoch %d] auc: %.6f, accuracy: %.6f" % (e, auc, accuracy))
6465

6566
def eval(self, test_data, device="cpu") -> tuple:
67+
self.mf_net = self.mf_net.to(device)
6668
self.mf_net.eval()
6769
y_pred = []
6870
y_true = []

EduCDM/MIRT/MIRT.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch
88
from EduCDM import CDM
99
from torch import nn
10+
import torch.nn.functional as F
1011
from tqdm import tqdm
1112
from sklearn.metrics import roc_auc_score, accuracy_score
1213

@@ -41,19 +42,26 @@ def irt2pl(theta, a, b, *, F=np):
4142

4243

4344
class MIRTNet(nn.Module):
44-
def __init__(self, user_num, item_num, latent_dim, irf_kwargs=None):
45+
def __init__(self, user_num, item_num, latent_dim, a_range, irf_kwargs=None):
4546
super(MIRTNet, self).__init__()
4647
self.user_num = user_num
4748
self.item_num = item_num
4849
self.irf_kwargs = irf_kwargs if irf_kwargs is not None else {}
4950
self.theta = nn.Embedding(self.user_num, latent_dim)
5051
self.a = nn.Embedding(self.item_num, latent_dim)
5152
self.b = nn.Embedding(self.item_num, 1)
53+
self.a_range = a_range
5254

5355
def forward(self, user, item):
5456
theta = torch.squeeze(self.theta(user), dim=-1)
5557
a = torch.squeeze(self.a(item), dim=-1)
58+
if self.a_range is not None:
59+
a = self.a_range * torch.sigmoid(a)
60+
else:
61+
a = F.softplus(a)
5662
b = torch.squeeze(self.b(item), dim=-1)
63+
if torch.max(theta != theta) or torch.max(a != a) or torch.max(b != b): # pragma: no cover
64+
raise ValueError('ValueError:theta,a,b may contains nan! The a_range is too large.')
5765
return self.irf(theta, a, b, **self.irf_kwargs)
5866

5967
@classmethod
@@ -62,11 +70,12 @@ def irf(cls, theta, a, b, **kwargs):
6270

6371

6472
class MIRT(CDM):
65-
def __init__(self, user_num, item_num, latent_dim):
73+
def __init__(self, user_num, item_num, latent_dim, a_range=None):
6674
super(MIRT, self).__init__()
67-
self.irt_net = MIRTNet(user_num, item_num, latent_dim)
75+
self.irt_net = MIRTNet(user_num, item_num, latent_dim, a_range)
6876

6977
def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.001) -> ...:
78+
self.irt_net = self.irt_net.to(device)
7079
loss_function = nn.BCELoss()
7180

7281
trainer = torch.optim.Adam(self.irt_net.parameters(), lr)
@@ -90,10 +99,11 @@ def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.00
9099
print("[Epoch %d] LogisticLoss: %.6f" % (e, float(np.mean(losses))))
91100

92101
if test_data is not None:
93-
auc, accuracy = self.eval(test_data)
102+
auc, accuracy = self.eval(test_data, device=device)
94103
print("[Epoch %d] auc: %.6f, accuracy: %.6f" % (e, auc, accuracy))
95104

96105
def eval(self, test_data, device="cpu") -> tuple:
106+
self.irt_net = self.irt_net.to(device)
97107
self.irt_net.eval()
98108
y_pred = []
99109
y_true = []

tests/irt/gd/test_gdirt.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# 2021/4/23 @ tongshiwei
33

44
from EduCDM import GDIRT
5+
import pytest
56

67

78
def test_train(data, conf, tmp_path):
@@ -11,3 +12,15 @@ def test_train(data, conf, tmp_path):
1112
filepath = tmp_path / "mcd.params"
1213
cdm.save(filepath)
1314
cdm.load(filepath)
15+
16+
17+
def test_exception(data, conf, tmp_path):
18+
try:
19+
user_num, item_num = conf
20+
cdm = GDIRT(user_num, item_num, value_range=10, a_range=100)
21+
cdm.train(data, test_data=data, epoch=2)
22+
filepath = tmp_path / "mcd.params"
23+
cdm.save(filepath)
24+
cdm.load(filepath)
25+
except ValueError:
26+
print(ValueError)

tests/mirt/test_mirt.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# 2021/4/23 @ tongshiwei
33

44
from EduCDM import MIRT
5+
import pytest
56

67

78
def test_train(data, conf, tmp_path):
@@ -11,3 +12,15 @@ def test_train(data, conf, tmp_path):
1112
filepath = tmp_path / "mcd.params"
1213
cdm.save(filepath)
1314
cdm.load(filepath)
15+
16+
17+
def test_exception(data, conf, tmp_path):
18+
try:
19+
user_num, item_num = conf
20+
cdm = MIRT(user_num, item_num, 10, a_range=100)
21+
cdm.train(data, test_data=data, epoch=2)
22+
filepath = tmp_path / "mcd.params"
23+
cdm.save(filepath)
24+
cdm.load(filepath)
25+
except ValueError:
26+
print(ValueError)

0 commit comments

Comments
 (0)