Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/SUMMARY.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
- [Simulator](simulator.md)
- [Interactive scenario editor](scene-editor.md)
- [Visualizer](visualizer.md)
- [Export model to ONNX](export-onnx.md)

# Data

Expand Down
81 changes: 81 additions & 0 deletions docs/src/export-onnx.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Exporting PufferDrive Models to ONNX

PufferDrive provides a utility script to export trained PyTorch models to the ONNX format. This is useful for deployment, inference optimization, or using the model in environments that support ONNX Runtime.

## Usage

The export script is located at `scripts/export_onnx.py`. You can run it from the root of the repository.

### Basic Usage

To export a model using default settings (assuming you have a checkpoint at the default path or specify one):

```bash
python scripts/export_onnx.py --checkpoint path/to/your/checkpoint.pt
```

This will create an `.onnx` file in the same directory as the checkpoint, with the same name (e.g., `checkpoint.onnx`).

### Specifying Output Path

You can specify a custom output path for the ONNX file:

```bash
python scripts/export_onnx.py \
--checkpoint experiments/my_experiment/model_000100.pt \
--output models/my_model.onnx
```

### Specifying Environment

If you are using a specific environment configuration, you can specify it with `--env`:

```bash
python scripts/export_onnx.py --env puffer_drive --checkpoint ...
```

### ONNX Opset Version

You can specify the ONNX opset version (default is 18):

```bash
python scripts/export_onnx.py --opset 17 ...
```

## Arguments

| Argument | Type | Default | Description |
|----------|------|---------|-------------|
| `--env` | str | `puffer_drive` | The environment name to load configuration for. |
| `--checkpoint` | str | (required/default example path) | Path to the PyTorch `.pt` checkpoint file. |
| `--output` | str | `None` (derived from checkpoint) | Path where the `.onnx` file will be saved. |
| `--opset` | int | `18` | ONNX opset version to use for export. |

## Verification

The script automatically verifies the exported ONNX model by running a forward pass on both the PyTorch model and the ONNX model with dummy inputs and comparing the outputs. It checks for:
- Logits
- Value
- LSTM hidden states (if applicable)

If verification passes, it will print match confirmations. If there are mismatches, it will raise an error or print a mismatch warning.

# Exporting Model Weights to .bin

You can also export the model weights to a binary format (`.bin`) which can be loaded by the C backend of PufferDrive. This is done using `scripts/export_model_bin.py`.

## Usage

```bash
python scripts/export_model_bin.py --checkpoint path/to/your/checkpoint.pt
```

## Arguments

| Argument | Type | Default | Description |
|----------|------|---------|-------------|
| `--env` | str | `puffer_drive` | The environment name to load configuration for. |
| `--checkpoint` | str | (required) | Path to the PyTorch `.pt` checkpoint file. |
| `--output` | str | `pufferlib/resources/drive/model_puffer_drive_000100.bin` | Output path for the binary weights file. |

This script flattens all model parameters into a single contiguous binary file.
134 changes: 134 additions & 0 deletions scripts/export_model_bin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import argparse
import os
import torch
import importlib
import numpy as np

import pufferlib.utils
import pufferlib.vector
import pufferlib.models

from pufferlib.ocean.torch import Drive


def load_config(env_name, config_dir=None):
# Minimal config loader based on pufferl.py
import configparser
import glob
from collections import defaultdict
import ast

if config_dir is None:
puffer_dir = os.path.dirname(os.path.realpath(pufferlib.__file__))
else:
puffer_dir = config_dir

puffer_config_dir = os.path.join(puffer_dir, "config/**/*.ini")
puffer_default_config = os.path.join(puffer_dir, "config/default.ini")

found = False
for path in glob.glob(puffer_config_dir, recursive=True):
p = configparser.ConfigParser()
p.read([puffer_default_config, path])
if env_name in p["base"]["env_name"].split():
found = True
break

if not found:
raise ValueError(f"No config for env_name {env_name}")

def puffer_type(value):
try:
return ast.literal_eval(value)
except (ValueError, SyntaxError):
return value
Comment on lines 40 to 44
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bare except catches all exceptions including KeyboardInterrupt and SystemExit

Suggested change
def puffer_type(value):
try:
return ast.literal_eval(value)
except:
return value
def puffer_type(value):
try:
return ast.literal_eval(value)
except (ValueError, SyntaxError):
return value


args = defaultdict(dict)
for section in p.sections():
for key in p[section]:
value = puffer_type(p[section][key])
args[section][key] = value

return args


# Export PufferDrive model weights to .bin
def export_weights():
parser = argparse.ArgumentParser(description="Export PufferDrive model weights to .bin")
parser.add_argument("--env", type=str, default="puffer_drive", help="Environment name")
parser.add_argument(
"--checkpoint",
type=str,
help="Path to .pt checkpoint",
)
parser.add_argument(
"--output",
type=str,
default="pufferlib/resources/drive/model_puffer_drive_000100.bin",
help="Output .bin file path",
)
Comment on lines 60 to 69
Copy link

Copilot AI Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default --checkpoint path is machine/user-specific (/scratch/...) and will break for most users running the script. Prefer no default (require the flag) or use a repo-relative default similar to scripts/export_onnx.py so the script is portable.

Copilot uses AI. Check for mistakes.

args = parser.parse_args()

# Load configuration
config = load_config(args.env)

# Load environment to get observation/action spaces
package = config["base"]["package"]
module_name = "pufferlib.ocean" if package == "ocean" else f"pufferlib.environments.{package}"
env_module = importlib.import_module(module_name)
make_env = env_module.env_creator(args.env)

# Use valid dummy env to initialize policy
# Ensure env args/kwargs are correctly passed as expected by make()
env_kwargs = config["env"]

vecenv = pufferlib.vector.make(make_env, env_kwargs=env_kwargs, backend=pufferlib.vector.Serial, num_envs=1)

# Initialize Policy
print("Initializing Policy...")
policy = Drive(vecenv.driver_env, **config["policy"])

if config["base"]["rnn_name"]:
print("Wrapping with LSTM...")
policy = pufferlib.models.LSTMWrapper(vecenv.driver_env, policy, **config["rnn"])

# Load Checkpoint
print(f"Loading checkpoint from {args.checkpoint}...")
checkpoint = torch.load(args.checkpoint, map_location="cpu")

# Handle both full checkpoint dict and raw state dict
if isinstance(checkpoint, dict) and "agent_state_dict" in checkpoint:
state_dict = checkpoint["agent_state_dict"]
else:
state_dict = checkpoint

# Strip compile prefixes
new_state_dict = {}
for k, v in state_dict.items():
if k.startswith("_orig_mod."):
new_state_dict[k[10:]] = v
else:
new_state_dict[k] = v

policy.load_state_dict(new_state_dict)
policy.eval()

# Export Weights
print(f"Exporting weights to {args.output}...")
weights = []
total_params = 0
for name, param in policy.named_parameters():
param_flat = param.data.cpu().numpy().flatten()
weights.append(param_flat)
count = param_flat.size
print(f" {name}: {param.shape} -> {count} params")
total_params += count

weights = np.concatenate(weights)
weights.tofile(args.output)
print(f"Success! Saved {len(weights)} weights ({total_params} params) to {args.output}")


if __name__ == "__main__":
export_weights()
Loading