-
借助工具下载 GraFITi 官方代码至
./GraFITi/ -
安装 tsdm
- 下载 tsdm 官方代码 至
./tsdm/ - 创建 conda 虚拟环境,注意
python=3.11 - 用
./GraFITi/tsdm替换./tsdm/src/tsdm - 将
./tsdm/src/tsdm/viz/_config.py中的USE_TEX: Final[bool] = matplotlib.checkdep_usetex(True)改为USE_TEX: Final[bool] = False - 进入
./tsdm/目录,执行pip install -e .
- 下载 tsdm 官方代码 至
-
修改
./GraFITi/train_grafiti.py-
创建模型存储目录
if not os.path.exists('saved_models/'): os.makedirs('saved_models/')
-
修改优化器配置
OPTIMIZER_CONFIG = { "lr": ARGS.learn_rate, "betas": ARGS.betas, "weight_decay": ARGS.weight_decay, }
-
如果需要,添加
tqdm打印进度条
-
-
进入
./GraFITi/目录,运行如下命令运行官方示例,如果提示缺包自行安装即可
python train_grafiti.py --epochs 200 --learn-rate 0.001 --batch-size 128 --attn-head 1 --latent-dim 128 --nlayers 4 --dataset physionet2012 --fold 0 -ct 36 -ft 12-
下载本项目
-
创建 conda 虚拟环境,注意
python=3.11 -
进入
tsdm-main目录,执行pip install -e . -
进入
./GraFITi/目录,运行如下命令运行官方示例,如果提示缺包自行安装即可python train_grafiti.py --epochs 200 --learn-rate 0.001 --batch-size 128 --attn-head 1 --latent-dim 128 --nlayers 4 --dataset physionet2012 --fold 0 -ct 36 -ft 12
