Source code for autompc.costs.sum_cost_factory

# Created by William Edwards (wre2@illinois.edu), 2021-01-24

# Standard library includes
from pdb import set_trace

# Internal library includes
from .cost_factory import CostFactory
from .sum_cost import SumCost
from ..utils.cs_utils import *
from . import QuadCost

# External library includes
import numpy as np
import ConfigSpace as CS
import ConfigSpace.hyperparameters as CSH
import ConfigSpace.conditions as CSC

[docs]class SumCostFactory(CostFactory): """ A factory which produces sums of other cost terms. A SumCostFactory can be crated by combining other costfactories with the `+` operator. """ def __init__(self, system, factories): super().__init__(system) self._factories = factories[:] @property def factories(self): return self._factories[:] def get_configuration_space(self, *args, **kwargs): cs = CS.ConfigurationSpace() for i, factory in enumerate(self.factories): _fact_cs = factory.get_configuration_space(*args, **kwargs) add_configuration_space(cs,"_sum_{}".format(i), _fact_cs) return cs def is_compatible(self, *args, **kwargs): for factory in self.factories: if not factory.is_compatible(*args, **kwargs): return False return True def __call__(self, cfg, task, trajs): costs = [] for i, factory in enumerate(self.factories): fact_cs = factory.get_configuration_space() fact_cfg = fact_cs.get_default_configuration() set_subspace_configuration(cfg, "_sum_{}".format(i), fact_cfg) cost = factory(fact_cfg, task, trajs) costs.append(cost) return sum(costs, SumCost(self.system, [])) def __add__(self, other): if isinstance(other, SumCostFactory): return SumCostFactory([*self.factories, *other.factories]) else: return SumCostFactory([*self.factories, other]) def __radd__(self, other): if isinstance(other, SumCostFactory): return SumCostFactory([*other.factories, *self.factories]) else: return SumCostFactory([other, *self.factories])