Skip to content

A Lighting Pytorch Framework for Recommendation Models, Easy-to-use and Easy-to-extend.

License

Notifications You must be signed in to change notification settings

Lyons-T/torch-rechub

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

39 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Torch-RecHub

中文Wiki站

查看最新研发进度,认领感兴趣的研发任务,学习rechub模型复现心得,加入rechub共建者团队等

点击链接

安装

#稳定版 
pip install torch-rechub

#最新版
1. git clone https://github.com/datawhalechina/torch-rechub.git
2. cd torch-rechub
3. python setup.py install

核心定位

易用易拓展,聚焦复现业界实用的推荐模型,以及泛生态化的推荐场景

主要特性

  • scikit-learn风格易用的API(fit、predict),即插即用

  • 模型训练与模型定义解耦,易拓展,可针对不同类型的模型设置不同的训练机制

  • 接受pandas的DataFrame、Dict数据输入,上手成本低

  • 高度模块化,支持常见Layer,容易调用组装成新模型

    • LR、MLP、FM、FFM、CIN

    • target-attention、self-attention、transformer

  • 支持常见排序模型

    • WideDeep、DeepFM、DIN、DCN、xDeepFM等
  • 支持常见召回模型

    • DSSM、YoutubeDNN、YoutubeDSSM、FacebookEBR、MIND等
  • 丰富的多任务学习支持

    • SharedBottom、ESMM、MMOE、PLE、AITM等模型

    • GradNorm、UWL、MetaBanlance等动态loss加权机制

  • 聚焦更生态化的推荐场景

    • 冷启动

    • 延迟反馈

    • 去偏
  • 支持丰富的训练机制

    • 对比学习

    • 蒸馏学习

  • 第三方高性能开源Trainer支持(Pytorch Lighting)

  • 更多模型正在开发中

快速使用

单任务排序

from torch_rechub.models.ranking import WideDeep, DeepFM, DIN
from torch_rechub.trainers import CTRTrainer
from torch_rechub.basic.utils import DataGenerator

dg = DataGenerator(x, y)
train_dataloader, val_dataloader, test_dataloader = dg.generate_dataloader()

model = DeepFM(deep_features=deep_features, fm_features=fm_features, mlp_params={"dims": [256, 128], "dropout": 0.2, "activation": "relu"})

ctr_trainer = CTRTrainer(model)
ctr_trainer.fit(train_dataloader, val_dataloader)
auc = ctr_trainer.evaluate(ctr_trainer.model, test_dataloader)

多任务排序

from torch_rechub.models.multi_task import SharedBottom, ESMM, MMOE, PLE, AITM
from torch_rechub.trainers import MTLTrainer

model = MMOE(features, task_types, n_expert=3, expert_params={"dims": [64,32,16]}, tower_params_list=[{"dims": [8]}, {"dims": [8]}])

ctr_trainer = MTLTrainer(model)
ctr_trainer.fit(train_dataloader, val_dataloader)
auc = ctr_trainer.evaluate(ctr_trainer.model, test_dataloader)

About

A Lighting Pytorch Framework for Recommendation Models, Easy-to-use and Easy-to-extend.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 75.9%
  • Jupyter Notebook 24.1%