This is a template repository for creating Cog models that efficiently handle model weights with proper caching. It includes tools to upload model weights to Google Cloud Storage and generate download code for your predict.py
file.
To use this template for your own model:
- Clone this repository
- Modify
predict.py
with your model's implementation - Update
cog.yaml
with your model's dependencies - Use
cache_manager.py
to upload and manage model weights
predict.py
: The main model implementation filecache_manager.py
: Script for uploading model weights to GCS and generating download codecog.yaml
: Cog configuration file that defines your model's environment
A key feature of this template is the cache_manager.py
script, which helps you:
- Upload model weights to Google Cloud Storage (GCS)
- Generate code for downloading those weights in your
predict.py
- Handle both individual files and directories efficiently
- Google Cloud SDK installed and configured (
gcloud
command) - Permission to upload to the specified GCS bucket (default:
gs://replicate-weights/
) tar
command available in your PATH
python cache_manager.py --model-name your-model-name --local-dirs model_cache
This will:
- Find files and directories in the
model_cache
directory - Create tar archives of each directory
- Upload both individual files and tar archives to GCS
- Generate code snippets for downloading the weights in your
predict.py
python cache_manager.py \
--model-name your-model-name \
--local-dirs model_cache weights \
--gcs-base-path gs://replicate-weights/ \
--cdn-base-url https://weights.replicate.delivery/default/ \
--keep-tars
--model-name
: Required. The name of your model (used in paths)--local-dirs
: Required. One or more local directories to process--gcs-base-path
: Optional. Base Google Cloud Storage path--cdn-base-url
: Optional. Base CDN URL--keep-tars
: Optional. Keep the generated .tar files locally after upload
-
Develop your model locally:
# Run your model once to download weights to model_cache cog predict -i prompt="test"
-
Upload model weights:
python cache_manager.py --model-name your-model-name --local-dirs model_cache
-
Copy the generated code snippet into your
predict.py
-
Test that the model can download weights:
rm -rf model_cache cog predict -i prompt="test"
The template comes with a sample Stable Diffusion implementation in predict.py
that demonstrates:
- Setting up the model cache directory
- Downloading weights from GCS with progress reporting
- Setting environment variables for model caching
- Random seed generation for reproducibility
- Output format and quality options
-
Environment Variables: Set cache-related environment variables early
os.environ["HF_HOME"] = MODEL_CACHE os.environ["TORCH_HOME"] = MODEL_CACHE # etc.
-
Seed Management: Provide a seed parameter and implement random seed generation
if seed is None: seed = int.from_bytes(os.urandom(2), "big") print(f"Using seed: {seed}")
-
Output Formats: Support multiple output formats (webp, jpg, png) with quality controls
output_format: str = Input( description="Format of the output image", choices=["webp", "jpg", "png"], default="webp" ) output_quality: int = Input( description="The image compression quality...", ge=1, le=100, default=80 )
After setting up your model, you can push it to Replicate:
- Create a new model on Replicate
- Push your model:
cog push r8.im/username/model-name
MIT
⭐ Star this on GitHub!
👋 Follow zsxkib
on Twitter/X