Source code for aido.config

import json
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, Tuple, Self

from aido.logger import logger


[docs] @dataclass class OptimizerConfig: lr: float = 0.02 batch_size: int = 512 n_epochs: int = 40
[docs] @dataclass class SurrogateConfig: n_epoch_pre: int = 24 n_epochs_main: int = 40
[docs] @dataclass class SimulationConfig: generate_scaling: float = 1.2 sigma: float = 1.5 sigma_mode: str = "flat"
[docs] @dataclass class SchedulerConfig: training_num_retries: int = 20 training_delay_between_retries: int | float = 60
[docs] @dataclass class AIDOConfig: """ Config Dataclass for storing the internal parameter such as the hyperparameters of the different models and the way new values are sampled for the Simulation Task. This class is serializable to and from json. In order to be picked up by the AIDO scheduler, the json file with updated values must be placed in the AIDO root directory. Default fields: - Optimizer: - optimizer.lr: float = 0.02 (>0) - optimizer.batch_size: int = 512 - optimizer.n_epochs: int = 40 - Surrogate: - surrogate.n_epoch_pre: int = 24 - surrogate.n_epochs_main: int = 40 - Simulation: - simulation.generate_scaling: float = 1.2 (>0) - simulation.sigma: float = 1.5 (>0) - simulation.sigma_mode: str = "flat" (or "scale") - Scheduler: - scheduler.training_num_retries: int = 20 - scheduler.training_delay_between_retries: int | float = 60 (in seconds) """ optimizer: OptimizerConfig = field(default_factory=OptimizerConfig) surrogate: SurrogateConfig = field(default_factory=SurrogateConfig) simulation: SimulationConfig = field(default_factory=SimulationConfig) scheduler: SchedulerConfig = field(default_factory=SchedulerConfig)
[docs] @classmethod def from_json(cls, file_path: str) -> Self: """Create a new instance from a json file Args: file_path (str): The input file path Returns: AIDOConfig: New instance of this class Raises: Warning: If the file could not be found, a warning is displayed instead of raising a FileNotFoundError. This is to ensure that the code can still run despite an invalid config. Warning: If any Exception is raised while reading in the json file, e.g. if it is invalid or wrongly formatted, a warning is displayed. Will show the Error message for debugging. """ try: with open(file_path, "r") as file: data = json.load(file) except FileNotFoundError: logger.warning(f"Config file {file_path} not found. Using default configuration.") return cls() except Exception as e: logger.error(f"Error reading config file {file_path}. Using default configuration.\nError: {e}") return cls() return cls( optimizer=OptimizerConfig(**data["optimizer"]), surrogate=SurrogateConfig(**data["surrogate"]), simulation=SimulationConfig(**data["simulation"]), scheduler=SchedulerConfig(**data["scheduler"]) )
[docs] def to_json(self, file_path: str) -> None: """Write the current values to a json file Args: file_path (str): The output file path """ with open(file_path, "w") as file: json.dump(self.as_dict(), file, indent=4)
[docs] def get_key(self, key: str) -> Tuple[Self, Any]: """Helper method to find the subclass that corresponds to a given dot-separated name. Args: key (str): A dot-separated name. For example `"optimizer.lr"`. Returns: tuple: A tuple of the sub-class together with the attribute name. Example: >>> AIDOConfig.get_key("optimizer.lr") (OptimizerConfig(lr=0.02, batch_size=512, n_epochs=40), 'lr') Where the first entry is an instance of the :class:`OptimizerConfig` subclass and the second entry is the key that can be used to access that attribute. >>> getattr(OptimizerConfig, "lr") 0.02 """ keys = key.split(".") current_config_subclass = self for k in keys[:-1]: current_config_subclass = getattr(current_config_subclass, k) return current_config_subclass, keys[-1]
[docs] def set_value(self, key: str, value: Any) -> None: """Change the value of a field by providing a dot-separated key and value Args: key (str): dot-separated name, for example `"optimizer.lr"` will adjust the attribute `"lr"` of the sub-class :class:`OptimizerConfig`. value (Any): The updated value """ setattr(*self.get_key(key), value)
[docs] def get_value(self, key: str) -> Any: """Get the value from one of the subclasses Args: key (str): A dot-separated name Returns: Any: The attribute value corresponding to the key """ return getattr(*self.get_key(key))
[docs] def __getitem__(self, key: str) -> Any: """See :meth:`AIDOConfig.get_value`""" return self.get_value(key)
[docs] def from_dict(self, new_dict: dict) -> None: """Update values from this instance with values from the provided dict. Args: new_dict (dict): A dictionary with the new values, with the keys being in the dot-separated format used by :meth:`AIDOConfig.set_value` """ for key, value in new_dict.items(): self.set_value(key, value)
[docs] def as_dict(self) -> Dict: """Return all values from this Config class as a dict Returns: dict: Nested dictionary with subclasses also being dicts """ return asdict(self)
if __name__ == "__main__": """ Use as script to reset the values of the config.json file to their defaults """ AIDOConfig().to_json("config.json")