Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 72 additions & 6 deletions examples/scripts/2_structural_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

# Number of steps to run
SMOKE_TEST = os.getenv("CI") is not None
N_steps = 10 if SMOKE_TEST else 500
N_steps = 10 if SMOKE_TEST else 100


# ============================================================================
Expand Down Expand Up @@ -111,7 +111,7 @@

# Run optimization
for step in range(N_steps):
if step % 100 == 0:
if step % (N_steps // 5) == 0:
print(f"Step {step}: Potential energy: {state.energy[0].item()} eV")
state = ts.fire_step(state=state, model=lj_model, dt_max=0.01)

Expand Down Expand Up @@ -174,7 +174,7 @@

print("\nRunning FIRE:")
for step in range(N_steps):
if step % 20 == 0:
if step % (N_steps // 5) == 0:
print(f"Step {step}, Energy: {[energy.item() for energy in state.energy]}")

state = ts.fire_step(state=state, model=model, dt_max=0.01)
Expand Down Expand Up @@ -254,7 +254,7 @@

print("\nRunning batched unit cell gradient descent:")
for step in range(N_steps):
if step % 20 == 0:
if step % (N_steps // 5) == 0:
P1 = -torch.trace(state.stress[0]) * UnitConversion.eV_per_Ang3_to_GPa / 3
P2 = -torch.trace(state.stress[1]) * UnitConversion.eV_per_Ang3_to_GPa / 3
P3 = -torch.trace(state.stress[2]) * UnitConversion.eV_per_Ang3_to_GPa / 3
Expand Down Expand Up @@ -308,7 +308,7 @@

print("\nRunning batched unit cell FIRE:")
for step in range(N_steps):
if step % 20 == 0:
if step % (N_steps // 5) == 0:
P1 = -torch.trace(state.stress[0]) * UnitConversion.eV_per_Ang3_to_GPa / 3
P2 = -torch.trace(state.stress[1]) * UnitConversion.eV_per_Ang3_to_GPa / 3
P3 = -torch.trace(state.stress[2]) * UnitConversion.eV_per_Ang3_to_GPa / 3
Expand Down Expand Up @@ -360,7 +360,7 @@

print("\nRunning batched frechet cell filter with FIRE:")
for step in range(N_steps):
if step % 20 == 0:
if step % (N_steps // 5) == 0:
P1 = -torch.trace(state.stress[0]) * UnitConversion.eV_per_Ang3_to_GPa / 3
P2 = -torch.trace(state.stress[1]) * UnitConversion.eV_per_Ang3_to_GPa / 3
P3 = -torch.trace(state.stress[2]) * UnitConversion.eV_per_Ang3_to_GPa / 3
Expand All @@ -386,6 +386,72 @@
print(f"Initial pressure: {initial_pressure} GPa")
print(f"Final pressure: {final_pressure} GPa")

# ============================================================================
# SECTION 7: Batched MACE L-BFGS
# ============================================================================
print("\n" + "=" * 70)
print("SECTION 7: Batched MACE L-BFGS")
print("=" * 70)

# Recreate structures with perturbations
si_dc = bulk("Si", "diamond", a=5.21).repeat((2, 2, 2))
si_dc.positions += 0.2 * rng.standard_normal(si_dc.positions.shape)

cu_dc = bulk("Cu", "fcc", a=3.85).repeat((2, 2, 2))
cu_dc.positions += 0.2 * rng.standard_normal(cu_dc.positions.shape)

fe_dc = bulk("Fe", "bcc", a=2.95).repeat((2, 2, 2))
fe_dc.positions += 0.2 * rng.standard_normal(fe_dc.positions.shape)

atoms_list = [si_dc, cu_dc, fe_dc]

state = ts.io.atoms_to_state(atoms_list, device=device, dtype=dtype)
results = model(state)
state = ts.lbfgs_init(state=state, model=model, alpha=70.0, step_size=1.0)

print("\nRunning L-BFGS:")
for step in range(N_steps):
if step % (N_steps // 5) == 0:
print(f"Step {step}, Energy: {[energy.item() for energy in state.energy]}")
state = ts.lbfgs_step(state=state, model=model, max_history=100)

print(f"Initial energies: {[energy.item() for energy in results['energy']]} eV")
print(f"Final energies: {[energy.item() for energy in state.energy]} eV")


# ============================================================================
# SECTION 8: Batched MACE BFGS
# ============================================================================
print("\n" + "=" * 70)
print("SECTION 8: Batched MACE BFGS")
print("=" * 70)

# Recreate structures with perturbations
si_dc = bulk("Si", "diamond", a=5.21).repeat((2, 2, 2))
si_dc.positions += 0.2 * rng.standard_normal(si_dc.positions.shape)

cu_dc = bulk("Cu", "fcc", a=3.85).repeat((2, 2, 2))
cu_dc.positions += 0.2 * rng.standard_normal(cu_dc.positions.shape)

fe_dc = bulk("Fe", "bcc", a=2.95).repeat((2, 2, 2))
fe_dc.positions += 0.2 * rng.standard_normal(fe_dc.positions.shape)

atoms_list = [si_dc, cu_dc, fe_dc]

state = ts.io.atoms_to_state(atoms_list, device=device, dtype=dtype)
results = model(state)
state = ts.bfgs_init(state=state, model=model, alpha=70.0)

print("\nRunning BFGS:")
for step in range(N_steps):
if step % (N_steps // 5) == 0:
print(f"Step {step}, Energy: {[energy.item() for energy in state.energy]}")
state = ts.bfgs_step(state=state, model=model)

print(f"Initial energies: {[energy.item() for energy in results['energy']]} eV")
print(f"Final energies: {[energy.item() for energy in state.energy]} eV")


print("\n" + "=" * 70)
print("Structural optimization examples completed!")
print("=" * 70)
212 changes: 212 additions & 0 deletions tests/test_autobatching.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,3 +605,215 @@ def test_in_flight_max_iterations(
# Verify iteration_count tracking
for idx in range(len(states)):
assert batcher.iteration_count[idx] == max_iterations


@pytest.mark.parametrize(
"num_steps_per_batch",
[
5, # At 5 steps, not every state will converge before the next batch.
10, # At 10 steps, all states will converge before the next batch
],
)
def test_in_flight_with_bfgs(
si_sim_state: ts.SimState,
fe_supercell_sim_state: ts.SimState,
lj_model: LennardJonesModel,
num_steps_per_batch: int,
) -> None:
"""Test InFlightAutoBatcher with BFGS optimizer (matching FIRE test structure)."""
si_bfgs_state = ts.bfgs_init(si_sim_state, lj_model, cell_filter=ts.CellFilter.unit)
fe_bfgs_state = ts.bfgs_init(
fe_supercell_sim_state, lj_model, cell_filter=ts.CellFilter.unit
)

bfgs_states = [si_bfgs_state, fe_bfgs_state] * 5
bfgs_states = [state.clone() for state in bfgs_states]
for state in bfgs_states:
state.positions += torch.randn_like(state.positions) * 0.01

batcher = InFlightAutoBatcher(
model=lj_model,
memory_scales_with="n_atoms",
max_memory_scaler=6000,
)
batcher.load_states(bfgs_states)

def convergence_fn(state: ts.BFGSState) -> torch.Tensor:
system_wise_max_force = torch.zeros(
state.n_systems, device=state.device, dtype=torch.float64
)
max_forces = state.forces.norm(dim=1)
system_wise_max_force = system_wise_max_force.scatter_reduce(
dim=0, index=state.system_idx, src=max_forces, reduce="amax"
)
return system_wise_max_force < 5e-1

all_completed_states, convergence_tensor = [], None
while True:
state, completed_states = batcher.next_batch(state, convergence_tensor)

all_completed_states.extend(completed_states)
if state is None:
break

for _ in range(num_steps_per_batch):
state = ts.bfgs_step(state=state, model=lj_model)
convergence_tensor = convergence_fn(state)

assert len(all_completed_states) == len(bfgs_states)


def test_binning_auto_batcher_with_bfgs(
si_sim_state: ts.SimState,
fe_supercell_sim_state: ts.SimState,
lj_model: LennardJonesModel,
) -> None:
"""Test BinningAutoBatcher with BFGS optimizer (matching FIRE test structure)."""
si_bfgs_state = ts.bfgs_init(si_sim_state, lj_model, cell_filter=ts.CellFilter.unit)
fe_bfgs_state = ts.bfgs_init(
fe_supercell_sim_state, lj_model, cell_filter=ts.CellFilter.unit
)

bfgs_states = [si_bfgs_state, fe_bfgs_state] * 5
bfgs_states = [state.clone() for state in bfgs_states]
for state in bfgs_states:
state.positions += torch.randn_like(state.positions) * 0.01

batcher = BinningAutoBatcher(
model=lj_model, memory_scales_with="n_atoms", max_memory_scaler=6000
)
batcher.load_states(bfgs_states)

all_finished_states: list[ts.SimState] = []
total_batches = 0
for batch, _ in batcher:
total_batches += 1 # noqa: SIM113
for _ in range(5):
batch = ts.bfgs_step(state=batch, model=lj_model)
all_finished_states.extend(batch.split())

assert len(all_finished_states) == len(bfgs_states)


def _group_states_by_size(
states: list[ts.SimState],
) -> list[list[tuple[int, ts.SimState]]]:
"""Group states by n_atoms, preserving original indices for order restoration.

Used for L-BFGS which requires same-sized systems in each batch due to
history tensor shapes being dependent on n_atoms.
"""
from itertools import groupby

indexed_states = list(enumerate(states))
sorted_states = sorted(indexed_states, key=lambda x: x[1].n_atoms)
groups = []
for _, group in groupby(sorted_states, key=lambda x: x[1].n_atoms):
groups.append(list(group))
return groups


@pytest.mark.parametrize(
"num_steps_per_batch",
[
5, # At 5 steps, not every state will converge before the next batch.
10, # At 10 steps, all states will converge before the next batch
],
)
def test_in_flight_with_lbfgs(
si_sim_state: ts.SimState,
fe_supercell_sim_state: ts.SimState,
lj_model: LennardJonesModel,
num_steps_per_batch: int,
) -> None:
"""Test InFlightAutoBatcher with L-BFGS optimizer (matching FIRE test structure)."""
si_lbfgs_state = ts.lbfgs_init(si_sim_state, lj_model, cell_filter=ts.CellFilter.unit)
fe_lbfgs_state = ts.lbfgs_init(
fe_supercell_sim_state, lj_model, cell_filter=ts.CellFilter.unit
)

lbfgs_states = [si_lbfgs_state, fe_lbfgs_state] * 5
lbfgs_states = [state.clone() for state in lbfgs_states]
for state in lbfgs_states:
state.positions += torch.randn_like(state.positions) * 0.01

batcher = InFlightAutoBatcher(
model=lj_model,
memory_scales_with="n_atoms",
max_memory_scaler=6000,
)
batcher.load_states(lbfgs_states)

def convergence_fn(state: ts.LBFGSState) -> torch.Tensor:
system_wise_max_force = torch.zeros(
state.n_systems, device=state.device, dtype=torch.float64
)
max_forces = state.forces.norm(dim=1)
system_wise_max_force = system_wise_max_force.scatter_reduce(
dim=0, index=state.system_idx, src=max_forces, reduce="amax"
)
return system_wise_max_force < 5e-1

all_completed_states, convergence_tensor = [], None
while True:
state, completed_states = batcher.next_batch(state, convergence_tensor)

all_completed_states.extend(completed_states)
if state is None:
break

for _ in range(num_steps_per_batch):
state = ts.lbfgs_step(state=state, model=lj_model)
convergence_tensor = convergence_fn(state)

assert len(all_completed_states) == len(lbfgs_states)


def test_binning_auto_batcher_with_lbfgs(
si_sim_state: ts.SimState,
fe_supercell_sim_state: ts.SimState,
lj_model: LennardJonesModel,
) -> None:
"""Test BinningAutoBatcher with L-BFGS optimizer (matching FIRE test structure)."""
si_lbfgs_state = ts.lbfgs_init(si_sim_state, lj_model, cell_filter=ts.CellFilter.unit)
fe_lbfgs_state = ts.lbfgs_init(
fe_supercell_sim_state, lj_model, cell_filter=ts.CellFilter.unit
)

lbfgs_states = [si_lbfgs_state, fe_lbfgs_state] * 5
lbfgs_states = [state.clone() for state in lbfgs_states]
for state in lbfgs_states:
state.positions += torch.randn_like(state.positions) * 0.01

# Group by size and process each group separately
size_groups = _group_states_by_size(lbfgs_states)
all_finished_with_indices: list[tuple[int, ts.SimState]] = []
total_batches = 0

for group in size_groups:
original_indices, group_states = zip(*group, strict=True)
group_states_list = list(group_states)

batcher = BinningAutoBatcher(
model=lj_model, memory_scales_with="n_atoms", max_memory_scaler=6000
)
batcher.load_states(group_states_list)

finished_states = []
for batch, _ in batcher:
total_batches += 1
for _ in range(5):
batch = ts.lbfgs_step(state=batch, model=lj_model)
finished_states.extend(batch.split())

restored = batcher.restore_original_order(finished_states)
for idx, finished_state in zip(original_indices, restored, strict=True):
all_finished_with_indices.append((idx, finished_state))

# Sort by original index to restore order
all_finished_with_indices.sort(key=lambda x: x[0])
all_finished_states = [s for _, s in all_finished_with_indices]

assert len(all_finished_states) == len(lbfgs_states)
for restored, original in zip(all_finished_states, lbfgs_states, strict=True):
assert torch.all(restored.atomic_numbers == original.atomic_numbers)
Loading