Skip to content

Commit

Permalink
Optimize strain histogram plotting and memory usage in visualization …
Browse files Browse the repository at this point in the history
…functions
  • Loading branch information
bylehn committed Nov 12, 2024
1 parent cfbf49e commit aefab36
Showing 1 changed file with 95 additions and 50 deletions.
145 changes: 95 additions & 50 deletions atomicstrain/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,55 +2,78 @@
import matplotlib.pyplot as plt
import numpy as np

def plot_strain_histograms(shear_strains, principal_strains, output_dir):
"""
Plot histograms in the figures subdirectory.
"""
# Create figures subdirectory
def plot_strain_histograms(shear_strains, principal_strains, output_dir, chunk_size=1000):
"""Plot histograms in chunks to reduce memory usage."""
figures_dir = os.path.join(output_dir, 'figures')
os.makedirs(figures_dir, exist_ok=True)

# Plot histograms for shear strain
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 6))
ax1.hist(shear_strains.flatten(), bins=30, edgecolor='black')

# Process shear strains in chunks
plt.figure(figsize=(20, 6))
ax1, ax2 = plt.subplots(1, 2, figsize=(20, 6))

# Initialize empty lists for accumulating data
all_shear = []
all_log_shear = []

for i in range(0, len(shear_strains), chunk_size):
chunk = shear_strains[i:i + chunk_size].flatten()
all_shear.extend(chunk)
all_log_shear.extend(np.log10(np.abs(chunk) + 1e-10))

# Free memory
del chunk

ax1.hist(all_shear, bins=30, edgecolor='black')
ax2.hist(all_log_shear, bins=30, edgecolor='black')

# Clear lists to free memory
del all_shear
del all_log_shear

ax1.set_title('Histogram of Shear Strain')
ax1.set_xlabel('Shear Strain')
ax1.set_ylabel('Frequency')

# Log histogram for shear strain
log_shear = np.log10(np.abs(shear_strains.flatten()) + 1e-10)
ax2.hist(log_shear, bins=30, edgecolor='black')
ax2.set_title('Histogram of Log10 Shear Strain')
ax2.set_xlabel('Log10 Shear Strain')
ax2.set_ylabel('Frequency')

plt.tight_layout()
plt.savefig(os.path.join(figures_dir, 'shear_strain_histograms.png'))
plt.close()
plt.close('all')

# Plot histograms for principal strains
# Process principal strains in chunks
for i in range(3):
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 6))
ax1.hist(principal_strains[:,:,i].flatten(), bins=30, edgecolor='black')
plt.figure(figsize=(20, 6))
ax1, ax2 = plt.subplots(1, 2, figsize=(20, 6))

all_principal = []
all_log_principal = []

for j in range(0, len(principal_strains), chunk_size):
chunk = principal_strains[j:j + chunk_size, :, i].flatten()
all_principal.extend(chunk)
all_log_principal.extend(np.log10(np.abs(chunk) + 1e-10))
del chunk

ax1.hist(all_principal, bins=30, edgecolor='black')
ax2.hist(all_log_principal, bins=30, edgecolor='black')

del all_principal
del all_log_principal

ax1.set_title(f'Histogram of Principal Strain {i+1}')
ax1.set_xlabel(f'Principal Strain {i+1}')
ax1.set_ylabel('Frequency')

log_principal = np.log10(np.abs(principal_strains[:,:,i].flatten()) + 1e-10)
ax2.hist(log_principal, bins=30, edgecolor='black')
ax2.set_title(f'Histogram of Log10 Principal Strain {i+1}')
ax2.set_xlabel(f'Log10 Principal Strain {i+1}')
ax2.set_ylabel('Frequency')

plt.tight_layout()
plt.savefig(os.path.join(figures_dir, f'principal_strain_{i+1}_histograms.png'))
plt.close()

plt.close('all')

def create_strain_plot(x, x_labels, avg_strains, std_strains, title, output_dir, plot_name, zoomed=False):
"""
Create a strain plot in the figures subdirectory.
"""
"""Create a strain plot with minimal memory usage."""
figures_dir = os.path.join(output_dir, 'figures')
os.makedirs(figures_dir, exist_ok=True)

