Source code for aido.training

from contextlib import nullcontext
import json
import os
from typing import Callable

import numpy as np
import pandas as pd
import torch

from aido.config import AIDOConfig
from aido.logger import logger
from aido.optimizer import Optimizer
from aido.simulation_helpers import SimulationParameterDictionary
from aido.surrogate import Surrogate, SurrogateDataset
from aido.monitoring.logger import WandbLogger, WandbTaskLogger


[docs] def pre_train(model: Surrogate, dataset: SurrogateDataset, n_epochs: int) -> list[float]: """Pre-train the Surrogate Model using a three-stage process. This function performs pre-training in three stages with different batch sizes and learning rates to ensure stable convergence. Parameters ---------- model : Surrogate The surrogate model to pre-train. dataset : SurrogateDataset The dataset to use for training. n_epochs : int Number of epochs to train in each stage. """ model.to("cuda" if torch.cuda.is_available() else "cpu") logger.info('Surrogate: Pre-Training 0') model.train_model(dataset, batch_size=512, n_epochs=n_epochs, lr=0.001) logger.info('Surrogate: Pre-Training 1') model.train_model(dataset, batch_size=1024, n_epochs=n_epochs, lr=0.001) logger.info('Surrogate: Pre-Training 2') model.train_model(dataset, batch_size=1024, n_epochs=n_epochs, lr=0.0003)
def train_or_load_surrogate(config: AIDOConfig, parameter_dict: SimulationParameterDictionary, surrogate_previous_path: str, surrogate_save_path: str, surrogate_df: pd.DataFrame, task_logger: WandbTaskLogger | None = None) -> tuple[Surrogate, SurrogateDataset]: if os.path.isfile(surrogate_save_path): surrogate: Surrogate = torch.load(surrogate_save_path) surrogate.mark_step_offset() surrogate_dataset = SurrogateDataset(surrogate_df, means=surrogate.means, stds=surrogate.stds) else: if os.path.isfile(surrogate_previous_path): surrogate: Surrogate = torch.load(surrogate_previous_path, weights_only=False) surrogate.mark_step_offset() surrogate_dataset = SurrogateDataset(surrogate_df, means=surrogate.means, stds=surrogate.stds) else: surrogate_dataset = SurrogateDataset(surrogate_df) surrogate = Surrogate(*surrogate_dataset.shape, surrogate_dataset.means, surrogate_dataset.stds) pre_train(surrogate, surrogate_dataset, config.surrogate.n_epoch_pre) logger.info("Surrogate Training") n_epochs_main = config.surrogate.n_epochs_main surrogate.train_model(surrogate_dataset, batch_size=1024, n_epochs=n_epochs_main // 2, lr=0.005) surrogate_loss = surrogate.train_model(surrogate_dataset, batch_size=1024, n_epochs=n_epochs_main, lr=0.0003) surrogate_lr = 0.001 * (1 if parameter_dict.iteration <= 50 else 0.5) while not surrogate.update_best_surrogate_loss(surrogate_loss): logger.info("Surrogate retraining") pre_train(surrogate, surrogate_dataset, config.surrogate.n_epoch_pre) surrogate.train_model( surrogate_dataset, batch_size=256, n_epochs=n_epochs_main // 5, lr=5 * surrogate_lr ) surrogate.train_model( surrogate_dataset, batch_size=1024, n_epochs=n_epochs_main // 2, lr=1 * surrogate_lr ) surrogate.train_model( surrogate_dataset, batch_size=1024, n_epochs=n_epochs_main // 2, lr=0.3 * surrogate_lr) surrogate_loss = surrogate.train_model( surrogate_dataset, batch_size=1024, n_epochs=n_epochs_main // 2, lr=0.1 * surrogate_lr, ) if task_logger is not None: task_logger.log_scalars("Surrogate Loss", surrogate.surrogate_loss[surrogate.step_offset:], surrogate.step_offset) torch.save(surrogate, surrogate_save_path) return surrogate, surrogate_dataset
[docs] def training_loop( reco_file_paths_dict: dict | str | os.PathLike, reconstruction_loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], constraints: None | Callable[[SimulationParameterDictionary], float | torch.Tensor] = None, wandb_logger: WandbLogger | None = None ): """Internal training of the Surrogate and Optimizer models Args: reco_file_paths_dict (dict | str | os.PathLike): Either the dict with all the file paths or a single filepath (str or os.PathLike) that we first have to read from JSON. reconstruction_loss_function (Callable): The user-defined loss function that provides the goodness of a given design. Has to take two Tensors (truth and predicted) and return a scalar Tensor used as the Optimizer loss. constraints (Callable, optional). Additional loss function to be applied on top of the regular loss function, for example to account for cost penalties. Default is None Returns: SimulationParameterDictionary: The updated values as proposed by the Optimizer model. Note: This function is integral to the correct training of the surrogate and optimizer models. The training itself consists of these steps: 1. Track all the file paths needed 2. Instantiate the Surrogate model if not done so, load it from .pt file if available from current iteration (if training was stopped), then train it. 3. Run the Optimizer 4. Save results """ if isinstance(reco_file_paths_dict, (str, os.PathLike)): with open(reco_file_paths_dict, "r") as file: reco_file_paths_dict = json.load(file) config = AIDOConfig.from_json(reco_file_paths_dict["config_path"]) results_dir = reco_file_paths_dict["results_dir"] output_df_path = reco_file_paths_dict["reco_output_df"] parameter_dict_input_path = reco_file_paths_dict["current_parameter_dict"] surrogate_previous_path = reco_file_paths_dict["surrogate_model_previous_path"] optimizer_previous_path = reco_file_paths_dict["optimizer_model_previous_path"] surrogate_save_path = reco_file_paths_dict["surrogate_model_save_path"] optimizer_save_path = reco_file_paths_dict["optimizer_model_save_path"] optimizer_loss_save_path = reco_file_paths_dict["optimizer_loss_save_path"] surrogate_loss_save_path = reco_file_paths_dict["surrogate_loss_save_path"] constraints_loss_save_path = reco_file_paths_dict["constraints_loss_save_path"] parameter_optimizer_savepath = os.path.join(results_dir, "models", "parameter_optimizer_df") # Surrogate parameter_dict = SimulationParameterDictionary.from_json(parameter_dict_input_path) surrogate_df = pd.read_parquet(output_df_path) with (wandb_logger.get_task_logger(task="surrogate") if wandb_logger else nullcontext()) as task_logger: surrogate, surrogate_dataset = train_or_load_surrogate( config, parameter_dict, surrogate_previous_path, surrogate_save_path, surrogate_df, task_logger=task_logger ) # Optimization optimizer = Optimizer(parameter_dict=parameter_dict) if os.path.isfile(optimizer_previous_path): checkpoint = torch.load(optimizer_previous_path) optimizer.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) updated_parameter_dict, is_optimal = optimizer.optimize( surrogate_model=surrogate, dataset=surrogate_dataset, batch_size=config.optimizer.batch_size, n_epochs=config.optimizer.n_epochs, reconstruction_loss=reconstruction_loss_function, additional_constraints=constraints, parameter_optimizer_savepath=parameter_optimizer_savepath, lr=config.optimizer.lr, wandb_logger=wandb_logger ) if not is_optimal: raise RuntimeError else: torch.save({"optimizer_state_dict": optimizer.optimizer.state_dict()}, optimizer_save_path) pd.DataFrame( np.array(surrogate.surrogate_loss), columns=["Surrogate Loss"] ).to_csv(surrogate_loss_save_path) pd.DataFrame( np.array(optimizer.optimizer_loss), columns=["Optimizer Loss"] ).to_csv(optimizer_loss_save_path) pd.DataFrame( np.array(optimizer.constraints_loss), columns=["Constraints Loss"] ).to_csv(constraints_loss_save_path) return updated_parameter_dict