# Created by William Edwards (wre2@illinois.edu), 2021-01-09
# Standard library includes
import sys, os
import pickle
# External library includes
import numpy as np
import matplotlib.animation as animation
# Project includes
from .benchmark import Benchmark
from ..utils.data_generation import *
from .. import System
from ..tasks import Task
from ..costs import BoxThresholdCost,ThresholdCost
def cartpole_simp_dynamics(y, u, g = 9.8, m = 1, L = 1, b = 0.1):
"""
Parameters
----------
y : states
u : control
Returns
-------
A list describing the dynamics of the cart cart pole
"""
theta, omega, x, dx = y
return np.array([omega,
g * np.sin(theta)/L - b * omega / (m*L**2) + u * np.cos(theta)/L,
dx,
u])
def dt_cartpole_dynamics(y,u,dt,g=9.8,m=1,L=1,b=1.0):
y += dt * cartpole_simp_dynamics(y,u[0],g,m,L,b)
return y
[docs]class CartpoleSwingupV2Benchmark(Benchmark):
"""
This benchmark uses the cartpole system and differs from CartpoleSwingupBenchmark
in that the performance metric requires the cartpole to stay within the [-10, 10]
range.
"""
def __init__(self, data_gen_method="uniform_random"):
name = "cartpole_swingup"
system = ampc.System(["theta", "omega", "x", "dx"], ["u"])
system.dt = 0.05
limits = np.array([[-0.2, 0.2], [-0.2, 0.2], [-10.0, 10.0], [-np.inf, np.inf]])
cost = BoxThresholdCost(system, limits, goal=np.zeros(4))
task = Task(system)
task.set_cost(cost)
task.set_ctrl_bound("u", -20.0, 20.0)
init_obs = np.array([3.1, 0.0, 0.0, 0.0])
task.set_init_obs(init_obs)
task.set_num_steps(200)
super().__init__(name, system, task, data_gen_method)
def dynamics(self, x, u):
return dt_cartpole_dynamics(x,u,self.system.dt,g=0.8,m=1,L=1,b=1.0)
[docs] def visualize(self, fig, ax, traj, margin=5.0):
"""
Visualize the cartpole trajectory.
Parameters
----------
fig : matplotlib.figure.Figure
Figure to generate visualization in.
ax : matplotlib.axes.Axes
Axes to create visualization in.
traj : Trajectory
Trajectory to visualize
margin : float
Shift the viewing window by this amount when the
cartpole reaches the edge of the screen
"""
ax.plot([-10000, 10000.0], [0.0, 0.0], "k-", lw=1)
ax.set_xlim([-10.0, 10.0])
ax.set_ylim([-2.0, 2.0])
ax.set_aspect("equal")
dt = self.system.dt
line, = ax.plot([0.0, 0.0], [0.0, -1.0], 'o-', lw=2)
time_text = ax.text(0.02, 0.85, '', transform=ax.transAxes)
ctrl_text = ax.text(0.7, 0.85, '', transform=ax.transAxes)
def init():
line.set_data([0.0, 0.0], [0.0, -1.0])
time_text.set_text('')
return line, time_text
nframes = traj.size + 50
def animate(i):
i %= nframes
i = min(i, traj.size-1)
if i == 0:
ax.set_xlim([-10.0, 10.0])
#i = min(i, ts.shape[0])
line.set_data([traj[i,"x"], traj[i,"x"]+np.sin(traj[i,"theta"]+np.pi)],
[0, -np.cos(traj[i,"theta"] + np.pi)])
time_text.set_text('t={:.2f}'.format(dt*i))
ctrl_text.set_text("u={:.2f}".format(traj[i,"u"]))
xmin, xmax = ax.get_xlim()
if traj[i, "x"] < xmin:
ax.set_xlim([traj[i,"x"] - margin, traj[i,"x"] + 20.0 - margin])
if traj[i, "x"] > xmax:
ax.set_xlim([traj[i,"x"] - 20.0 + margin, traj[i,"x"] + margin])
return line, time_text
anim = animation.FuncAnimation(fig, animate, frames=6*nframes, interval=dt*1000.0,
blit=False, init_func=init)
return anim
def _gen_trajs(self, n_trajs, traj_len, rng):
init_min = np.array([-1.0, 0.0, 0.0, 0.0])
init_max = np.array([1.0, 0.0, 0.0, 0.0])
if self._data_gen_method == "uniform_random":
return uniform_random_generate(self.system, self.task, self.dynamics, rng,
init_min=init_min, init_max=init_max,
traj_len=traj_len, n_trajs=n_trajs)
elif self._data_gen_method == "periodic_control":
return periodic_control_generate(self.system, self.task, self.dynamics, rng,
init_min=init_min, init_max=init_max, U_1=np.ones(1),
traj_len=traj_len, n_trajs=n_trajs)
elif self._data_gen_method == "multisine":
return multisine_generate(self.system, self.task, self.dynamics, rng,
init_min=init_min, init_max=init_max, n_freqs=20,
traj_len=traj_len, n_trajs=n_trajs)
elif self._data_gen_method == "random_walk":
return random_walk_generate(self.system, self.task, self.dynamics, rng,
init_min=init_min, init_max=init_max, walk_rate=1.0,
traj_len=traj_len, n_trajs=n_trajs)
def gen_trajs(self, seed, n_trajs, traj_len=200):
rng = np.random.default_rng(seed)
return self._gen_trajs(n_trajs, traj_len, rng)
def get_cached_tune_result(self):
dirname = os.path.dirname(__file__)
pklname = os.path.join(dirname,
"../../assets/cached_tunes/cartpole_tune_result.pkl")
with open(pklname, "rb") as f:
tune_result = pickle.load(f)
return tune_result
@staticmethod
def data_gen_methods():
return ["uniform_random", "periodic_control", "multisine", "random_walk"]