Source code for autompc.costs.sum_cost
# Created by William Edwards (wre2@illinois.edu)
from collections.abc import Iterable
import numpy as np
from .cost import Cost
[docs]class SumCost(Cost):
[docs] def __init__(self, system, costs):
"""
A cost which is the sum of other cost terms. It can be created by combining
other Cost objects with the `+` operator
Parameters
----------
system : System
System for the cost object.
costs : List of Costs
Cost objects to be summed.
"""
super().__init__(system)
self._costs = costs
@property
def costs(self):
return self._costs[:]
def get_cost_matrices(self):
if self.is_quad:
Q = np.zeros((self.system.obs_dim, self.system.obs_dim))
F = np.zeros((self.system.obs_dim, self.system.obs_dim))
R = np.zeros((self.system.ctrl_dim, self.system.ctrl_dim))
for cost in self._costs:
Q_, R_, F_ = cost.get_cost_matrices()
Q += Q_
R += R_
F += F_
return Q, R, F
else:
raise NotImplementedError
def get_goal(self):
if self.has_goal:
return self.costs[0]
def _sum_results(self, arg, attr):
results = [getattr(cost, attr)(arg) for cost in self.costs]
if isinstance(results[0], Iterable):
return [sum(vals) for vals in zip(*results)]
else:
return sum(results)
def eval_obs_cost(self, obs):
return self._sum_results(obs, "eval_obs_cost")
def eval_obs_cost_diff(self, obs):
return self._sum_results(obs, "eval_obs_cost_diff")
def eval_obs_cost_hess(self, obs):
return self._sum_results(obs, "eval_obs_cost_hess")
def eval_ctrl_cost(self, ctrl):
return self._sum_results(ctrl, "eval_ctrl_cost")
def eval_ctrl_cost_diff(self, ctrl):
return self._sum_results(ctrl, "eval_ctrl_cost_diff")
def eval_ctrl_cost_hess(self, ctrl):
return self._sum_results(ctrl, "eval_ctrl_cost_hess")
def eval_term_obs_cost(self, obs):
return self._sum_results(obs, "eval_term_obs_cost")
def eval_term_obs_cost_diff(self, obs):
return self._sum_results(obs, "eval_term_obs_cost_diff")
def eval_term_obs_cost_hess(self, obs):
return self._sum_results(obs, "eval_term_obs_cost_hess")
@property
def is_quad(self):
if not self.costs[0].is_quad:
return False
goal = self.costs[0].get_goal()
for cost in self.costs[1:]:
if not cost.is_quad:
return False
if not (goal == cost.get_goal()).all():
return False
return True
@property
def is_convex(self):
for cost in self.costs:
if not cost.is_convex:
return False
return True
@property
def is_diff(self):
for cost in self.costs:
if not cost.is_diff:
return False
return True
@property
def is_twice_diff(self):
for cost in self.costs:
if not cost.is_diff:
return False
return True
@property
def has_goal(self):
if not self.costs[0].has_goal:
return False
goal = self.costs[0].get_goal()
for cost in self.costs[1:]:
if not cost.has_goal:
return False
if not (goal == cost.get_goal()).all():
return False
return True
def __add__(self, other):
if isinstance(other, SumCost):
return SumCost(self.system, [*self.costs, *other.costs])
else:
return SumCost(self.system, [*self.costs, other])
def __radd__(self, other):
if isinstance(other, SumCost):
return SumCost(self.system, [*other.costs, *self.costs])
else:
return SumCost(self.system, [other, *self.costs])