Skip to content

Add Differentiable Physics: Mass-Spring System example #1359

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jun 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 154 additions & 0 deletions differentiable_physics/mass_spring.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import torch
import torch.nn as nn
import torch.optim as optim
import argparse
import matplotlib.pyplot as plt
import os


class MassSpringSystem(nn.Module):
def __init__(self, num_particles, springs, mass=1.0, dt=0.01, gravity=9.81, device="cpu"):
super().__init__()
self.device = device
self.mass = mass
self.springs = springs
self.dt = dt
self.gravity = gravity

# Particle 0 is fixed at the origin
self.initial_position_0 = torch.tensor([0.0, 0.0], device=device)

# Remaining particles are trainable
self.initial_positions_rest = nn.Parameter(torch.randn(num_particles - 1, 2, device=device))

# Velocities
self.velocities = torch.zeros(num_particles, 2, device=device)

def forward(self, steps):
positions = torch.cat([self.initial_position_0.unsqueeze(0), self.initial_positions_rest], dim=0)
velocities = self.velocities

for _ in range(steps):
forces = torch.zeros_like(positions)

# Compute spring forces
for (i, j, rest_length, stiffness) in self.springs:
xi, xj = positions[i], positions[j]
dir_vec = xj - xi
dist = dir_vec.norm()
force = stiffness * (dist - rest_length) * dir_vec / (dist + 1e-6)
forces[i] += force
forces[j] -= force

# Apply gravity
forces[:, 1] -= self.gravity * self.mass

# Semi-implicit Euler integration
acceleration = forces / self.mass
velocities = velocities + acceleration * self.dt
positions = positions + velocities * self.dt

# Fix particle 0 at origin
positions[0] = self.initial_position_0
velocities[0] = torch.tensor([0.0, 0.0], device=positions.device)

return positions


def visualize_positions(initial, final, target, save_path="mass_spring_viz.png"):
plt.figure(figsize=(6, 4))
plt.scatter(initial[:, 0], initial[:, 1], c='blue', label='Initial', marker='x')
plt.scatter(final[:, 0], final[:, 1], c='green', label='Final', marker='o')
plt.scatter(target[:, 0], target[:, 1], c='red', label='Target', marker='*')
plt.title("Mass-Spring System Positions")
plt.xlabel("X")
plt.ylabel("Y")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig(save_path)
print(f"Saved visualization to {os.path.abspath(save_path)}")
plt.close()


def train(args):
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.accelerator.current_accelerator() if torch.accelerator.is_available() else torch.device("cpu")
print(f"Using device: {device}")
system = MassSpringSystem(
num_particles=args.num_particles,
springs=[(0, 1, 1.0, args.stiffness)],
mass=args.mass,
dt=args.dt,
gravity=args.gravity,
device=device,
)

optimizer = optim.Adam(system.parameters(), lr=args.lr)
target_positions = torch.tensor(
[[0.0, 0.0], [1.0, 0.0]], device=device
)

for epoch in range(args.epochs):
optimizer.zero_grad()
final_positions = system(args.steps)
loss = (final_positions - target_positions).pow(2).mean()
loss.backward()
optimizer.step()

if (epoch + 1) % args.log_interval == 0:
print(f"Epoch {epoch+1}/{args.epochs}, Loss: {loss.item():.6f}")

# Visualization
initial_positions = torch.cat([system.initial_position_0.unsqueeze(0), system.initial_positions_rest.detach()], dim=0).cpu().numpy()
visualize_positions(initial_positions, final_positions.detach().cpu().numpy(), target_positions.cpu().numpy())

print("\nTraining completed.")
print(f"Final positions:\n{final_positions.detach().cpu().numpy()}")
print(f"Target positions:\n{target_positions.cpu().numpy()}")


def evaluate(args):
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.accelerator.current_accelerator() if torch.accelerator.is_available() else torch.device("cpu")
print(f"Using device: {device}")
system = MassSpringSystem(
num_particles=args.num_particles,
springs=[(0, 1, 1.0, args.stiffness)],
mass=args.mass,
dt=args.dt,
gravity=args.gravity,
device=device,
)

