Skip to content

Commit

Permalink
fixed runs
Browse files Browse the repository at this point in the history
  • Loading branch information
bylehn committed Nov 9, 2024
1 parent 18e0d93 commit 0f52360
Show file tree
Hide file tree
Showing 2 changed files with 224 additions and 76 deletions.
189 changes: 139 additions & 50 deletions atomicstrain/analysis.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,11 @@
import numpy as np
from MDAnalysis.analysis.base import AnalysisBase
from .compute import compute_strain_tensor, compute_principal_strains_and_shear
from .compute import process_frame_data
from .utils import create_selections
from .io import write_strain_files, write_pdb_with_strains
from tqdm import tqdm
import os

# analysis.py
import numpy as np
from MDAnalysis.analysis.base import AnalysisBase
import os
from .compute import compute_strain_tensor, compute_principal_strains_and_shear
from .utils import create_selections
from .io import write_strain_files, write_pdb_with_strains

class StrainAnalysis(AnalysisBase):
def __init__(self, reference, deformed, residue_numbers, output_dir, min_neighbors=3, n_frames=None, use_all_heavy=False, **kwargs):
self.ref = reference
Expand All @@ -25,32 +17,24 @@ def __init__(self, reference, deformed, residue_numbers, output_dir, min_neighbo
self.output_dir = output_dir
self.has_ref_trajectory = hasattr(self.ref, 'trajectory') and len(self.ref.trajectory) > 1

# Store n_frames but don't use it directly for array initialization
self.requested_n_frames = n_frames
# Create necessary directories
os.makedirs(output_dir, exist_ok=True)
os.makedirs(os.path.join(output_dir, 'data'), exist_ok=True)

super().__init__(self.defm.trajectory, **kwargs)

def _prepare(self):
# Create output directory if it doesn't exist
os.makedirs(self.output_dir, exist_ok=True)

# Determine the number of atoms we're analyzing
n_atoms = len(self.selections)

# Calculate actual number of frames that will be analyzed
# Calculate frame range
start = self.start if self.start is not None else 0
stop = self.stop if self.stop is not None else len(self.defm.trajectory)
step = self.step if self.step is not None else 1

# Calculate actual number of frames based on start, stop, and stride
actual_n_frames = len(range(start, stop, step))
print(f"Preparing analysis for {actual_n_frames} frames")
n_atoms = len(self.selections)

# Use the data subdirectory for memory-mapped files
data_dir = os.path.join(self.output_dir, 'data')
os.makedirs(data_dir, exist_ok=True)
print(f"Preparing analysis for {actual_n_frames} frames and {n_atoms} atoms")

# Create memory-mapped arrays with correct size
# Create memory-mapped arrays
data_dir = os.path.join(self.output_dir, 'data')
self.results.shear_strains = np.memmap(
f"{data_dir}/shear_strains.npy",
dtype='float32',
Expand All @@ -65,54 +49,159 @@ def _prepare(self):
shape=(actual_n_frames, n_atoms, 3)
)

# Store atom info
self.results.atom_info = [(ref_center.resid, ref_center.name)
for (_, ref_center), _ in self.selections]

# Initialize frame counter
self._frame_counter = 0

def _single_frame(self):
frame_shear = np.zeros(len(self.selections), dtype='float32')
frame_principal = np.zeros((len(self.selections), 3), dtype='float32')

# Update reference frame only if it has a trajectory
# Update reference frame if needed
if self.has_ref_trajectory:
self.ref.trajectory[self._frame_index]

for i, ((ref_sel, ref_center), (defm_sel, defm_center)) in enumerate(self.selections):
A = ref_sel.positions - ref_center.position
B = defm_sel.positions - defm_center.position

if A.shape != B.shape:
print(f"Warning: Shapes don't match for atom {ref_center.index}. Skipping.")
continue
# Collect all positions and centers
ref_positions = []
ref_centers = []
def_positions = []
def_centers = []

for (ref_sel, ref_center), (defm_sel, defm_center) in self.selections:
ref_positions.append(ref_sel.positions)
ref_centers.append(ref_center.position)
def_positions.append(defm_sel.positions)
def_centers.append(defm_center.position)

Q = compute_strain_tensor(A, B)
shear, principal = compute_principal_strains_and_shear(Q)
frame_shear[i] = float(shear)
frame_principal[i] = principal
# Process entire frame at once
frame_shear, frame_principal = process_frame_data(
ref_positions,
ref_centers,
def_positions,
def_centers
)

# Use frame counter instead of frame index for array indexing
# Store results
self.results.shear_strains[self._frame_counter] = frame_shear
self.results.principal_strains[self._frame_counter] = frame_principal

# Periodically flush to disk
if self._frame_counter % 100 == 0:
self.results.shear_strains.flush()
self.results.principal_strains.flush()

self._frame_counter += 1

def run(self, start=None, stop=None, stride=None, verbose=True):
"""
Run the analysis with enhanced progress tracking and error handling.
Parameters
----------
start : int, optional
First frame of trajectory to analyze, default: None (start at beginning)
stop : int, optional
Last frame of trajectory to analyze, default: None (end of trajectory)
stride : int, optional
Number of frames to skip between each analyzed frame, default: None (use every frame)
verbose : bool, optional
Show detailed progress, default: True
Returns
-------
self : StrainAnalysis
Return self to allow for method chaining
"""
# Store parameters
self.start = start
self.stop = stop
self.step = stride
return super().run(start=start, stop=stop, step=stride, verbose=verbose)

# Calculate total frames to be analyzed
trajectory_length = len(self.defm.trajectory)
start_frame = start if start is not None else 0
end_frame = stop if stop is not None else trajectory_length
step_size = stride if stride is not None else 1

total_frames = len(range(start_frame, end_frame, step_size))

if verbose:
print("\nAnalysis Setup:")
print(f"Total trajectory frames: {trajectory_length}")
print(f"Analyzing frames {start_frame} to {end_frame} with stride {step_size}")
print(f"Total frames to analyze: {total_frames}")
if self.has_ref_trajectory:
print("Using reference trajectory")
print(f"Number of atoms to analyze: {len(self.selections)}")

# Memory usage estimate
mem_per_frame = (len(self.selections) * 4 * 4) # 4 bytes per float32, 4 values per atom
total_mem_estimate = (mem_per_frame * total_frames) / (1024 * 1024) # Convert to MB
print(f"Estimated memory usage: {total_mem_estimate:.2f} MB")

print("\nStarting analysis...")

try:
import time
start_time = time.time()

# Run the analysis
result = super().run(start=start, stop=stop, step=stride, verbose=verbose)

if verbose:
end_time = time.time()
duration = end_time - start_time
frames_per_second = total_frames / duration
print(f"\nAnalysis completed in {duration:.2f} seconds")
print(f"Average processing speed: {frames_per_second:.2f} frames/second")

# Memory usage report
import psutil
process = psutil.Process()
memory_usage = process.memory_info().rss / (1024 * 1024) # Convert to MB
print(f"Final memory usage: {memory_usage:.2f} MB")

return result

except Exception as e:
print(f"\nError during analysis: {str(e)}")

# Additional error context
if self._frame_counter > 0:
print(f"Error occurred after processing {self._frame_counter} frames")
print("Partial results may be available")

import traceback
print("\nFull traceback:")
traceback.print_exc()

# Try to clean up memory-mapped files
try:
del self.results.shear_strains
del self.results.principal_strains
except:
pass

raise

def _conclude(self):
# Compute average strains using the actual number of frames analyzed
# Ensure all data is written
self.results.shear_strains.flush()
self.results.principal_strains.flush()

# Compute averages
self.results.avg_shear_strains = np.mean(self.results.shear_strains[:self._frame_counter], axis=0)
self.results.avg_principal_strains = np.mean(self.results.principal_strains[:self._frame_counter], axis=0)

# Store final arrays for visualization
self.results.final_shear_strains = np.array(
self.results.shear_strains[:self._frame_counter],
dtype=np.float32
)
self.results.final_principal_strains = np.array(
self.results.principal_strains[:self._frame_counter],
dtype=np.float32
)

# Save a copy of the strains before writing files
self.results.final_shear_strains = np.array(self.results.shear_strains[:self._frame_counter])
self.results.final_principal_strains = np.array(self.results.principal_strains[:self._frame_counter])

# Write output files
write_strain_files(
self.output_dir,
self.results.shear_strains[:self._frame_counter],
Expand All @@ -133,6 +222,6 @@ def _conclude(self):
self.use_all_heavy
)

# Clean up memory-mapped arrays
# Clean up
del self.results.shear_strains
del self.results.principal_strains
111 changes: 85 additions & 26 deletions atomicstrain/compute.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,99 @@
# compute.py
import jax.numpy as jnp
from jax import jit
from jax import jit, vmap, device_put
from jax.scipy.linalg import eigh
from jax import config
import numpy as np
import jax

@jit
def compute_strain_tensor(Am, Bm):
"""
Compute the strain tensor for given reference and deformed configurations.
# Configure JAX
DEVICES = jax.devices()
print(f"JAX devices available: {DEVICES}")

Args:
Am (jnp.ndarray): Reference configuration matrix.
Bm (jnp.ndarray): Deformed configuration matrix.
# Try to use GPU if available
try:
if any(d.platform == 'gpu' for d in DEVICES):
config.update('jax_platform_name', 'gpu')
print("Using GPU for computations")
else:
config.update('jax_platform_name', 'cpu')
print("No GPU found, using CPU")
except:
config.update('jax_platform_name', 'cpu')
print("Failed to set GPU, using CPU")

Returns:
jnp.ndarray: Computed strain tensor.
"""
@jit
def compute_strain_tensor(Am, Bm):
"""Compute strain tensor using JAX."""
D = jnp.linalg.inv(Am.T @ Am)
C = Bm @ Bm.T - Am @ Am.T
Q = 0.5 * (D @ Am.T @ C @ Am @ D)
# Explicitly symmetrize the tensor
Q = 0.5 * (Q + Q.T)
return Q
return 0.5 * (Q + Q.T)

@jit
def compute_principal_strains_and_shear(Q):
"""
Compute principal strains and shear from the strain tensor.
"""Compute principal strains and shear using JAX."""
eigenvalues, _ = eigh(Q)
shear = jnp.trace(Q @ Q) - (1/3) * (jnp.trace(Q))**2
return shear, jnp.sort(eigenvalues)[::-1]

Args:
Q (jnp.ndarray): Strain tensor.
def pad_positions(positions, max_length):
"""Pad position array to fixed length with zeros."""
padded = np.zeros((max_length, 3), dtype=np.float32)
actual_length = min(len(positions), max_length)
padded[:actual_length] = positions[:actual_length]
return padded

Returns:
Tuple[jnp.ndarray, jnp.ndarray]: A tuple containing:
- shear (jnp.ndarray): Computed shear strain.
- eigenvalues (jnp.ndarray): Principal strains (eigenvalues of the strain tensor).
def process_frame_data(ref_positions, ref_centers, def_positions, def_centers):
"""
eigenvalues, _ = eigh(Q)
shear = jnp.trace(Q @ Q) - (1/3) * (jnp.trace(Q))**2
sorted_eigenvalues = jnp.sort(eigenvalues)[::-1] # Sort in descending order
return shear, sorted_eigenvalues
Process all positions for a frame efficiently using JAX vectorization.
"""
n_atoms = len(ref_positions)

# Find maximum number of neighbors
max_neighbors = max(len(pos) for pos in ref_positions if len(pos) >= 3)

# Initialize output arrays
shear_strains = np.zeros(n_atoms, dtype=np.float32)
principal_strains = np.zeros((n_atoms, 3), dtype=np.float32)

# Process valid atoms (those with enough neighbors)
valid_indices = []
valid_ref_data = []
valid_def_data = []

for i in range(n_atoms):
if len(ref_positions[i]) >= 3:
# Center and pad the positions
ref_centered = ref_positions[i] - ref_centers[i]
def_centered = def_positions[i] - def_centers[i]

ref_padded = pad_positions(ref_centered, max_neighbors)
def_padded = pad_positions(def_centered, max_neighbors)

valid_indices.append(i)
valid_ref_data.append(ref_padded)
valid_def_data.append(def_padded)

if not valid_indices:
return shear_strains, principal_strains

# Convert to JAX arrays
ref_array = device_put(jnp.array(valid_ref_data))
def_array = device_put(jnp.array(valid_def_data))

try:
# Vectorized computation for all valid atoms
Q_tensors = vmap(compute_strain_tensor)(ref_array, def_array)
shear_vals, principal_vals = vmap(compute_principal_strains_and_shear)(Q_tensors)

# Move results back to CPU and store in output arrays
shear_strains[valid_indices] = np.array(shear_vals)
principal_strains[valid_indices] = np.array(principal_vals)

except Exception as e:
print(f"Error in strain computation: {str(e)}")
# Return zeros if computation fails
return shear_strains, principal_strains

return shear_strains, principal_strains

0 comments on commit 0f52360

Please sign in to comment.