Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 149 additions & 0 deletions autoregressive_prova_generic_condition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import torch
import matplotlib.pyplot as plt

from pina import Trainer
from pina.optim import TorchOptimizer
from pina.problem import AbstractProblem
from pina.condition.data_condition import DataCondition
from pina.solver import AutoregressiveSolver

NUM_TIMESTEPS = 100
NUM_FEATURES = 15
USE_TEST_MODEL = False

# ============================================================================
# DATA
# ============================================================================

torch.manual_seed(42)

y = torch.zeros(NUM_TIMESTEPS, NUM_FEATURES)
y[0] = torch.rand(NUM_FEATURES) # Random initial state

for t in range(NUM_TIMESTEPS - 1):
y[t + 1] = 0.95 * y[t] # + 0.05 * torch.sin(y[t].sum())

# ============================================================================
# TRAINING
# ============================================================================

class SimpleModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.layers = torch.nn.Sequential(
torch.nn.Linear(y.shape[1], 20),
torch.nn.ReLU(),
torch.nn.Dropout(0.2),
torch.nn.Linear(20, y.shape[1]),
)

def forward(self, x):
return x + self.layers(x)


class TestModel(torch.nn.Module):
"""
Debug model that implements the EXACT transformation rule.
y[t+1] = 0.95 * y[t]
Expected loss is zero
"""

def __init__(self, data_series=None):
super().__init__()
self.dummy_param = torch.nn.Parameter(torch.zeros(1))

def forward(self, x):
next_state = 0.95 * x # + 0.05 * torch.sin(x.sum(dim=1, keepdim=True))
return next_state + 0.0 * self.dummy_param


class Problem(AbstractProblem):
output_variables = None
input_variables = None
conditions = {
"data_condition_0":DataCondition(input=y),
"data_condition_1":DataCondition(input=y),
}

problem = Problem()

#for each condition, define unroll instructions with these keys:
# - unroll_length: length of each unroll window
# - num_unrolls: number of unroll windows to create (if None, use all possible)
# - randomize: whether to randomize the starting indices of the unroll windows
unroll_instructions = {
"data_condition_0": {
"unroll_length": 10,
"num_unrolls": 89,
"randomize": True,
"eps": 5.0
},
"data_condition_1": {
"unroll_length": 20,
"num_unrolls": 79,
"randomize": True,
"eps": 10.0
},
}

solver = AutoregressiveSolver(
unroll_instructions=unroll_instructions,
problem=problem,
model=TestModel() if USE_TEST_MODEL else SimpleModel(),
optimizer= TorchOptimizer(torch.optim.AdamW, lr=0.01),
eps=10.0,
)

trainer = Trainer(
solver, max_epochs=2000, accelerator="cpu", enable_model_summary=False, shuffle=False
)
trainer.train()

# ============================================================================
# VISUALIZATION
# ============================================================================

test_start_idx = 50
num_prediction_steps = 30

initial_state = y[test_start_idx] # Shape: [features]
predictions = solver.predict(initial_state, num_prediction_steps)
actual = y[test_start_idx : test_start_idx + num_prediction_steps + 1]

total_mse = torch.nn.functional.mse_loss(predictions[1:], actual[1:])
print(f"\nOverall MSE (all {num_prediction_steps} steps): {total_mse:.6f}")

# viauzlize single dof
dof_to_plot = [0, 3, 6, 9, 12]
colors = [
"r",
"g",
"b",
"c",
"m",
"y",
"k",
]
plt.figure(figsize=(10, 6))
for dof, color in zip(dof_to_plot, colors):
plt.plot(
range(test_start_idx, test_start_idx + num_prediction_steps + 1),
actual[:, dof].numpy(),
label="Actual",
marker="o",
color=color,
markerfacecolor="none",
)
plt.plot(
range(test_start_idx, test_start_idx + num_prediction_steps + 1),
predictions[:, dof].numpy(),
label="Predicted",
marker="x",
color=color,
)

plt.title(f"Autoregressive Predictions vs Actual, MRSE: {total_mse:.6f}")
plt.legend()
plt.xlabel("Timestep")
plt.savefig(f"autoregressive_predictions.png")
plt.close()
5 changes: 5 additions & 0 deletions pina/solver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"DeepEnsembleSupervisedSolver",
"DeepEnsemblePINN",
"GAROM",
"AutoregressiveSolver",
]

from .solver import SolverInterface, SingleSolverInterface, MultiSolverInterface
Expand All @@ -41,3 +42,7 @@
DeepEnsemblePINN,
)
from .garom import GAROM
from .autoregressive_solver import (
AutoregressiveSolver,
AutoregressiveSolverInterface,
)
4 changes: 4 additions & 0 deletions pina/solver/autoregressive_solver/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
__all__ = ["AutoregressiveSolver", "AutoregressiveSolverInterface"]

from .autoregressive_solver import AutoregressiveSolver
from .autoregressive_solver_interface import AutoregressiveSolverInterface
172 changes: 172 additions & 0 deletions pina/solver/autoregressive_solver/autoregressive_solver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import torch
from pina.utils import check_consistency
from pina.solver.solver import SingleSolverInterface
from pina.condition import DataCondition
from .autoregressive_solver_interface import AutoregressiveSolverInterface


