Relatively simple JAX implementation of CLIP Usage: model, params = clip_jaxtorch.clip.load('ViT-B/32')