Commit 59c5eb16 authored by metataro's avatar metataro

shortest path observation

parent c333654f
import gym
import numpy as np
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.rail_env import RailEnv
from envs.flatland.observations import Observation, register_obs
@register_obs("shortest_path")
class ShortestPathObservation(Observation):
def __init__(self, config) -> None:
super().__init__(config)
self._config = config
self._builder = ShortestPathForRailEnv(encode_one_hot=True)
def builder(self) -> ObservationBuilder:
return self._builder
def observation_space(self) -> gym.Space:
return gym.spaces.Tuple([
gym.spaces.Box(low=0, high=1, shape=(4,)), # shortest path direction (one-hot)
gym.spaces.Box(low=0, high=1, shape=(1,)), # shortest path distance to target
gym.spaces.Box(low=0, high=1, shape=(1,)), # conflict when following shortest path (1=true, 0=false)
gym.spaces.Box(low=0, high=1, shape=(4,)), # other path direction (all zero if not available)
gym.spaces.Box(low=0, high=1, shape=(1,)), # other path direction (zero if not available)
gym.spaces.Box(low=0, high=1, shape=(1,)), # conflict when following other path (1=true, 0=false)
])
class ShortestPathForRailEnv(ObservationBuilder):
def __init__(self, encode_one_hot=True):
super().__init__()
self._encode_one_hot = encode_one_hot
def reset(self):
pass
def get(self, handle: int = 0):
self.env: RailEnv = self.env
agent = self.env.agents[handle]
if agent.status == RailAgentStatus.READY_TO_DEPART:
agent_virtual_position = agent.initial_position
elif agent.status == RailAgentStatus.ACTIVE:
agent_virtual_position = agent.position
elif agent.status == RailAgentStatus.DONE:
agent_virtual_position = agent.target
else:
return None
directions = list(range(4))
possible_transitions = self.env.rail.get_transitions(*agent_virtual_position, agent.direction)
distance_map = self.env.distance_map.get()
nan_inf_mask = ((distance_map != np.inf) * (np.abs(np.isnan(distance_map) - 1))).astype(np.bool)
max_distance = np.max(distance_map[nan_inf_mask])
assert not np.isnan(max_distance)
assert max_distance != np.inf
possible_steps = []
# look in all directions for possible moves
for movement in directions:
if possible_transitions[movement]:
next_move = movement
pos = get_new_position(agent_virtual_position, movement)
distance = distance_map[agent.handle][pos + (movement,)] # new distance to target
distance = max_distance if (distance == np.inf or np.isnan(distance)) else distance # TODO: why does this happen?
# look ahead if there is an agent between the agent and the next intersection
# Todo: currently any train between the agent and the next intersection is reported. This includes
# those that are moving away from the agent and therefore are not really conflicting. Will be improved.
conflict = self.env.agent_positions[pos] != -1
next_possible_moves = self.env.rail.get_transitions(*pos, movement)
while np.count_nonzero(next_possible_moves) == 1 and not conflict:
movement = np.argmax(next_possible_moves)
pos = get_new_position(pos, movement)
conflict = self.env.agent_positions[pos] != -1
next_possible_moves = self.env.rail.get_transitions(*pos, movement)
if self._encode_one_hot:
next_move_one_hot = np.zeros(len(directions))
next_move_one_hot[next_move] = 1
next_move = next_move_one_hot
possible_steps.append((next_move, [distance/max_distance], [int(conflict)]))
if len(possible_steps) == 1:
# print(possible_steps[0] + (np.zeros(len(directions)), [.0], [0]))
return possible_steps[0] + (np.zeros(len(directions)), [.0], [0])
elif len(possible_steps) == 2:
possible_steps = sorted(possible_steps, key=lambda step: step[1]) # sort by distance, ascending
# print(possible_steps[0] + possible_steps[1])
return possible_steps[0] + possible_steps[1]
else:
raise ValueError(f"More than two possibles steps at {agent_virtual_position}. Looks like a bug.")
......@@ -7,7 +7,7 @@ from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.agent_utils import EnvAgent, RailAgentStatus
from flatland.envs.rail_env import RailEnv, RailEnvActions
from envs.flatland.utils.gym_env import StepOutput
from envs.flatland.utils.gym_env import StepOutput, FlatlandGymEnv
def available_actions(env: RailEnv, agent: EnvAgent, allow_noop=True) -> List[int]:
......@@ -234,3 +234,61 @@ class DeadlockWrapper(gym.Wrapper):
def reset(self, random_seed: Optional[int] = None) -> Dict[int, Any]:
self._deadlocked_agents = []
return self.env.reset(random_seed)
def possible_actions_sorted_by_distance(env: RailEnv, handle: int):
agent = env.agents[handle]
if agent.status == RailAgentStatus.READY_TO_DEPART:
agent_virtual_position = agent.initial_position
elif agent.status == RailAgentStatus.ACTIVE:
agent_virtual_position = agent.position
elif agent.status == RailAgentStatus.DONE:
agent_virtual_position = agent.target
else:
return None
possible_transitions = env.rail.get_transitions(*agent_virtual_position, agent.direction)
distance_map = env.distance_map.get()[handle]
possible_steps = []
for movement in list(range(4)):
if possible_transitions[movement]:
if movement == agent.direction:
action = RailEnvActions.MOVE_FORWARD
elif movement == (agent.direction + 1) % 4:
action = RailEnvActions.MOVE_RIGHT
elif movement == (agent.direction - 1) % 4:
action = RailEnvActions.MOVE_LEFT
else:
raise ValueError("Wtf, debug this shit.")
distance = distance_map[get_new_position(agent_virtual_position, movement) + (movement,)]
possible_steps.append((action, distance))
possible_steps = sorted(possible_steps, key=lambda step: step[1])
if len(possible_steps) == 1:
return possible_steps * 2
else:
return possible_steps
class ShortestPathActionWrapper(gym.Wrapper):
def __init__(self, env) -> None:
super().__init__(env)
print("Apply ShortestPathActionWrapper")
self.action_space = gym.spaces.Discrete(n=3) # stop, shortest path, other direction
def step(self, action_dict: Dict[int, RailEnvActions]) -> StepOutput:
rail_env: RailEnv = self.env.unwrapped.rail_env
transformed_action_dict = {}
for agent_id, action in action_dict.items():
if action == 0:
transformed_action_dict[agent_id] = action
else:
assert action in [1, 2]
transformed_action_dict[agent_id] = possible_actions_sorted_by_distance(rail_env, agent_id)[action - 1][0]
step_output = self.env.step(transformed_action_dict)
return step_output
def reset(self, random_seed: Optional[int] = None) -> Dict[int, Any]:
return self.env.reset(random_seed)
......@@ -7,7 +7,8 @@ from envs.flatland.utils.env_generators import random_sparse_env_small
from envs.flatland.observations import make_obs
from envs.flatland.utils.gym_env import FlatlandGymEnv
from envs.flatland.utils.gym_env_wrappers import SkipNoChoiceCellsWrapper, AvailableActionsWrapper
from envs.flatland.utils.gym_env_wrappers import SkipNoChoiceCellsWrapper, AvailableActionsWrapper, DeadlockWrapper, \
SparseRewardWrapper, ShortestPathActionWrapper
class FlatlandRandomSparseSmall(MultiAgentEnv):
......@@ -31,8 +32,15 @@ class FlatlandRandomSparseSmall(MultiAgentEnv):
regenerate_rail_on_reset=env_config['regenerate_rail_on_reset'],
regenerate_schedule_on_reset=env_config['regenerate_schedule_on_reset']
)
if env_config['observation'] == 'shortest_path':
self._env = ShortestPathActionWrapper(self._env)
if env_config.get('sparse_reward', False):
self._env = SparseRewardWrapper(self._env, finished_reward=env_config.get('done_reward', 1),
not_finished_reward=env_config.get('not_finished_reward', -1))
if env_config.get('deadlock_reward', 0) != 0:
self._env = DeadlockWrapper(self._env, deadlock_reward=env_config['deadlock_reward'])
if env_config.get('skip_no_choice_cells', False):
self._env = SkipNoChoiceCellsWrapper(self._env, self._config.get('accumulate_skipped_rewards', False))
self._env = SkipNoChoiceCellsWrapper(self._env, env_config.get('accumulate_skipped_rewards', False))
if env_config.get('available_actions_obs', False):
self._env = AvailableActionsWrapper(self._env)
......
......@@ -11,7 +11,8 @@ from envs.flatland import get_generator_config
from envs.flatland.observations import make_obs
from envs.flatland.utils.gym_env import FlatlandGymEnv, StepOutput
from envs.flatland.utils.gym_env_wrappers import SkipNoChoiceCellsWrapper, AvailableActionsWrapper
from envs.flatland.utils.gym_env_wrappers import SkipNoChoiceCellsWrapper, AvailableActionsWrapper, \
ShortestPathActionWrapper, SparseRewardWrapper, DeadlockWrapper
class FlatlandSingle(gym.Env):
......@@ -27,8 +28,15 @@ class FlatlandSingle(gym.Env):
regenerate_rail_on_reset=self._config['regenerate_rail_on_reset'],
regenerate_schedule_on_reset=self._config['regenerate_schedule_on_reset']
)
if env_config['observation'] == 'shortest_path':
self._env = ShortestPathActionWrapper(self._env)
if env_config.get('sparse_reward', False):
self._env = SparseRewardWrapper(self._env, finished_reward=env_config.get('done_reward', 1),
not_finished_reward=env_config.get('not_finished_reward', -1))
if env_config.get('deadlock_reward', 0) != 0:
self._env = DeadlockWrapper(self._env, deadlock_reward=env_config['deadlock_reward'])
if env_config.get('skip_no_choice_cells', False):
self._env = SkipNoChoiceCellsWrapper(self._env, self._config.get('accumulate_skipped_rewards', False))
self._env = SkipNoChoiceCellsWrapper(self._env, env_config.get('accumulate_skipped_rewards', False))
if env_config.get('available_actions_obs', False):
self._env = AvailableActionsWrapper(self._env)
......
......@@ -12,7 +12,7 @@ from envs.flatland import get_generator_config
from envs.flatland.observations import make_obs
from envs.flatland.utils.gym_env import FlatlandGymEnv
from envs.flatland.utils.gym_env_wrappers import AvailableActionsWrapper, SkipNoChoiceCellsWrapper, SparseRewardWrapper, \
DeadlockWrapper
DeadlockWrapper, ShortestPathActionWrapper
class FlatlandSparse(MultiAgentEnv):
......@@ -37,6 +37,8 @@ class FlatlandSparse(MultiAgentEnv):
regenerate_rail_on_reset=self._config['regenerate_rail_on_reset'],
regenerate_schedule_on_reset=self._config['regenerate_schedule_on_reset']
)
if env_config['observation'] == 'shortest_path':
self._env = ShortestPathActionWrapper(self._env)
if env_config.get('sparse_reward', False):
self._env = SparseRewardWrapper(self._env, finished_reward=env_config.get('done_reward', 1),
not_finished_reward=env_config.get('not_finished_reward', -1))
......
flatland-sparse-small-sortest_path-fc-apex:
run: APEX
env: flatland_sparse
stop:
timesteps_total: 100000000 # 1e8
checkpoint_freq: 10
checkpoint_at_end: True
# keep_checkpoints_num: 5
checkpoint_score_attr: episode_reward_mean
config:
num_workers: 15
num_envs_per_worker: 5
num_gpus: 0
env_config:
observation: shortest_path
generator: sparse_rail_generator
generator_config: small_v0
wandb:
project: flatland
entity: masterscrat
tags: ["small_v0", "tree_obs", "apex", "sparse_reward", "deadlock_reward"] # TODO should be set programmatically
model:
fcnet_activation: relu
fcnet_hiddens: [256, 256]
vf_share_layers: True # False
flatland-sparse-small-sortest_path-fc-ppo:
run: PPO
env: flatland_sparse
stop:
timesteps_total: 10000000 # 1e7
checkpoint_freq: 10
checkpoint_at_end: True
checkpoint_score_attr: episode_reward_mean
config:
clip_rewards: False
# clip_param: 0.1
# vf_clip_param: 500.0
vf_clip_param: 10.0
entropy_coeff: 0.01
# effective batch_size: train_batch_size * num_agents_in_each_environment
# see https://github.com/ray-project/ray/issues/4628
train_batch_size: 1000 # 5000
rollout_fragment_length: 50 # 100
sgd_minibatch_size: 100 # 500
num_sgd_iter: 10
num_workers: 15
num_envs_per_worker: 5
batch_mode: truncate_episodes
observation_filter: NoFilter
vf_share_layers: True
vf_loss_coeff: 0.5
num_gpus: 0
env_config:
observation: shortest_path
generator: sparse_rail_generator
generator_config: small_v0
wandb:
project: flatland
entity: masterscrat
tags: ["small_v0", "tree_obs", "ppo", "sparse_reward", "deadlock_reward"] # TODO should be set programmatically
model:
fcnet_activation: relu
fcnet_hiddens: [256, 256]
vf_share_layers: True # False
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment