solarcarsim/src/solarcarsim/simv2.py
2024-12-20 16:58:51 -06:00

184 lines
7 KiB
Python

"""Second-generation simulator. More functional, cleaner code, faster"""
from typing import NamedTuple, Optional, Tuple, Union, Dict, Any
import jax
import jax.numpy as jnp
import chex
from flax import struct
from jax import lax, vmap
from gymnax.environments import environment
from gymnax.environments import spaces
from solarcarsim.physsim import CarParams, fractal_noise_1d
import solarcarsim.physsim as sim
@struct.dataclass
class SimState(environment.EnvState):
position: jnp.ndarray
velocity: jnp.ndarray
realtime: jnp.ndarray
energy: jnp.ndarray
# distance_remaining: jnp.ndarray
# time_remaining: jnp.ndarray
slope: jnp.ndarray
@struct.dataclass
class SimParams(environment.EnvParams):
car: CarParams = CarParams()
goal_time: int = 600
goal_dist: int = 8000
map_size: int = 10000
time_step: float = 1.0
terrain_lookahead: int = 100
# skip wind for now
class Snax(environment.Environment[SimState, SimParams]):
"""JAX version of the solar race simulator"""
@property
def default_params(self) -> SimParams:
return SimParams()
def action_space(self, params: Optional[SimParams] = None):
return spaces.Box(low=-1.0, high=1.0, shape=(1,))
def observation_space(self, params: Optional[SimParams] = None) -> spaces.Box:
if params is None:
params = self.default_params
# needs to be a box. it will be [pos, time, energy, dist_to_goal, time_remaining, terrain0, terrain1]
shape = 5 + params.terrain_lookahead
low = jnp.array(
[0, 0, -1e11, 0, 0] + [-1.0] * params.terrain_lookahead, dtype=jnp.float32
)
high = jnp.array(
[params.map_size, params.goal_time, 0, params.goal_dist, params.goal_time]
+ [1.0] * params.terrain_lookahead,
dtype=jnp.float32,
)
return spaces.Box(low, high, shape=(shape,))
def state_space(self, params: Optional[SimParams] = None) -> spaces.Dict:
if params is None:
params = self.default_params
return spaces.Dict(
{
"position": spaces.Box(0.0, params.map_size, (), jnp.float32),
"realtime": spaces.Box(0.0, params.goal_time + 100, (), jnp.float32),
"energy": spaces.Box(-1e11, 0.0, (), jnp.float32),
# "dist_to_goal": spaces.Box(0.0, params.goal_dist, (), jnp.float32),
# "time_remaining": spaces.Box(0.0, params.goal_time, (), jnp.float32),
"slope": spaces.Box(
-1.0, 1.0, shape=(params.map_size,), dtype=jnp.float32
),
"time": spaces.Discrete(int(params.goal_time / params.time_step)),
}
)
def reset_env(
self, key: chex.PRNGKey, params: Optional[SimParams] = None
) -> Tuple[chex.Array, SimState]:
if params is None:
params = self.default_params
slope = fractal_noise_1d(key, 10000, scale=1200, height_scale=0.08)
init_state = SimState(
position=jnp.array(0.0),
velocity=jnp.array(0.0),
time=0,
realtime=jnp.array(0.0),
energy=jnp.array(0.0),
# distance_remaining=jnp.array(params.goal_dist),
# time_remaining=jnp.array(params.goal_time),
slope=slope,
)
return self.get_obs(init_state, key, params), init_state
def get_obs(
self, state: SimState, key: chex.PRNGKey, params: SimParams
) -> chex.Array:
if params is None:
params = self.default_params
# get rounded position from state
pos_int = jnp.astype(state.position, jnp.int32)
terrain_view = jax.lax.dynamic_slice(state.slope, (pos_int,), (100,))
dist_to_goal = jnp.abs(params.goal_dist - state.position)
time_remaining = jnp.abs(params.goal_time - state.realtime)
main_state = jnp.array(
[state.position, state.realtime, state.energy, dist_to_goal, time_remaining]
)
return jnp.concat([main_state, terrain_view]).squeeze()
def step_env(
self,
key: chex.PRNGKey,
state: SimState,
action: Union[int, float, chex.Array],
params: SimParams,
) -> Tuple[chex.Array, SimState, jnp.ndarray, jnp.ndarray, Dict[Any, Any]]:
pos = jnp.astype(state.position, jnp.int32)
theta = state.slope[pos]
velocity = jnp.array([action * params.car.max_speed]).squeeze()
dragf = sim.drag_force(
velocity, params.car.frontal_area, params.car.drag_coeff, 1.184
)
rollf = sim.rolling_force(params.car.mass, theta, params.car.rolling_coeff)
hillf = sim.downslope_force(params.car.mass, theta)
total_f = dragf + rollf + hillf
tau = params.car.wheel_radius * total_f / params.car.n_motors
p_draw = (
sim.bldc_power_draw(tau, velocity, params.car.motor) * params.car.n_motors
)
new_energy = state.energy - params.time_step * p_draw
new_position = state.position + jnp.cos(theta) * velocity * params.time_step
new_state = SimState(
position=new_position.squeeze(),
velocity=velocity.squeeze(),
realtime=state.realtime + params.time_step,
energy=new_energy.squeeze(),
slope=state.slope,
time=state.time + 1,
)
# compute reward
# reward = new_state.position / params.goal_dist
# if new_state.position >= params.goal_dist:
# reward += 100
# reward += params.goal_time - new_state.realtime
# # penalize energy use
# reward += 1e-7 * new_state.energy # energy is negative
# if (
# new_state.realtime >= params.goal_time
# or new_state.time > params.max_steps_in_episode
# ):
# reward -= 500
# # we have to vectorize that.
# reward = new_state.position / params.goal_dist # constant reward for moving forward
# # reward for finishing
# reward += (new_state.position >= params.goal_dist) * (100 + params.goal_time - new_state.realtime + 1e-7*new_state.energy)
# # reward for failure
# reward += (new_state.realtime >= params.goal_time) * -500
reward = new_state.position / params.goal_dist + \
(new_state.position >= params.goal_dist) * (100 + params.goal_time - new_state.realtime + 1e-6*new_state.energy) + \
(new_state.realtime >= params.goal_time) * -500
reward = reward.squeeze()
terminal = self.is_terminal(state, params)
return (
lax.stop_gradient(self.get_obs(new_state, key, params)),
lax.stop_gradient(new_state),
reward,
terminal,
{},
)
def is_terminal(self, state: SimState, params: SimParams) -> jnp.ndarray:
finish = state.position >= params.goal_dist
timeout = state.time >= params.max_steps_in_episode
return jnp.logical_or(finish, timeout).squeeze()