solarcarsim/src/solarcarsim/simv1.py
2024-12-20 16:58:51 -06:00

126 lines
4.7 KiB
Python

import gymnasium as gym
import solarcarsim.physsim as sim
import jax
import jax.numpy as jnp
from jax import jit
from functools import partial
@partial(jit, static_argnames=["params"])
def forward(state, control, delta_time, wind, elevation, slope, params: sim.CarParams):
pos = jnp.astype(jnp.round(state[0]), "int32")
time = jnp.astype(jnp.round(state[1]), "int32")
theta = slope[pos]
velocity = control * params.max_speed
# sum up the forces acting on the car
windspeed = wind[pos, time]
dragf = sim.drag_force(velocity + windspeed, params.frontal_area, params.drag_coeff, 1.184)
rollf = sim.rolling_force(params.mass, theta, params.rolling_coeff)
hillforce = sim.downslope_force(params.mass, theta)
totalf = dragf + rollf + hillforce
# with the sum of forces, determine the needed torque at the wheels, and then power
tau = params.wheel_radius * totalf
pdraw = sim.bldc_power_draw(tau, velocity, params.motor)
# determine the energy needed to do this power for the time step
net_power = state[2] - delta_time * pdraw # joules
dpos = state[0] + jnp.cos(theta) * velocity * delta_time
new_pos = jnp.maximum(dpos, 0)
dist_remaining = 10000.0 - (state[0] + dpos)
time_remaining = 600 - (state[1] + delta_time)
return jnp.array(
[new_pos, state[1] + delta_time, net_power, dist_remaining, time_remaining]
)
def reward(state, prev_state):
reward = 0
reward += state[0]/8000
return reward
class SolarRaceV1(gym.Env):
"""A primitive hill climber. Aims to solve the given route optimizing
for energy usage and on-time arrival.
"""
# these are some simulator helpers
def _reset_sim(self, key):
self._environment = sim.make_environment(key)
# self._state = jnp.array([np.array([x], dtype="float32") for x in (0,0,0, 10000.0, 600.0)])
self._state = jnp.array([[0],[0],[0],[10000.0], [600.0]])
# self._state = jnp.array([0, 0,0,10000.0, 600.0])
def _vision_function(self):
# extract the vision results.
def slookup(x):
return jax.lax.dynamic_slice(self._environment[0], x, (100,100))
pos = jnp.astype(jnp.round(self._state[0]), "int32")
time = jnp.astype(jnp.round(self._state[1]), "int32")
wind_view = slookup(jnp.hstack([pos,time]))
slope_view = jax.lax.dynamic_slice(self._environment[2], pos, (100,))
return slope_view, wind_view
def _get_obs(self):
slope_view, wind_view = self._vision_function()
return {
"position": self._state[0],
"time": self._state[1],
"energy": self._state[2],
"dist_remaining": self._state[3],
"time_remaining": self._state[4],
"terrain": slope_view,
"wind": wind_view,
}
def __init__(self, car: sim.CarParams = sim.CarParams(), timestep: float = 1.0, seed=1234):
self._reset_sim(jax.random.key(seed))
self._timestep = timestep
self._car = car
self._simstep = forward
self._simreward = reward
self.observation_space = gym.spaces.Dict(
{
"position": gym.spaces.Box(-100, 10100.0, shape=(1,)),
"time": gym.spaces.Box(0, 1000.0),
"energy": gym.spaces.Box(-1.0e6, 0.0),
"dist_remaining": gym.spaces.Box(0.0, 10100.0),
"time_remaining": gym.spaces.Box(0.0, 600.0),
# This is the window into the future/ahead spatially.
"terrain": gym.spaces.Box(-1.0, 1.0, shape=(100,)), # slope
"wind": gym.spaces.Box(-10.0, 10.0, shape=(100, 100)),
}
)
self.action_space = gym.spaces.Box(0.0, 1.0, shape=(1,)) # velocity, m/s
def reset(self, *, seed = None, options = None):
self._reset_sim(jax.random.key(seed or 0))
super().reset(seed=seed, options=options)
return self._get_obs(), {}
def step(self, action):
wind, elevation, slope = self._environment
old_state = self._state
self._state = self._simstep(self._state, action, self._timestep,wind, elevation, slope, self._car)
reward = self._simreward(self._state, old_state)[0]
terminated = False
truncated = False
if jnp.all(self._state[0] > 8000):
reward += 500
terminated = True
# we want the time to be as close to 600 as possible
reward -= 600 - self._state[1][0]
reward += 1e-6 * (self._state[2][0]) # net energy is negative.
if jnp.all(self._state[1] > 600):
reward -= 500
truncated = True
return self._get_obs(), reward, terminated, truncated, {}