with torch.no_grad():
final_positions = system(args.steps)
print(f"Final positions after {args.steps} steps:\n{final_positions.cpu().numpy()}")


def parse_args():
parser = argparse.ArgumentParser(description="Differentiable Physics: Mass-Spring System")
parser.add_argument("--epochs", type=int, default=1000, help="Number of training epochs")
parser.add_argument("--steps", type=int, default=50, help="Number of simulation steps per forward pass")
parser.add_argument("--lr", type=float, default=0.01, help="Learning rate")
parser.add_argument("--dt", type=float, default=0.01, help="Time step for integration")
parser.add_argument("--mass", type=float, default=1.0, help="Mass of each particle")
parser.add_argument("--stiffness", type=float, default=10.0, help="Spring stiffness constant")
parser.add_argument("--num_particles", type=int, default=2, help="Number of particles in the system")
parser.add_argument("--mode", choices=["train", "eval"], default="train", help="Mode: train or eval")
parser.add_argument("--log_interval", type=int, default=100, help="Print loss every n epochs")
parser.add_argument("--gravity", type=float, default=9.81, help="Gravity strength")
return parser.parse_args()


def main():
args = parse_args()

if args.mode == "train":
train(args)
elif args.mode == "eval":
evaluate(args)


if __name__ == "__main__":
main()
Binary file added differentiable_physics/mass_spring_viz.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
68 changes: 68 additions & 0 deletions differentiable_physics/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Differentiable Physics: Mass-Spring System

This example demonstrates a simple differentiable **mass-spring system** using PyTorch.

A set of particles is connected via springs and evolves over time under the influence of:
- **Spring forces** (via Hooke’s Law)
- **Gravity** (acting in the negative Y-direction)

The system is fully differentiable, enabling **gradient-based optimization** of the **initial positions** of the particles so that their **final positions** match a desired **target configuration**.

This idea is inspired by differentiable simulation frameworks such as those presented in recent research (see reference below).

---

## Files

- `mass_spring.py` — Implements the simulation, training loop, and evaluation logic.
- `README.md` — Description, instructions, and visualization output.
- `mass_spring_viz.png` — Output visualization of the final vs target configuration.

---

## Key Concepts

| Term | Description |
|-------------------|-----------------------------------------------------------------------------|
| Initial Position | Learnable 2D coordinates (x, y) of each particle before simulation begins. |
| Target Position | Desired final 2D position after simulation. Used to compute loss. |
| Gravity | Constant force `[0, -9.8]` pulling particles downward in Y direction. |
| Spring Forces | Modeled using Hooke’s Law. Particles connected by springs exert forces. |
| Dimensionality | All particle positions and forces are 2D vectors. |

---

## Requirements

- Python 3.8+
- PyTorch ≥ 2.0

Install requirements (if needed):

pip install -r requirements.txt


## Usage

First, ensure PyTorch is installed.

#### Train the system


python mass_spring.py --mode train


![Mass-Spring System Visualization](mass_spring_viz.png)

*Mass-Spring System Visualization comparing final vs target positions.*



## References

[1] Sanchez-Gonzalez, A. et al. (2020).
Learning to Simulate Complex Physics with Graph Networks.
arXiv preprint arXiv:2002.09405.
Available: https://arxiv.org/abs/2002.09405


3 changes: 3 additions & 0 deletions differentiable_physics/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
torch>=2.6
matplotlib

6 changes: 6 additions & 0 deletions run_python_examples.sh
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,11 @@ function gat() {
uv run main.py --epochs 1 --dry-run || error "graph attention network failed"
}

function differentiable_physics() {
uv run mass_spring.py --mode train --epochs 5 --steps 3 || error "differentiable_physics example failed"
}


eval "base_$(declare -f stop)"

function stop() {
Expand Down Expand Up @@ -217,6 +222,7 @@ function run_all() {
run fx
run gcn
run gat
run differentiable_physics
}

# by default, run all examples
Expand Down