126 lines
4.7 KiB
Python
126 lines
4.7 KiB
Python
import gymnasium as gym
|
|
import solarcarsim.physsim as sim
|
|
import jax
|
|
import jax.numpy as jnp
|
|
from jax import jit
|
|
from functools import partial
|
|
|
|
@partial(jit, static_argnames=["params"])
|
|
def forward(state, control, delta_time, wind, elevation, slope, params: sim.CarParams):
|
|
pos = jnp.astype(jnp.round(state[0]), "int32")
|
|
time = jnp.astype(jnp.round(state[1]), "int32")
|
|
theta = slope[pos]
|
|
|
|
velocity = control * params.max_speed
|
|
|
|
# sum up the forces acting on the car
|
|
windspeed = wind[pos, time]
|
|
dragf = sim.drag_force(velocity + windspeed, params.frontal_area, params.drag_coeff, 1.184)
|
|
rollf = sim.rolling_force(params.mass, theta, params.rolling_coeff)
|
|
hillforce = sim.downslope_force(params.mass, theta)
|
|
totalf = dragf + rollf + hillforce
|
|
# with the sum of forces, determine the needed torque at the wheels, and then power
|
|
tau = params.wheel_radius * totalf
|
|
pdraw = sim.bldc_power_draw(tau, velocity, params.motor)
|
|
# determine the energy needed to do this power for the time step
|
|
net_power = state[2] - delta_time * pdraw # joules
|
|
|
|
dpos = state[0] + jnp.cos(theta) * velocity * delta_time
|
|
new_pos = jnp.maximum(dpos, 0)
|
|
dist_remaining = 10000.0 - (state[0] + dpos)
|
|
time_remaining = 600 - (state[1] + delta_time)
|
|
return jnp.array(
|
|
[new_pos, state[1] + delta_time, net_power, dist_remaining, time_remaining]
|
|
)
|
|
|
|
|
|
def reward(state, prev_state):
|
|
reward = 0
|
|
reward += state[0]/8000
|
|
return reward
|
|
|
|
class SolarRaceV1(gym.Env):
|
|
"""A primitive hill climber. Aims to solve the given route optimizing
|
|
for energy usage and on-time arrival.
|
|
"""
|
|
|
|
# these are some simulator helpers
|
|
def _reset_sim(self, key):
|
|
self._environment = sim.make_environment(key)
|
|
# self._state = jnp.array([np.array([x], dtype="float32") for x in (0,0,0, 10000.0, 600.0)])
|
|
self._state = jnp.array([[0],[0],[0],[10000.0], [600.0]])
|
|
# self._state = jnp.array([0, 0,0,10000.0, 600.0])
|
|
def _vision_function(self):
|
|
# extract the vision results.
|
|
def slookup(x):
|
|
return jax.lax.dynamic_slice(self._environment[0], x, (100,100))
|
|
pos = jnp.astype(jnp.round(self._state[0]), "int32")
|
|
time = jnp.astype(jnp.round(self._state[1]), "int32")
|
|
wind_view = slookup(jnp.hstack([pos,time]))
|
|
slope_view = jax.lax.dynamic_slice(self._environment[2], pos, (100,))
|
|
return slope_view, wind_view
|
|
|
|
def _get_obs(self):
|
|
slope_view, wind_view = self._vision_function()
|
|
return {
|
|
"position": self._state[0],
|
|
"time": self._state[1],
|
|
"energy": self._state[2],
|
|
"dist_remaining": self._state[3],
|
|
"time_remaining": self._state[4],
|
|
"terrain": slope_view,
|
|
"wind": wind_view,
|
|
}
|
|
|
|
def __init__(self, car: sim.CarParams = sim.CarParams(), timestep: float = 1.0, seed=1234):
|
|
|
|
self._reset_sim(jax.random.key(seed))
|
|
self._timestep = timestep
|
|
self._car = car
|
|
self._simstep = forward
|
|
self._simreward = reward
|
|
|
|
self.observation_space = gym.spaces.Dict(
|
|
{
|
|
"position": gym.spaces.Box(-100, 10100.0, shape=(1,)),
|
|
"time": gym.spaces.Box(0, 1000.0),
|
|
"energy": gym.spaces.Box(-1.0e6, 0.0),
|
|
"dist_remaining": gym.spaces.Box(0.0, 10100.0),
|
|
"time_remaining": gym.spaces.Box(0.0, 600.0),
|
|
# This is the window into the future/ahead spatially.
|
|
"terrain": gym.spaces.Box(-1.0, 1.0, shape=(100,)), # slope
|
|
"wind": gym.spaces.Box(-10.0, 10.0, shape=(100, 100)),
|
|
}
|
|
)
|
|
|
|
self.action_space = gym.spaces.Box(0.0, 1.0, shape=(1,)) # velocity, m/s
|
|
|
|
|
|
def reset(self, *, seed = None, options = None):
|
|
self._reset_sim(jax.random.key(seed or 0))
|
|
super().reset(seed=seed, options=options)
|
|
return self._get_obs(), {}
|
|
|
|
def step(self, action):
|
|
wind, elevation, slope = self._environment
|
|
|
|
old_state = self._state
|
|
|
|
self._state = self._simstep(self._state, action, self._timestep,wind, elevation, slope, self._car)
|
|
reward = self._simreward(self._state, old_state)[0]
|
|
terminated = False
|
|
truncated = False
|
|
if jnp.all(self._state[0] > 8000):
|
|
reward += 500
|
|
terminated = True
|
|
# we want the time to be as close to 600 as possible
|
|
reward -= 600 - self._state[1][0]
|
|
reward += 1e-6 * (self._state[2][0]) # net energy is negative.
|
|
if jnp.all(self._state[1] > 600):
|
|
reward -= 500
|
|
truncated = True
|
|
|
|
return self._get_obs(), reward, terminated, truncated, {}
|
|
|
|
|