Skip to content
/ cog-template Public template

Template for Cog models with built-in model weights caching and CDN integration

License

Notifications You must be signed in to change notification settings

zsxkib/cog-template

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Cog Template Repository

This is a template repository for creating Cog models that efficiently handle model weights with proper caching. It includes tools to upload model weights to Google Cloud Storage and generate download code for your predict.py file.

Replicate

Getting Started

To use this template for your own model:

  1. Clone this repository
  2. Modify predict.py with your model's implementation
  3. Update cog.yaml with your model's dependencies
  4. Use cache_manager.py to upload and manage model weights

Repository Structure

  • predict.py: The main model implementation file
  • cache_manager.py: Script for uploading model weights to GCS and generating download code
  • cog.yaml: Cog configuration file that defines your model's environment

Managing Model Weights with cache_manager.py

A key feature of this template is the cache_manager.py script, which helps you:

  1. Upload model weights to Google Cloud Storage (GCS)
  2. Generate code for downloading those weights in your predict.py
  3. Handle both individual files and directories efficiently

Prerequisites for Using cache_manager.py

  • Google Cloud SDK installed and configured (gcloud command)
  • Permission to upload to the specified GCS bucket (default: gs://replicate-weights/)
  • tar command available in your PATH

Basic Usage

python cache_manager.py --model-name your-model-name --local-dirs model_cache

This will:

  1. Find files and directories in the model_cache directory
  2. Create tar archives of each directory
  3. Upload both individual files and tar archives to GCS
  4. Generate code snippets for downloading the weights in your predict.py

Advanced Usage

python cache_manager.py \
    --model-name your-model-name \
    --local-dirs model_cache weights \
    --gcs-base-path gs://replicate-weights/ \
    --cdn-base-url https://weights.replicate.delivery/default/ \
    --keep-tars

Parameters

  • --model-name: Required. The name of your model (used in paths)
  • --local-dirs: Required. One or more local directories to process
  • --gcs-base-path: Optional. Base Google Cloud Storage path
  • --cdn-base-url: Optional. Base CDN URL
  • --keep-tars: Optional. Keep the generated .tar files locally after upload

Workflow Example

  1. Develop your model locally:

    # Run your model once to download weights to model_cache
    cog predict -i prompt="test"
  2. Upload model weights:

    python cache_manager.py --model-name your-model-name --local-dirs model_cache
  3. Copy the generated code snippet into your predict.py

  4. Test that the model can download weights:

    rm -rf model_cache
    cog predict -i prompt="test"

Example Implementation

The template comes with a sample Stable Diffusion implementation in predict.py that demonstrates:

  • Setting up the model cache directory
  • Downloading weights from GCS with progress reporting
  • Setting environment variables for model caching
  • Random seed generation for reproducibility
  • Output format and quality options

Best Practices

  • Environment Variables: Set cache-related environment variables early

    os.environ["HF_HOME"] = MODEL_CACHE
    os.environ["TORCH_HOME"] = MODEL_CACHE
    # etc.
  • Seed Management: Provide a seed parameter and implement random seed generation

    if seed is None:
        seed = int.from_bytes(os.urandom(2), "big")
    print(f"Using seed: {seed}")
  • Output Formats: Support multiple output formats (webp, jpg, png) with quality controls

    output_format: str = Input(
        description="Format of the output image",
        choices=["webp", "jpg", "png"],
        default="webp"
    )
    output_quality: int = Input(
        description="The image compression quality...",
        ge=1, le=100, default=80
    )

Deploying to Replicate

After setting up your model, you can push it to Replicate:

  1. Create a new model on Replicate
  2. Push your model:
    cog push r8.im/username/model-name

License

MIT



⭐ Star this on GitHub!

👋 Follow zsxkib on Twitter/X

About

Template for Cog models with built-in model weights caching and CDN integration

Topics

Resources

License

Stars

Watchers

Forks

Languages