Skip to content

Commit

Permalink
Bug fix: in sine forward function, sum last dim.
Browse files Browse the repository at this point in the history
  • Loading branch information
bokang-ugent committed Jul 11, 2022
1 parent 03a9eba commit 6d2a45e
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
16 changes: 8 additions & 8 deletions examples/matching/run_ml_sine.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ def get_movielens_data(data_path, load_cache=False, seq_max_len=50):
mode=2,
neg_ratio=3,
min_item=0)
x_train = gen_model_input(df_train, user_profile, user_col, item_profile, item_col, seq_max_len=seq_max_len, padding='post', truncating='post')
x_train = gen_model_input(df_train, user_profile, user_col, item_profile, item_col, seq_max_len=seq_max_len, padding='pre', truncating='pre')
y_train = np.array([0] * df_train.shape[0]) #label=0 means the first pred value is positive sample
x_test = gen_model_input(df_test, user_profile, user_col, item_profile, item_col, seq_max_len=seq_max_len, padding='post', truncating='post')
x_test = gen_model_input(df_test, user_profile, user_col, item_profile, item_col, seq_max_len=seq_max_len, padding='pre', truncating='pre')
np.save("./data/ml-1m/saved/data_cache.npy", np.array((x_train, y_train, x_test), dtype=object))

user_features, item_features, history_features, neg_item_features = ["user_id"], ["movie_id"], ["hist_movie_id"], ["neg_items"]
Expand Down Expand Up @@ -98,20 +98,20 @@ def main(dataset_path, model_name, epoch, learning_rate, batch_size, weight_deca
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_path', default="./data/ml-1m/ml-1m_sample.csv")
parser.add_argument('--model_name', default='sine')
parser.add_argument('--epoch', type=int, default=3)
parser.add_argument('--learning_rate', type=float, default=1e-4)
parser.add_argument('--epoch', type=int, default=10)
parser.add_argument('--learning_rate', type=float, default=1e-3)
parser.add_argument('--batch_size', type=int, default=256)
parser.add_argument('--weight_decay', type=float, default=1e-4)
parser.add_argument('--weight_decay', type=float, default=1e-6)
parser.add_argument('--device', default='cuda:0')
parser.add_argument('--save_dir', default='./data/ml-1m/saved/')
parser.add_argument('--seed', type=int, default=2022)

parser.add_argument('--embedding_dim', type=int, default=128)
parser.add_argument('--hidden_dim', type=int, default=512)
parser.add_argument('--num_concept', type=int, default=50)
parser.add_argument('--num_intention', type=int, default=4)
parser.add_argument('--num_concept', type=int, default=10)
parser.add_argument('--num_intention', type=int, default=2)
parser.add_argument('--temperature', type=int, default=0.1)
parser.add_argument('--seq_max_len', type=int, default=20)
parser.add_argument('--seq_max_len', type=int, default=50)

args = parser.parse_args()
main(args.dataset_path, args.model_name, args.epoch, args.learning_rate, args.batch_size, args.weight_decay, args.device,
Expand Down
2 changes: 1 addition & 1 deletion torch_rechub/models/matching/sine.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def forward(self, x):
if self.mode == "item":
return item_embedding

y = torch.mul(user_embedding, item_embedding).sum(dim=1)
y = torch.mul(user_embedding, item_embedding).sum(dim=-1)

# # compute covariance regularizer
# M = torch.cov(self.concept_embedding.weight, correction=0)
Expand Down

0 comments on commit 6d2a45e

Please sign in to comment.