@@ -17,6 +17,7 @@ def __init__(self, user_num, item_num, knowledge_num, ste=False, zeta=0.5):
1717 self .zeta = zeta
1818
1919 def train (self , train_data , test_data = None , * , epoch : int , device = "cpu" , lr = 0.001 ) -> ...:
20+ self .dina_net = self .dina_net .to (device )
2021 point_loss_function = nn .BCELoss ()
2122 pair_loss_function = PairSCELoss ()
2223 loss_function = HarmonicLoss (self .zeta )
@@ -32,6 +33,7 @@ def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.00
3233 user_id : torch .Tensor = user_id .to (device )
3334 item_id : torch .Tensor = item_id .to (device )
3435 knowledge : torch .Tensor = knowledge .to (device )
36+ n_samples : torch .Tensor = n_samples .to (device )
3537 predicted_pos_score : torch .Tensor = self .dina_net (user_id , item_id , knowledge )
3638 score : torch .Tensor = score .to (device )
3739 neg_score = 1 - score
@@ -40,6 +42,7 @@ def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.00
4042 predicted_neg_scores = []
4143 if neg_users :
4244 for neg_user in neg_users :
45+ neg_user : torch .Tensor = neg_user .to (device )
4346 predicted_neg_score = self .dina_net (neg_user , item_id , knowledge )
4447 predicted_neg_scores .append (predicted_neg_score )
4548
@@ -75,10 +78,11 @@ def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.00
7578 )
7679
7780 if test_data is not None :
78- eval_data = self .eval (test_data )
81+ eval_data = self .eval (test_data , device = device )
7982 print ("[Epoch %d]\n %s" % (e , eval_data ))
8083
8184 def eval (self , test_data , device = "cpu" ):
85+ self .dina_net = self .dina_net .to (device )
8286 self .dina_net .eval ()
8387 y_pred = []
8488 y_true = []
@@ -87,6 +91,7 @@ def eval(self, test_data, device="cpu"):
8791 user_id , item_id , knowledge , response = batch_data
8892 user_id : torch .Tensor = user_id .to (device )
8993 item_id : torch .Tensor = item_id .to (device )
94+ knowledge : torch .Tensor = knowledge .to (device )
9095 pred : torch .Tensor = self .dina_net (user_id , item_id , knowledge )
9196 y_pred .extend (pred .tolist ())
9297 y_true .extend (response .tolist ())
0 commit comments