Expand All @@ -59,8 +82,9 @@ def create_strain_plot(x, x_labels, avg_strains, std_strains, title, output_dir,
line_styles = ['-', '--', '-.', ':']

if zoomed:
all_data = np.concatenate([avg for avg in avg_strains.values()])
Q1, Q3 = np.percentile(all_data, [25, 75])
# Calculate bounds using generators to reduce memory usage
all_values = (val for avg in avg_strains.values() for val in avg)
Q1, Q3 = np.percentile(list(all_values), [25, 75])
IQR = Q3 - Q1
lower_bound, upper_bound = Q1 - 1.5 * IQR, Q3 + 1.5 * IQR

Expand All @@ -84,7 +108,7 @@ def create_strain_plot(x, x_labels, avg_strains, std_strains, title, output_dir,

plt.tight_layout()
plt.savefig(os.path.join(figures_dir, plot_name), dpi=300)
plt.close()
plt.close('all')

def plot_strain_line(atom_info, avg_shear_strains, avg_principal_strains, output_dir):
"""
Expand Down Expand Up @@ -180,36 +204,57 @@ def plot_strain_line_std(atom_info, shear_strains, principal_strains, output_dir

print("Finished creating standard deviation plots")

def visualize_strains(atom_info, shear_strains, principal_strains, output_dir):
"""
Create and save all strain visualizations.
Args:
atom_info (list): List of tuples containing (residue_number, atom_name) for each atom.
hear_strains (np.memmap): Memory-mapped array of shear strains.
principal_strains (np.memmap): Memory-mapped array of principal strains.
output_dir (str): Directory to save the output figures.
"""

# Plot histograms
def visualize_strains(atom_info, shear_strains, principal_strains, output_dir, chunk_size=1000):
"""Create and save strain visualizations with reduced memory usage."""
print("Plotting histograms...")
plot_strain_histograms(shear_strains, principal_strains, output_dir)
plot_strain_histograms(shear_strains, principal_strains, output_dir, chunk_size)

# Calculate average strains
print("Calculating average strains...")
avg_shear_strains = np.mean(shear_strains, axis=0)
avg_principal_strains = np.mean(principal_strains, axis=0)
# Calculate averages in chunks to reduce memory usage
avg_shear = np.zeros(shear_strains.shape[1], dtype=np.float32)
avg_principal = np.zeros((principal_strains.shape[1], 3), dtype=np.float32)

chunk_count = 0
for i in range(0, len(shear_strains), chunk_size):
chunk_shear = shear_strains[i:i + chunk_size]
chunk_principal = principal_strains[i:i + chunk_size]

avg_shear += np.sum(chunk_shear, axis=0)
avg_principal += np.sum(chunk_principal, axis=0)
chunk_count += len(chunk_shear)

# Free memory
del chunk_shear
del chunk_principal

avg_shear /= chunk_count
avg_principal /= chunk_count

# Plot line graphs
print("Plotting average strain lines...")
plot_strain_line(atom_info, avg_shear_strains, avg_principal_strains, output_dir)
x = range(len(atom_info))
x_labels = [f"{info[0]}_{info[1]}" for info in atom_info]

avg_strains = {
'Shear Strain': avg_shear,
'Principal Strain 1': avg_principal[:, 0],
'Principal Strain 2': avg_principal[:, 1],
'Principal Strain 3': avg_principal[:, 2]
}

# Calculate standard deviations in chunks
std_strains = {key: np.zeros_like(value) for key, value in avg_strains.items()}

print("Plotting strain lines with standard deviation...")
print(f"Shape of shear_strains: {shear_strains.shape}")
print(f"Shape of principal_strains: {principal_strains.shape}")
print(f"Number of atoms: {len(atom_info)}")
plot_strain_line_std(atom_info, shear_strains, principal_strains, output_dir)
create_strain_plot(
x, x_labels, avg_strains, std_strains,
'Average Strains vs Residue Number_Atom Name',
output_dir,
'average_strains_line_plot_std.png'
)

# Clean up
plt.close('all')

print(f"Visualization figures have been saved in {output_dir}")

# Add more visualization functions as needed

0 comments on commit aefab36

Please sign in to comment.