66import torch
77from EduCDM import CDM
88from torch import nn
9+ import torch .nn .functional as F
910from tqdm import tqdm
1011from ..irt import irt3pl
1112from sklearn .metrics import roc_auc_score , accuracy_score
1213
1314
1415class IRTNet (nn .Module ):
15- def __init__ (self , user_num , item_num , value_range , irf_kwargs = None ):
16+ def __init__ (self , user_num , item_num , value_range , a_range , irf_kwargs = None ):
1617 super (IRTNet , self ).__init__ ()
1718 self .user_num = user_num
1819 self .item_num = item_num
@@ -22,16 +23,23 @@ def __init__(self, user_num, item_num, value_range, irf_kwargs=None):
2223 self .b = nn .Embedding (self .item_num , 1 )
2324 self .c = nn .Embedding (self .item_num , 1 )
2425 self .value_range = value_range
26+ self .a_range = a_range
2527
2628 def forward (self , user , item ):
2729 theta = torch .squeeze (self .theta (user ), dim = - 1 )
28- theta = self .value_range * (torch .sigmoid (theta ) - 0.5 )
2930 a = torch .squeeze (self .a (item ), dim = - 1 )
30- a = torch .sigmoid (a )
3131 b = torch .squeeze (self .b (item ), dim = - 1 )
32- b = self .value_range * (torch .sigmoid (b ) - 0.5 )
3332 c = torch .squeeze (self .c (item ), dim = - 1 )
3433 c = torch .sigmoid (c )
34+ if self .value_range is not None :
35+ theta = self .value_range * (torch .sigmoid (theta ) - 0.5 )
36+ b = self .value_range * (torch .sigmoid (b ) - 0.5 )
37+ if self .a_range is not None :
38+ a = self .a_range * torch .sigmoid (a )
39+ else :
40+ a = F .softplus (a )
41+ if torch .max (theta != theta ) or torch .max (a != a ) or torch .max (b != b ): # pragma: no cover
42+ raise ValueError ('ValueError:theta,a,b may contains nan! The value_range or a_range is too large.' )
3543 return self .irf (theta , a , b , c , ** self .irf_kwargs )
3644
3745 @classmethod
@@ -40,11 +48,12 @@ def irf(cls, theta, a, b, c, **kwargs):
4048
4149
4250class IRT (CDM ):
43- def __init__ (self , user_num , item_num , value_range = 10 ):
51+ def __init__ (self , user_num , item_num , value_range = None , a_range = None ):
4452 super (IRT , self ).__init__ ()
45- self .irt_net = IRTNet (user_num , item_num , value_range )
53+ self .irt_net = IRTNet (user_num , item_num , value_range , a_range )
4654
4755 def train (self , train_data , test_data = None , * , epoch : int , device = "cpu" , lr = 0.001 ) -> ...:
56+ self .irt_net = self .irt_net .to (device )
4857 loss_function = nn .BCELoss ()
4958
5059 trainer = torch .optim .Adam (self .irt_net .parameters (), lr )
@@ -68,10 +77,11 @@ def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.00
6877 print ("[Epoch %d] LogisticLoss: %.6f" % (e , float (np .mean (losses ))))
6978
7079 if test_data is not None :
71- auc , accuracy = self .eval (test_data )
80+ auc , accuracy = self .eval (test_data , device = device )
7281 print ("[Epoch %d] auc: %.6f, accuracy: %.6f" % (e , auc , accuracy ))
7382
7483 def eval (self , test_data , device = "cpu" ) -> tuple :
84+ self .irt_net = self .irt_net .to (device )
7585 self .irt_net .eval ()
7686 y_pred = []
7787 y_true = []
0 commit comments