import os
from typing import Callable, Dict, Tuple
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
from aido.logger import logger
from aido.monitoring.logger import WandbLogger
from aido.optimization_helpers import ParameterModule
from aido.simulation_helpers import SimulationParameterDictionary
from aido.surrogate import Surrogate, SurrogateDataset
[docs]
class Optimizer(torch.nn.Module):
'''
The optimizer uses the surrogate model to optimise the detector parameters in batches.
It is also linked to a generator object, to check if the parameters are still in bounds using
the function is_local(parameters) of the generator.
Once the parameters are not local anymore, the optimizer will return the last parameters that were local and stop.
For this purpose, the surrogate model will need to be applied using fixed weights.
Then the reconstruction model loss will be applied based on the surrogate model output.
The gradient w.r.t. the detector parameters will be calculated and the parameters will be updated.
'''
[docs]
def __init__(
self,
parameter_dict: SimulationParameterDictionary,
device: str | None = None
):
"""
Initializes the optimizer with the given surrogate model and parameters.
Args:
starting_parameter_dict (Dict): A dictionary containing the initial parameters.
device (str): Defaults to 'cuda'
"""
super().__init__()
self.parameter_dict = parameter_dict
dev = "cuda" if torch.cuda.is_available() else "cpu"
self.device = dev or torch.device(dev)
self.parameter_module = ParameterModule(self.parameter_dict).to(self.device)
self.optimizer = torch.optim.Adam(self.parameter_module.parameters())
[docs]
def to(self, device: str | torch.device, **kwargs) -> "Optimizer":
""" Move all Tensors and modules to 'device'.
"""
self.device = device if isinstance(device, torch.device) else torch.device(device)
super().to(self.device, **kwargs)
return self
[docs]
def check_parameters_are_local(self, updated_parameters: torch.Tensor, scale=1.0) -> bool:
""" Assure that the predicted parameters by the optimizer are within the bounds of the covariance
matrix spanned by the 'sigma' of each parameter.
"""
diff = updated_parameters - self.starting_parameters_continuous
diff = diff.detach().cpu().numpy()
return np.dot(diff, np.dot(np.linalg.inv(self.parameter_dict.covariance), diff)) < scale
@property
def boundaries(self) -> torch.Tensor:
""" Adds penalties for parameters that are outside of the boundaries spaned by 'self.parameter_box'. This
ensures that the optimizer does not propose new values that are outside of the scope of the Surrogate and
therefore largely unknown to the current iteration.
Returns:
--------
torch.Tensor
"""
parameter_box = self.parameter_module.constraints.to(self.device)
if len(parameter_box) != 0:
parameters_continuous_tensor = self.parameter_module.continuous_tensors()
lower_boundary_loss = torch.mean(
0.5 * torch.nn.ReLU()(parameter_box[:, 0] - parameters_continuous_tensor)**2
)
upper_boundary_loss = torch.mean(
0.5 * torch.nn.ReLU()(parameters_continuous_tensor - parameter_box[:, 1])**2
)
return lower_boundary_loss + upper_boundary_loss
else:
return torch.Tensor([0.0])
[docs]
def other_constraints(
self,
constraints_func: None | Callable[[SimulationParameterDictionary, Dict], torch.Tensor],
parameter_dict_as_tensor: Dict[str, torch.nn.Parameter | torch.Tensor]
) -> torch.Tensor:
""" Adds user-defined constraints defined in 'interface.py:AIDOUserInterface.constraints()'. If no constraints
were added manually, this method defaults to calculating constraints based on the cost per parameter specified
in ParameterDict. Returns a float or torch.Tensor which can be considered as a penalty loss.
"""
if constraints_func is None:
loss = self.parameter_module.cost_loss
else:
loss = constraints_func(self.parameter_dict, parameter_dict_as_tensor)
return loss if loss is not None else torch.tensor(0.0)
def save_parameters(
self,
epoch: int,
batch_index: int,
loss: float,
filepath: str | os.PathLike = "parameter_optimizer_df.parquet",
) -> None:
df = self.parameter_dict.to_df(display_discrete="as_probabilities")
df["Epoch"] = epoch
df["Batch"] = batch_index
df["Surrogate_Prediction"] = loss
if not os.path.exists(filepath):
df.to_parquet(filepath)
else:
try:
updated_parameter_optimizer_df = pd.concat([pd.read_parquet(filepath), df], ignore_index=True)
updated_parameter_optimizer_df.to_parquet(filepath)
except KeyboardInterrupt:
logger.warning("Saving the Optimizer Parameters to file before stopping...")
updated_parameter_optimizer_df.to_parquet(filepath)
raise
def print_grads(self) -> None:
for name, param in self.named_parameters():
if param.requires_grad and not name.startswith("surrogate_model"):
logger.debug(f"Optimizer {name}: Data={param.data}, grads={param.grad}")
[docs]
def optimize(
self,
surrogate_model: Surrogate,
dataset: SurrogateDataset,
batch_size: int,
n_epochs: int,
reconstruction_loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
additional_constraints: None | Callable[[SimulationParameterDictionary, Dict], torch.Tensor] = None,
parameter_optimizer_savepath: str | os.PathLike | None = None,
device: str | None = None,
lr: float = 0.01,
wandb_logger: WandbLogger | None = None) -> Tuple[SimulationParameterDictionary, bool]:
""" Perform the optimization step.
1. The ParameterModule().forward() method generates new parameters.
2. The Surrogate Model computes the corresponding Reconstruction Loss (based on its interpolation).
3. The Optimizer Loss is the Sum of the Reconstruction Loss, user-defined Parameter Loss
(e.g. cost constraints) and the Parameter Box Loss (which ensures that the Parameters stay
within acceptable boundaries during training).
4. The optimizer applies backprogation and updates the current ParameterDict
Returns:
--------
SimulationParameterDictionary
bool
"""
self.starting_parameter_dict = self.parameter_dict
self.surrogate_model = surrogate_model
self.device = device or self.device
self.to(self.device)
self.starting_parameters_continuous = self.parameter_module.continuous_tensors().clone().detach()
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
self.surrogate_model.eval()
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
self.optimizer_loss = []
self.constraints_loss = []
gradients_norm = []
gradients_min = []
gradients_max = []
for epoch in range(n_epochs):
epoch_loss = 0.0
epoch_constraints_loss = 0.0
stop_epoch = False
for batch_idx, (_parameters, context, targets, _reconstructed) in enumerate(data_loader):
context: torch.Tensor = context.to(self.device)
targets: torch.Tensor = targets.to(self.device)
parameters_batch: torch.Tensor = self.parameter_module()
surrogate_output = self.surrogate_model.sample_forward(
parameters_batch,
context,
targets
)
surrogate_reconstruction_loss = reconstruction_loss(
dataset.unnormalize_features(targets, index=2),
dataset.unnormalize_features(surrogate_output, index=2)
)
loss = surrogate_reconstruction_loss.mean()
surrogate_loss_detached = loss.item()
constraints_loss = self.other_constraints(
additional_constraints,
self.parameter_module.current_values()
)
loss += constraints_loss
loss += self.boundaries
loss.backward()
if wandb_logger is not None:
flat_grads = torch.cat([p.grad.detach().flatten() for p in self.parameter_module.parameters() if p.grad is not None])
gradients_norm.append(flat_grads.norm().item())
gradients_min.append(flat_grads.abs().min().item())
gradients_max.append(flat_grads.abs().max().item())
if np.isnan(loss.item()):
logger.error("Optimizer: NaN loss, exiting.")
self.optimizer.step()
return self.parameter_dict, False
self.optimizer.step()
self.optimizer.zero_grad()
self.parameter_dict.update_current_values(self.parameter_module.physical_values(format="dict"))
self.parameter_dict.update_probabilities(self.parameter_module.probabilities)
self.save_parameters(epoch, batch_idx, surrogate_loss_detached, parameter_optimizer_savepath)
epoch_loss += loss.item()
epoch_constraints_loss += constraints_loss.item()
if not self.check_parameters_are_local(
updated_parameters=self.parameter_module.continuous_tensors(),
scale=0.8
):
stop_epoch = True
logger.error("Optimizer: Parameters are not local")
break
logger.info(
f"Optimizer Epoch: {epoch} \tLoss: {surrogate_loss_detached:.5f} (reco)\t"
+ f"+ {(constraints_loss.item()):.5f} (constraints)\t"
+ f"+ {(self.boundaries.item()):.5f} (boundaries)\t"
+ f"= {loss.item():.5f} (total)"
)
epoch_loss /= batch_idx + 1
epoch_constraints_loss /= batch_idx + 1
self.optimizer_loss.append(epoch_loss)
self.constraints_loss.append(epoch_constraints_loss)
if stop_epoch:
break
if wandb_logger is not None:
wandb_logger.log_scalars("Optimizer Loss", self.optimizer_loss)
wandb_logger.log_scalars("Constraints Loss", self.constraints_loss)
if len(gradients_norm) > 0:
wandb_logger.log_gradients("Gradients", gradients_norm, gradients_min, gradients_max)
self.parameter_dict.covariance = self.parameter_module.adjust_covariance(
self.parameter_module.continuous_tensors().to(self.device)
- self.starting_parameters_continuous.to(self.device)
).astype(float)
return self.parameter_dict, True
@property
def boosted_parameter_dict(self) -> SimulationParameterDictionary:
r""" Compute a new set of parameters by taking the current parameter dict and boosting it along
the direction of change between the previous and the current values (only continuous parameters).
Formula:
\[
p_{n+1} = p_{opt} + \frac{1}{2} \left( p_{opt} - p_n \right)
\]
Where:
- \( p_{n+1} \) is the updated parameter dict.
- \( p_{opt} \) is the current (optimized) parameter dict.
- \( p_n \) is the starting parameter dict.
"""
current_values = self.parameter_dict.get_current_values("dict", types="continuous")
previous_values = self.starting_parameter_dict.get_current_values("dict", types="continuous")
return self.parameter_dict.update_current_values(
{key: current_values[key] + 0.5 * (current_values[key] - previous_values[key]) for key in current_values}
)