# Created by William Edwards (wre2@illinois.edu), 2021-01-09
# Standard library includes
import sys
# External library includes
import numpy as np
import matplotlib.animation as animation
# Project includes
from .benchmark import Benchmark
from ..utils.data_generation import *
from .. import System
from ..tasks import Task
from ..costs import ThresholdCost
def cartpole_simp_dynamics(y, u, g = 9.8, m = 1, L = 1, b = 0.1):
"""
Parameters
----------
y : states
u : control
Returns
-------
A list describing the dynamics of the cart cart pole
"""
theta, omega, x, dx = y
return np.array([omega,
g * np.sin(theta)/L - b * omega / (m*L**2) + u * np.cos(theta)/L,
dx,
u])
def dt_cartpole_dynamics(y,u,dt,g=9.8,m=1,L=1,b=1.0):
y += dt * cartpole_simp_dynamics(y,u[0],g,m,L,b)
return y
[docs]class CartpoleSwingupBenchmark(Benchmark):
"""
This benchmark uses the cartpole system and is consistent with the
experiments in the ICRA 2021 paper. The task is to move the pole
from the down position to the up position. The performance metric
returns 1 for every observation which is more than 0.2 away from the goal
in either the angle or angular velocity dimensions, and 0 otherwise.
"""
def __init__(self, data_gen_method="uniform_random"):
name = "cartpole_swingup"
system = ampc.System(["theta", "omega", "x", "dx"], ["u"])
system.dt = 0.05
cost = ThresholdCost(system, goal=np.zeros(4), threshold=0.2, obs_range=(0,3))
task = Task(system)
task.set_cost(cost)
task.set_ctrl_bound("u", -20.0, 20.0)
init_obs = np.array([3.1, 0.0, 0.0, 0.0])
task.set_init_obs(init_obs)
task.set_num_steps(200)
super().__init__(name, system, task, data_gen_method)
def dynamics(self, x, u):
return dt_cartpole_dynamics(x,u,self.system.dt,g=9.8,m=1,L=1,b=1.0)
[docs] def visualize(self, fig, ax, traj, margin=5.0):
"""
Visualize the cartpole trajectory.
Parameters
----------
fig : matplotlib.figure.Figure
Figure to generate visualization in.
ax : matplotlib.axes.Axes
Axes to create visualization in.
traj : Trajectory
Trajectory to visualize
margin : float
Shift the viewing window by this amount when the
cartpole reaches the edge of the screen
"""
ax.plot([-10000, 10000.0], [0.0, 0.0], "k-", lw=1)
ax.set_xlim([-10.0, 10.0])
ax.set_ylim([-2.0, 2.0])
ax.set_aspect("equal")
dt = self.system.dt
line, = ax.plot([0.0, 0.0], [0.0, -1.0], 'o-', lw=2)
time_text = ax.text(0.02, 0.85, '', transform=ax.transAxes)
ctrl_text = ax.text(0.7, 0.85, '', transform=ax.transAxes)
def init():
line.set_data([0.0, 0.0], [0.0, -1.0])
time_text.set_text('')
return line, time_text
nframes = traj.size + 50
def animate(i):
i %= nframes
i = min(i, traj.size-1)
if i == 0:
ax.set_xlim([-10.0, 10.0])
#i = min(i, ts.shape[0])
line.set_data([traj[i,"x"], traj[i,"x"]+np.sin(traj[i,"theta"]+np.pi)],
[0, -np.cos(traj[i,"theta"] + np.pi)])
time_text.set_text('t={:.2f}'.format(dt*i))
ctrl_text.set_text("u={:.2f}".format(traj[i,"u"]))
xmin, xmax = ax.get_xlim()
if traj[i, "x"] < xmin:
ax.set_xlim([traj[i,"x"] - margin, traj[i,"x"] + 20.0 - margin])
if traj[i, "x"] > xmax:
ax.set_xlim([traj[i,"x"] - 20.0 + margin, traj[i,"x"] + margin])
return line, time_text
anim = animation.FuncAnimation(fig, animate, frames=6*nframes, interval=dt*1000.0,
blit=False, init_func=init)
return anim
def _gen_trajs(self, n_trajs, traj_len, rng):
init_min = np.array([-1.0, 0.0, 0.0, 0.0])
init_max = np.array([1.0, 0.0, 0.0, 0.0])
if self._data_gen_method == "uniform_random":
return uniform_random_generate(self.system, self.task, self.dynamics, rng,
init_min=init_min, init_max=init_max,
traj_len=traj_len, n_trajs=n_trajs)
elif self._data_gen_method == "periodic_control":
return periodic_control_generate(self.system, self.task, self.dynamics, rng,
init_min=init_min, init_max=init_max, U_1=np.ones(1),
traj_len=traj_len, n_trajs=n_trajs)
elif self._data_gen_method == "multisine":
return multisine_generate(self.system, self.task, self.dynamics, rng,
init_min=init_min, init_max=init_max, n_freqs=20,
traj_len=traj_len, n_trajs=n_trajs)
elif self._data_gen_method == "random_walk":
return random_walk_generate(self.system, self.task, self.dynamics, rng,
init_min=init_min, init_max=init_max, walk_rate=1.0,
traj_len=traj_len, n_trajs=n_trajs)
def gen_trajs(self, seed, n_trajs, traj_len=200):
rng = np.random.default_rng(seed)
return self._gen_trajs(n_trajs, traj_len, rng)
@staticmethod
def data_gen_methods():
return ["uniform_random", "periodic_control", "multisine", "random_walk"]