A PyTorch-based surrogate model for space charge simulation, enabling rapid approximation of electric field calculations from charge density distributions. This repository provides a full pipeline: data generation (with Julia/Distgen), preprocessing, model training, evaluation, and visualization.
# Clone the repository
git clone https://github.com/ndwang/SC_surrogate.git
cd SC_surrogate
# Create and activate the conda environment
conda env create -f environment.yml
conda activate sc_surrogateSC_surrogate/
├── configs/ # YAML configuration files
│ ├── training_config.yaml # Main training & preprocessing config
│ ├── generation_config.yaml # Data generation config
│ └── distgen_template.yaml # Distgen beam template
├── data/
│ ├── raw/ # Raw simulation data (HDF5, group-per-sample)
│ └── processed/ # Preprocessed train/val/test data (HDF5, monolithic)
├── generation/
│ └── generate_data.py # Data generation script (Julia/Distgen)
├── preprocessing/
│ └── preprocess_data.py # Data preprocessing pipeline
├── modeling/
│ ├── models/ # Neural network model definitions
│ │ ├── __init__.py # Model registry and factory
│ │ ├── cnn3d.py # 3D CNN architecture
│ │ ├── unet3d.py # 3D U-Net architecture
│ │ └── vae2d.py # 2D Variational Autoencoder (15-channel inputs)
│ ├── dataset.py # PyTorch Dataset & DataLoader utilities
│ └── train.py # Model training script
├── evaluation/
│ ├── evaluate.py # Model evaluation script
│ └── visualization/ # Visualization modules
│ ├── raw_data.py # Raw data visualization
│ ├── predict_efield.py # Model prediction visualization
│ └── training_curves.py # Training/validation curves
├── scripts/ # CLI entry points for main tasks and visualization
│ ├── generate_dataset.py # Generate synthetic data
│ ├── preprocess_dataset.py # Preprocess data
│ ├── train_model.py # Train the model
│ ├── evaluate_model.py # Evaluate the model
│ ├── visualize_raw_data.py # Visualize raw data
│ ├── visualize_predict_efield.py # Visualize model predictions
│ └── visualize_training_curves.py # Visualize training/validation loss curves
├── saved_models/ # Model checkpoints, scalers
├── tests/
│ ├── test_data_pipeline.py # Data pipeline test suite
│ └── test_model_training.py # Model training test suite
├── environment.yml # Conda environment definition
└── README.md
Generate synthetic space charge simulation data using Julia and Distgen:
python scripts/generate_dataset.py configs/generation_config.yaml- Config:
configs/generation_config.yamlcontrols output location, grid size, number of samples, parameter ranges, and device (CPU/GPU). - Template: Uses
configs/distgen_template.yamlfor beam/particle settings. - Output: HDF5 file in
data/raw/with group-per-sample structure:run_00001/rho,run_00001/efield,run_00001/parameters.
Tip: Requires Julia and the SpaceCharge Julia package. See Julia/Distgen setup below if needed.
Convert raw simulation data to a format suitable for PyTorch training:
python scripts/preprocess_dataset.py --config configs/training_config.yamlOr in Python:
from preprocessing.preprocess_data import Preprocessor
Preprocessor('configs/training_config.yaml').run()Pipeline steps:
- Reads raw HDF5 data
- Converts to monolithic format for efficient loading
- Applies normalization using configurable scalers (StandardScaler or SymlogScaler)
- Splits into train/val/test sets
- Saves processed data to
data/processed/and scalers tosaved_models/
Scaler configuration:
- You can specify the normalization method for both input (charge density) and target (electric field) data in
configs/training_config.yamlas a dictionary with atypekey and optional parameters:
preprocessing:
input_scaler:
type: 'standard' # Options: 'standard', 'symlog'
target_scaler:
type: 'symlog' # Use 'symlog' for symmetric log scaling, or 'standard' for StandardScaler
linthresh: 0.005 # (optional) Linear threshold for symlog
percentile: 90 # (optional) Percentile for automatic linthresh selectionstandard: StandardScaler (mean=0, std=1, suitable for most data)symlog: SymlogScaler (handles data with both positive and negative values spanning orders of magnitude)linthresh: (float, optional) Linear threshold for the symlog transform. If not provided, will be determined from data usingpercentile.percentile: (float, optional) Percentile (0-100) of |x| to use for linthresh. Default is 90.
If not specified, both default to 'standard'.
Train a neural network model on the preprocessed data:
python scripts/train_model.py --config configs/training_config.yamlTraining Pipeline:
- Automatically runs preprocessing if needed
- Creates model from config (supports CNN3D, UNet3D, and VAE2D)
- Sets up data loaders, optimizer, scheduler, and loss function
- Includes validation, checkpointing, early stopping, and logging
- Saves best model, training history, and logs
Loss Function Configuration:
- Loss functions are extensible and defined in
modeling/loss.py. - Standard losses:
mse,l1/mae,huber. - Custom/combined losses can be specified in the config as a dict, e.g.:
training:
loss_function:
type: "combined"
losses:
- type: "mse"
- type: "l1"
weights: [0.7, 0.3]- Add your own loss functions in
modeling/loss.pyand register them for use in config.
Key Features:
- Model-agnostic: Easily switch architectures via config
- Reproducible: Seed control and deterministic operations
- Robust: Automatic device selection, gradient clipping, error handling
- Monitored: Progress bars, logging, loss tracking
Training Output:
saved_models/best_model.pth- Best model checkpointsaved_models/checkpoint_epoch_XXX.pth- Periodic checkpointssaved_models/training_history.pkl- Loss curves and metricslogs/training.log- Detailed training logs
Evaluate a trained model on the test set:
python scripts/evaluate_model.py --config configs/training_config.yamlOptional: specify a specific checkpoint
python scripts/evaluate_model.py --config configs/training_config.yaml --checkpoint saved_models/best_model.pthEvaluation Pipeline:
- Automatically finds best model if no checkpoint specified
- Computes comprehensive metrics: MSE, MAE, R², RMSE
- Per-component metrics for each electric field component (Ex, Ey, Ez)
- Generates visualizations and saves predictions for analysis
- Creates human-readable reports
Evaluation Output:
saved_models/evaluation/evaluation_results.txt- Summary reportsaved_models/evaluation/evaluation_metrics.pkl- Detailed metricssaved_models/evaluation/predictions.npy- Model predictionssaved_models/evaluation/plots/- Visualization plots
Load processed data for custom PyTorch training:
from modeling.dataset import SpaceChargeDataset, create_data_loaders
# Load a single dataset
dataset = SpaceChargeDataset('data/processed/train.h5')
input_tensor, target_tensor = dataset[0] # input: (1,32,32,32), target: (3,32,32,32)
# Create DataLoaders for training/validation/testing
train_loader, val_loader, test_loader = create_data_loaders(
'data/processed/train.h5',
'data/processed/val.h5',
'data/processed/test.h5',
batch_size=8
)The framework supports easy addition of new model architectures:
Create a new file modeling/models/your_model.py:
import torch.nn as nn
from typing import Dict, Any
class YourModel(nn.Module):
def __init__(self, config: Dict[str, Any]):
super().__init__()
model_config = config.get('model', {})
# Initialize your model here
def forward(self, x):
# Your forward pass
return output
def get_model_summary(self):
# Return model information
return {'model_name': 'YourModel', ...}Edit modeling/models/__init__.py:
from .your_model import YourModel
MODEL_REGISTRY = {
'cnn3d': CNN3D,
'your_model': YourModel, # Add this line
}Set the architecture in configs/training_config.yaml:
model:
architecture: "your_model"
# your model-specific parametersUse the same training and evaluation commands - the framework automatically uses your new model!
The repository provides a collection of interactive visualization tools to help you explore raw data, training progress, and model predictions. All tools are located in evaluation/visualization/ and can be used both via command-line scripts and as Python modules.
- Raw Data Visualization (
raw_data.py)- Visualize charge density, electric field, or both from raw HDF5 simulation files.
- Model Prediction Visualization (
predict_efield.py)- Visualize model predictions versus ground truth, or inspect predicted fields for any test sample.
- Training Curve Visualization (
training_curves.py)- Plot training and validation loss curves from saved training history.
Visualize raw data (density, efield, or both):
python scripts/visualize_raw_data.py data/raw/simulations.h5 --plot both --run run_00000--plot: Choosedensity,efield, orboth--run: Specify the sample/run to visualize
Visualize model predictions (compare or predict mode):
python scripts/visualize_predict_efield.py data/processed/test.h5 --sample_idx 0 --checkpoint saved_models/best_model.pth --scalers data/processed/scalers.pkl --config configs/training_config.yaml --mode compare--mode compare: Interactive comparison of predicted and ground truth E-field--mode predict: Visualize charge density and predicted E-field only
Plot training and validation loss curves:
python scripts/visualize_training_curves.py saved_models/training_history.pklPlot histogram of raw electric field values (unscaled) across all samples:
# Use path from config
python scripts/plot_raw_output_histogram.py --config configs/training_config.yaml --bins 200 --component all
# Or provide raw file directly (positional argument)
python scripts/plot_raw_output_histogram.py data/raw/SC_10k.h5 --bins 200 --component Ex --save saved_models/evaluation/plots/raw_Ex_hist.png--config: Usespaths.raw_data_pathlikepreprocess_data.py- Positional file argument: direct path to raw HDF5 (group-per-sample)
--bins: Number of histogram bins--component:all,Ex,Ey, orEz--save: Save plot to PNG instead of showing interactively
You can visualize the learned convolutional kernels of a trained CNN3D model using the script:
python scripts/visualize_kernels.py --checkpoint saved_models/best_model.pth --config configs/training_config.yaml --layer_type encoder --layer_idx 0--checkpoint: Path to the model checkpoint (default:saved_models/best_model.pth)--config: Path to the model config (default:configs/training_config.yaml)--layer_type:encoderordecoder(default:encoder)--layer_idx: Index of the layer to visualize (default:0)
The script will display slices of each 3D kernel for the selected layer and channel.
output_dir,output_filename: Where to save raw datatemplate_file: Path to distgen templatedevice:cpuorgpugrid_size: Simulation grid resolutionn_samples: Number of samples to generatemin_bound,max_bound: Grid bounds (meters)sigma_mins,sigma_maxs: Parameter sampling rangesseed: Random seed for reproducibility
- Paths: Raw/processed data, model save dir, logs
- Preprocessing: Split ratios, normalization, chunking
- Model: Architecture, channels, layers, activation, dropout, etc.
- Training: Batch size, epochs, optimizer, scheduler, loss, device
- Resume: Optional resume-from-checkpoint settings
- Evaluation: Metrics, plotting, saving predictions
- Logging: Level, file, Tensorboard/W&B integration
- Defines the beam/particle distribution for simulation (see file for details)
Enable resume behavior in configs/training_config.yaml:
training:
resume:
enabled: true # enable resume behavior
checkpoint_path: saved_models/checkpoint_epoch_040.pth # or null
use_best: false # set true to resume from best_model.pthResolution order:
- If
use_best: true, the trainer loadssaved_models/best_model.pth. - Else if
checkpoint_pathis provided, it loads that file. - Else it tries
saved_models/latest_checkpoint.pth.
Restored state:
- Model weights, optimizer, scheduler, loss history, and epoch counter.
- Training resumes from the next epoch after the checkpoint's
epoch.
- Each sample:
run_XXXXX/rho: Charge density, shape(32, 32, 32), dtypefloat64efield: Electric field, shape(32, 32, 32, 3), dtypefloat64parameters: Beam parameters, shape(3,), dtypefloat64
charge_density: shape(N, 32, 32, 32), dtypefloat32, normalizedelectric_field: shape(N, 3, 32, 32, 32), dtypefloat32, normalized- For PyTorch: input
(1, Nx, Ny, Nz), target(3, Nx, Ny, Nz)
Run the full test suite:
# Test data pipeline
pytest tests/test_data_pipeline.py -v
# Test model training pipeline
pytest tests/test_model_training.py -v
# Run all tests
pytest tests/ -vData Pipeline Tests:
- End-to-end pipeline validation
- Dataset/DataLoader integration
- Normalization correctness
- Error handling and edge cases
Model Training Tests:
- Model instantiation and forward pass
- Training loop functionality
- Checkpoint saving/loading
- Evaluation pipeline
- End-to-end training process
Linting & Type Checking:
ruff check --fix
mypy preprocessing/ modeling/ --ignore-missing-imports- Install Julia
- Install Julia packages:
SpaceCharge(for field calculation)CUDA(if using GPU)
- The Python package
juliacall(installed via pip) is used for Python-Julia interop