benchmarks package¶
Base Classes¶
Benchmark¶
- class autompc.benchmarks.Benchmark(name, system, task, data_gen_method)[source]¶
Represents a Benchmark for testing AutoMPC, including the sytem, task, and a method of generating data.
- abstract dynamics(x, u)[source]¶
Benchmark dynamics
- Parameters:
x (np array of size self.system.obs_dim) – Current observation
u (np array of size self.system.ctrl_dim) – Control input
- Returns:
xnew – New observation.
- Return type:
np array of size self.system.obs_dim
- abstract gen_trajs(seed, n_trajs, traj_len=None)[source]¶
Generate trajectories.
- Parameters:
seed (int) – Seed for trajectory generation
n_trajs (int) – Number of trajectories to generate
traj_len (int) – Length of trajectories to generate. Default varies by benchmark.
- Returns:
Benchmark training set
- Return type:
List of Trajectory
Available Benchmarks¶
CartpoleSwingupBenchmark¶
- class autompc.benchmarks.CartpoleSwingupBenchmark(data_gen_method='uniform_random')[source]¶
This benchmark uses the cartpole system and is consistent with the experiments in the ICRA 2021 paper. The task is to move the pole from the down position to the up position. The performance metric returns 1 for every observation which is more than 0.2 away from the goal in either the angle or angular velocity dimensions, and 0 otherwise.
- visualize(fig, ax, traj, margin=5.0)[source]¶
Visualize the cartpole trajectory.
- Parameters:
fig (matplotlib.figure.Figure) – Figure to generate visualization in.
ax (matplotlib.axes.Axes) – Axes to create visualization in.
traj (Trajectory) – Trajectory to visualize
margin (float) – Shift the viewing window by this amount when the cartpole reaches the edge of the screen
CartpoleSwingupV2Benchmark¶
- class autompc.benchmarks.CartpoleSwingupV2Benchmark(data_gen_method='uniform_random')[source]¶
This benchmark uses the cartpole system and differs from CartpoleSwingupBenchmark in that the performance metric requires the cartpole to stay within the [-10, 10] range.
- visualize(fig, ax, traj, margin=5.0)[source]¶
Visualize the cartpole trajectory.
- Parameters:
fig (matplotlib.figure.Figure) – Figure to generate visualization in.
ax (matplotlib.axes.Axes) – Axes to create visualization in.
traj (Trajectory) – Trajectory to visualize
margin (float) – Shift the viewing window by this amount when the cartpole reaches the edge of the screen
HalfcheetahBenchmark¶
- class autompc.benchmarks.halfcheetah.HalfcheetahBenchmark(data_gen_method='uniform_random')[source]¶
This benchmark uses the OpenAI gym halfcheetah benchmark and is consistent with the experiments in the ICRA 2021 paper. The benchmark reuqires OpenAI gym and mujoco_py to be installed. The performance metric is \(200-R\) where \(R\) is the gym reward.
- visualize(traj, repeat)[source]¶
Visualize the half-cheetah trajectory using Gym functions.
- Parameters:
traj (Trajectory) – Trajectory to visualize
repeat (int) – Number of times to repeat trajectory in visualization