184 lines
7 KiB
Python
184 lines
7 KiB
Python
"""Second-generation simulator. More functional, cleaner code, faster"""
|
|
|
|
from typing import NamedTuple, Optional, Tuple, Union, Dict, Any
|
|
import jax
|
|
import jax.numpy as jnp
|
|
import chex
|
|
from flax import struct
|
|
from jax import lax, vmap
|
|
from gymnax.environments import environment
|
|
from gymnax.environments import spaces
|
|
|
|
from solarcarsim.physsim import CarParams, fractal_noise_1d
|
|
import solarcarsim.physsim as sim
|
|
|
|
|
|
@struct.dataclass
|
|
class SimState(environment.EnvState):
|
|
position: jnp.ndarray
|
|
velocity: jnp.ndarray
|
|
realtime: jnp.ndarray
|
|
energy: jnp.ndarray
|
|
# distance_remaining: jnp.ndarray
|
|
# time_remaining: jnp.ndarray
|
|
slope: jnp.ndarray
|
|
|
|
|
|
@struct.dataclass
|
|
class SimParams(environment.EnvParams):
|
|
car: CarParams = CarParams()
|
|
goal_time: int = 600
|
|
goal_dist: int = 8000
|
|
map_size: int = 10000
|
|
time_step: float = 1.0
|
|
terrain_lookahead: int = 100
|
|
# skip wind for now
|
|
|
|
|
|
class Snax(environment.Environment[SimState, SimParams]):
|
|
"""JAX version of the solar race simulator"""
|
|
|
|
@property
|
|
def default_params(self) -> SimParams:
|
|
return SimParams()
|
|
|
|
def action_space(self, params: Optional[SimParams] = None):
|
|
return spaces.Box(low=-1.0, high=1.0, shape=(1,))
|
|
|
|
def observation_space(self, params: Optional[SimParams] = None) -> spaces.Box:
|
|
if params is None:
|
|
params = self.default_params
|
|
# needs to be a box. it will be [pos, time, energy, dist_to_goal, time_remaining, terrain0, terrain1]
|
|
shape = 5 + params.terrain_lookahead
|
|
low = jnp.array(
|
|
[0, 0, -1e11, 0, 0] + [-1.0] * params.terrain_lookahead, dtype=jnp.float32
|
|
)
|
|
high = jnp.array(
|
|
[params.map_size, params.goal_time, 0, params.goal_dist, params.goal_time]
|
|
+ [1.0] * params.terrain_lookahead,
|
|
dtype=jnp.float32,
|
|
)
|
|
return spaces.Box(low, high, shape=(shape,))
|
|
|
|
def state_space(self, params: Optional[SimParams] = None) -> spaces.Dict:
|
|
if params is None:
|
|
params = self.default_params
|
|
return spaces.Dict(
|
|
{
|
|
"position": spaces.Box(0.0, params.map_size, (), jnp.float32),
|
|
"realtime": spaces.Box(0.0, params.goal_time + 100, (), jnp.float32),
|
|
"energy": spaces.Box(-1e11, 0.0, (), jnp.float32),
|
|
# "dist_to_goal": spaces.Box(0.0, params.goal_dist, (), jnp.float32),
|
|
# "time_remaining": spaces.Box(0.0, params.goal_time, (), jnp.float32),
|
|
"slope": spaces.Box(
|
|
-1.0, 1.0, shape=(params.map_size,), dtype=jnp.float32
|
|
),
|
|
"time": spaces.Discrete(int(params.goal_time / params.time_step)),
|
|
}
|
|
)
|
|
|
|
def reset_env(
|
|
self, key: chex.PRNGKey, params: Optional[SimParams] = None
|
|
) -> Tuple[chex.Array, SimState]:
|
|
if params is None:
|
|
params = self.default_params
|
|
slope = fractal_noise_1d(key, 10000, scale=1200, height_scale=0.08)
|
|
init_state = SimState(
|
|
position=jnp.array(0.0),
|
|
velocity=jnp.array(0.0),
|
|
time=0,
|
|
realtime=jnp.array(0.0),
|
|
energy=jnp.array(0.0),
|
|
# distance_remaining=jnp.array(params.goal_dist),
|
|
# time_remaining=jnp.array(params.goal_time),
|
|
slope=slope,
|
|
)
|
|
return self.get_obs(init_state, key, params), init_state
|
|
|
|
def get_obs(
|
|
self, state: SimState, key: chex.PRNGKey, params: SimParams
|
|
) -> chex.Array:
|
|
if params is None:
|
|
params = self.default_params
|
|
|
|
# get rounded position from state
|
|
pos_int = jnp.astype(state.position, jnp.int32)
|
|
|
|
terrain_view = jax.lax.dynamic_slice(state.slope, (pos_int,), (100,))
|
|
dist_to_goal = jnp.abs(params.goal_dist - state.position)
|
|
time_remaining = jnp.abs(params.goal_time - state.realtime)
|
|
main_state = jnp.array(
|
|
[state.position, state.realtime, state.energy, dist_to_goal, time_remaining]
|
|
)
|
|
return jnp.concat([main_state, terrain_view]).squeeze()
|
|
|
|
def step_env(
|
|
self,
|
|
key: chex.PRNGKey,
|
|
state: SimState,
|
|
action: Union[int, float, chex.Array],
|
|
params: SimParams,
|
|
) -> Tuple[chex.Array, SimState, jnp.ndarray, jnp.ndarray, Dict[Any, Any]]:
|
|
pos = jnp.astype(state.position, jnp.int32)
|
|
theta = state.slope[pos]
|
|
velocity = jnp.array([action * params.car.max_speed]).squeeze()
|
|
dragf = sim.drag_force(
|
|
velocity, params.car.frontal_area, params.car.drag_coeff, 1.184
|
|
)
|
|
rollf = sim.rolling_force(params.car.mass, theta, params.car.rolling_coeff)
|
|
hillf = sim.downslope_force(params.car.mass, theta)
|
|
total_f = dragf + rollf + hillf
|
|
tau = params.car.wheel_radius * total_f / params.car.n_motors
|
|
p_draw = (
|
|
sim.bldc_power_draw(tau, velocity, params.car.motor) * params.car.n_motors
|
|
)
|
|
|
|
new_energy = state.energy - params.time_step * p_draw
|
|
new_position = state.position + jnp.cos(theta) * velocity * params.time_step
|
|
new_state = SimState(
|
|
position=new_position.squeeze(),
|
|
velocity=velocity.squeeze(),
|
|
realtime=state.realtime + params.time_step,
|
|
energy=new_energy.squeeze(),
|
|
slope=state.slope,
|
|
time=state.time + 1,
|
|
)
|
|
|
|
# compute reward
|
|
# reward = new_state.position / params.goal_dist
|
|
# if new_state.position >= params.goal_dist:
|
|
# reward += 100
|
|
# reward += params.goal_time - new_state.realtime
|
|
# # penalize energy use
|
|
# reward += 1e-7 * new_state.energy # energy is negative
|
|
# if (
|
|
# new_state.realtime >= params.goal_time
|
|
# or new_state.time > params.max_steps_in_episode
|
|
# ):
|
|
# reward -= 500
|
|
|
|
# # we have to vectorize that.
|
|
# reward = new_state.position / params.goal_dist # constant reward for moving forward
|
|
# # reward for finishing
|
|
# reward += (new_state.position >= params.goal_dist) * (100 + params.goal_time - new_state.realtime + 1e-7*new_state.energy)
|
|
# # reward for failure
|
|
# reward += (new_state.realtime >= params.goal_time) * -500
|
|
|
|
reward = new_state.position / params.goal_dist + \
|
|
(new_state.position >= params.goal_dist) * (100 + params.goal_time - new_state.realtime + 1e-6*new_state.energy) + \
|
|
(new_state.realtime >= params.goal_time) * -500
|
|
reward = reward.squeeze()
|
|
terminal = self.is_terminal(state, params)
|
|
return (
|
|
lax.stop_gradient(self.get_obs(new_state, key, params)),
|
|
lax.stop_gradient(new_state),
|
|
reward,
|
|
terminal,
|
|
{},
|
|
)
|
|
|
|
def is_terminal(self, state: SimState, params: SimParams) -> jnp.ndarray:
|
|
finish = state.position >= params.goal_dist
|
|
timeout = state.time >= params.max_steps_in_episode
|
|
return jnp.logical_or(finish, timeout).squeeze()
|