By default, the code loads checkpoints with torch.load(PATH) without map_location, which deserializes CUDA tensors onto cuda:0 and can cause device mismatch when the intended device is not GPU 0.
Use model.load_state_dict(torch.load(PATH, map_location=device, weights_only=True)) (see [medclip.py](https://github.com/RyanWangZf/MedCLIP/blob/main/medclip/modeling_medclip.py) at the [L-55,66....147 and many other instances]) to allow training and reproducibility on GPU IDs other than 0.
Let me know if you want me to open a PR for this change, or if you will address it directly.
Thank you.