class AutoregressiveSolver(
AutoregressiveSolverInterface, SingleSolverInterface
):
"""
Autoregressive Solver class.
"""

accepted_conditions_types = DataCondition

def __init__(
self,
unroll_instructions,
problem,
model,
eps=None,
loss=None,
optimizer=None,
scheduler=None,
weighting=None,
use_lt=False,
):
"""
Initialization of the :class:`AutoregressiveSolver` class.
:param dict unroll_instructions: A dictionary specifying how to unroll each condition.
this is supposed to map condition names to dict objects with unroll instructions.
:param AbstractProblem problem: The problem to be solved.
:param torch.nn.Module model: The model to be trained.
:param torch.nn.Module or LossInterface or None loss: The loss function to be minimized. If None, defaults to MSELoss.
:param TorchOptimizer or None optimizer: The optimizer to be used. If None, no optimization is performed.
:param TorchScheduler or None scheduler: The learning rate scheduler to be used. If None, no scheduling is performed.
:param Weighting or None weighting: The weighting scheme for combining losses from different conditions. If None, equal weighting is applied.
:param bool use_lt: Whether to use learning rate tuning.
"""

super().__init__(
unroll_instructions=unroll_instructions,
problem=problem,
model=model,
loss=loss,
optimizer=optimizer,
scheduler=scheduler,
weighting=weighting,
use_lt=use_lt,
)

def loss_data(self, data, condition_unroll_instructions):
"""
Compute the data loss for the recursive autoregressive solver.
This will be applied to each condition individually.
:param torch.Tensor data: all training data.
:param dict condition_unroll_instructions: instructions on how to unroll the model for this condition.
:return: Computed loss value.
:rtype: torch.Tensor
"""

initial_data, unroll_data = self.create_unroll_windows(
data, condition_unroll_instructions
)

unroll_length = condition_unroll_instructions["unroll_length"]
current_state = initial_data # [num_unrolls, features]

losses = []
for step in range(unroll_length):

predicted_state = self.forward(current_state) # [num_unrolls, features]
target_state = unroll_data[:, step, :] # [num_unrolls, features]
step_loss = self._loss_fn(predicted_state, target_state)
losses.append(step_loss)
current_state = predicted_state

step_losses = torch.stack(losses) # [unroll_length]

with torch.no_grad():
weights = self.compute_adaptive_weights(step_losses.detach(), condition_unroll_instructions)

weighted_loss = (step_losses * weights).sum()
return weighted_loss

def create_unroll_windows(self, data, condition_unroll_instructions):
"""
Create unroll windows for each condition from the data based on the provided instructions.
:param torch.Tensor data: The full data tensor.
:param dict condition_unroll_instructions: Instructions on how to unroll the model for this condition.
:return: Tuple of initial data and unroll data tensors.
:rtype: (torch.Tensor, torch.Tensor)
"""

unroll_length = condition_unroll_instructions["unroll_length"]

start_list = []
unroll_list = []
for starting_index in self.decide_starting_indices(
data, condition_unroll_instructions
):
idx = starting_index.item()
start = data[idx]
target_start = idx + 1
unroll = data[target_start : target_start + unroll_length, :]
start_list.append(start)
unroll_list.append(unroll)
initial_data = torch.stack(start_list) # [num_unrolls, features]
unroll_data = torch.stack(unroll_list) # [num_unrolls, unroll_length, features]
return initial_data, unroll_data

def decide_starting_indices(self, data, condition_unroll_instructions):
"""
Decide the starting indices for unrolling based on the provided instructions.
:param torch.Tensor data: The full data tensor.
:param dict condition_unroll_instructions: Instructions on how to unroll the model for this condition.
:return: Tensor of starting indices.
:rtype: torch.Tensor
"""
n_step, n_features = data.shape
num_unrolls = condition_unroll_instructions.get("num_unrolls", None)
unroll_length = condition_unroll_instructions["unroll_length"]
randomize = condition_unroll_instructions.get("randomize", True)

max_start = n_step - unroll_length
indices = torch.arange(max_start, device=data.device)

if num_unrolls is not None and num_unrolls < len(indices):
indices = indices[:num_unrolls]

if randomize:
indices = indices[torch.randperm(len(indices), device=data.device)]

return indices

def compute_adaptive_weights(self, step_losses, condition_unroll_instructions):
"""
Compute adaptive weights for each time step based on cumulative losses.
:param torch.Tensor step_losses: Tensor of shape [unroll_length] containing losses at each time step.
:return: Tensor of shape [unroll_length] containing normalized weights.
:rtype: torch.Tensor
"""
num_steps = len(step_losses)
eps = condition_unroll_instructions.get("eps", None)
if eps is None:
weights = torch.ones_like(step_losses)
else:
weights = torch.exp(-eps * torch.cumsum(step_losses, dim=0))

return weights / weights.sum()

def predict(self, initial_state, num_steps):
"""
Make recursive predictions starting from an initial state.
:param torch.Tensor initial_state: Initial state tensor.
:param int num_steps: Number of steps to predict ahead.
:return: Tensor of predictions.
:rtype: torch.Tensor
"""
self.eval() # Set model to evaluation mode

current_state = initial_state
predictions = [current_state]

with torch.no_grad():
for step in range(num_steps):
next_state = self.forward(current_state)
predictions.append(next_state)
current_state = next_state

return torch.stack(predictions)
Loading