Commit e23d8f00 authored by nimishsantosh107's avatar nimishsantosh107
Browse files

wrapper fixes - incomplete

parent 6d0f42a7
Pipeline #8697 canceled with stages
...@@ -16,7 +16,8 @@ from flatland.core.grid.grid4_utils import get_new_position ...@@ -16,7 +16,8 @@ from flatland.core.grid.grid4_utils import get_new_position
# First of all we import the Flatland rail environment # First of all we import the Flatland rail environment
from flatland.utils.rendertools import RenderTool, AgentRenderVariant from flatland.utils.rendertools import RenderTool, AgentRenderVariant
from flatland.envs.agent_utils import EnvAgent, RailAgentStatus from flatland.envs.agent_utils import EnvAgent
from flatland.envs.step_utils.states import TrainState
from flatland.envs.rail_env import RailEnv, RailEnvActions from flatland.envs.rail_env import RailEnv, RailEnvActions
...@@ -24,20 +25,13 @@ def possible_actions_sorted_by_distance(env: RailEnv, handle: int): ...@@ -24,20 +25,13 @@ def possible_actions_sorted_by_distance(env: RailEnv, handle: int):
agent = env.agents[handle] agent = env.agents[handle]
if agent.status == RailAgentStatus.READY_TO_DEPART: if agent.state == TrainState.READY_TO_DEPART:
agent_virtual_position = agent.initial_position agent_virtual_position = agent.initial_position
elif agent.status == RailAgentStatus.ACTIVE: elif agent.state.is_on_map_state():
agent_virtual_position = agent.position agent_virtual_position = agent.position
elif agent.status == RailAgentStatus.DONE:
agent_virtual_position = agent.target
else: else:
print("no action possible!") print("no action possible!")
if agent.status == RailAgentStatus.DONE_REMOVED: print("agent status: ", agent.state)
print(f"agent status: DONE_REMOVED for agent {agent.handle}")
print("to solve this problem, do not input actions for removed agents!")
return [(RailEnvActions.DO_NOTHING, 0)] * 2
print("agent status:")
print(RailAgentStatus(agent.status))
#return None #return None
# NEW: if agent is at target, DO_NOTHING, and distance is zero. # NEW: if agent is at target, DO_NOTHING, and distance is zero.
# NEW: (needs to be tested...) # NEW: (needs to be tested...)
...@@ -58,25 +52,18 @@ def possible_actions_sorted_by_distance(env: RailEnv, handle: int): ...@@ -58,25 +52,18 @@ def possible_actions_sorted_by_distance(env: RailEnv, handle: int):
elif movement == (agent.direction - 1) % 4: elif movement == (agent.direction - 1) % 4:
action = RailEnvActions.MOVE_LEFT action = RailEnvActions.MOVE_LEFT
else: else:
# MICHEL: prints for debugging
print(f"An error occured. movement is: {movement}, agent direction is: {agent.direction}") print(f"An error occured. movement is: {movement}, agent direction is: {agent.direction}")
if movement == (agent.direction + 2) % 4 or (movement == agent.direction - 2) % 4: if movement == (agent.direction + 2) % 4 or (movement == agent.direction - 2) % 4:
print("it seems that we are turning by 180 degrees. Turning in a dead end?") print("it seems that we are turning by 180 degrees. Turning in a dead end?")
# MICHEL: can this happen when we turn 180 degrees in a dead end?
# i.e. can we then have movement == agent.direction + 2 % 4 (resp. ... == - 2 % 4)?
# TRY OUT: ASSIGN MOVE_FORWARD HERE...
action = RailEnvActions.MOVE_FORWARD action = RailEnvActions.MOVE_FORWARD
print("Here we would have a ValueError...") print("Here we would have a ValueError...")
#raise ValueError("Wtf, debug this shit.")
distance = distance_map[get_new_position(agent_virtual_position, movement) + (movement,)] distance = distance_map[get_new_position(agent_virtual_position, movement) + (movement,)]
possible_steps.append((action, distance)) possible_steps.append((action, distance))
possible_steps = sorted(possible_steps, key=lambda step: step[1]) possible_steps = sorted(possible_steps, key=lambda step: step[1])
# MICHEL: what is this doing?
# if there is only one path to target, this is both the shortest one and the second shortest path. # if there is only one path to target, this is both the shortest one and the second shortest path.
if len(possible_steps) == 1: if len(possible_steps) == 1:
return possible_steps * 2 return possible_steps * 2
...@@ -186,16 +173,9 @@ class ShortestPathActionWrapper(RailEnvWrapper): ...@@ -186,16 +173,9 @@ class ShortestPathActionWrapper(RailEnvWrapper):
super().__init__(env) super().__init__(env)
#self.action_space = gym.spaces.Discrete(n=3) # 0:stop, 1:shortest path, 2:other direction #self.action_space = gym.spaces.Discrete(n=3) # 0:stop, 1:shortest path, 2:other direction
# MICHEL: we have to make sure that not agents with agent.status == DONE_REMOVED are in the action dict. # MICHEL: we have to make sure that not agents with agent.state == DONE_REMOVED are in the action dict.
# otherwise, possible_actions_sorted_by_distance(self.env, agent_id)[action - 1][0] will crash. # otherwise, possible_actions_sorted_by_distance(self.env, agent_id)[action - 1][0] will crash.
def step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]: def step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]:
########## MICHEL: NEW (just for debugging) ########
for agent_id, action in action_dict.items():
agent = self.agents[agent_id]
# assert agent.status != RailAgentStatus.DONE_REMOVED # this comes with agent.position == None...
# assert agent.status != RailAgentStatus.DONE # not sure about this one...
print(f"agent: {agent} with status: {agent.status}")
######################################################
# input: action dict with actions in [0, 1, 2]. # input: action dict with actions in [0, 1, 2].
transformed_action_dict = {} transformed_action_dict = {}
...@@ -207,21 +187,14 @@ class ShortestPathActionWrapper(RailEnvWrapper): ...@@ -207,21 +187,14 @@ class ShortestPathActionWrapper(RailEnvWrapper):
# MICHEL: how exactly do the indices work here? # MICHEL: how exactly do the indices work here?
#transformed_action_dict[agent_id] = possible_actions_sorted_by_distance(self.rail_env, agent_id)[action - 1][0] #transformed_action_dict[agent_id] = possible_actions_sorted_by_distance(self.rail_env, agent_id)[action - 1][0]
#print(f"possible actions sorted by distance(...) is: {possible_actions_sorted_by_distance(self.env, agent_id)}") #print(f"possible actions sorted by distance(...) is: {possible_actions_sorted_by_distance(self.env, agent_id)}")
#assert agent.status != RailAgentStatus.DONE_REMOVED
# MICHEL: THIS LINE CRASHES WITH A "NoneType is not subscriptable" error... # MICHEL: THIS LINE CRASHES WITH A "NoneType is not subscriptable" error...
assert possible_actions_sorted_by_distance(self.env, agent_id) is not None assert possible_actions_sorted_by_distance(self.env, agent_id) is not None
assert possible_actions_sorted_by_distance(self.env, agent_id)[action - 1] is not None assert possible_actions_sorted_by_distance(self.env, agent_id)[action - 1] is not None
transformed_action_dict[agent_id] = possible_actions_sorted_by_distance(self.env, agent_id)[action - 1][0] transformed_action_dict[agent_id] = possible_actions_sorted_by_distance(self.env, agent_id)[action - 1][0]
obs, rewards, dones, info = self.env.step(transformed_action_dict) obs, rewards, dones, info = self.env.step(transformed_action_dict)
return obs, rewards, dones, info return obs, rewards, dones, info
#def reset(self, random_seed: Optional[int] = None) -> Dict[int, Any]:
#return self.rail_env.reset(random_seed)
# MICHEL: should not be needed, as we inherit that from RailEnvWrapper...
#def reset(self, **kwargs) -> Tuple[Dict, Dict]:
# obs, info = self.env.reset(**kwargs)
# return obs, info
def find_all_cells_where_agent_can_choose(env: RailEnv): def find_all_cells_where_agent_can_choose(env: RailEnv):
...@@ -236,19 +209,11 @@ def find_all_cells_where_agent_can_choose(env: RailEnv): ...@@ -236,19 +209,11 @@ def find_all_cells_where_agent_can_choose(env: RailEnv):
for h in range(env.height): for h in range(env.height):
for w in range(env.width): for w in range(env.width):
# MICHEL: THIS SEEMS TO BE A BUG. WRONG ODER OF COORDINATES.
# will not show up in quadratic environments.
# should be pos = (h, w)
#pos = (w, h)
# MICHEL: changed this
pos = (h, w) pos = (h, w)
is_switch = False is_switch = False
# Check for switch: if there is more than one outgoing transition # Check for switch: if there is more than one outgoing transition
for orientation in directions: for orientation in directions:
#print(f"env is: {env}")
#print(f"env.rail is: {env.rail}")
possible_transitions = env.rail.get_transitions(*pos, orientation) possible_transitions = env.rail.get_transitions(*pos, orientation)
num_transitions = np.count_nonzero(possible_transitions) num_transitions = np.count_nonzero(possible_transitions)
if num_transitions > 1: if num_transitions > 1:
...@@ -386,15 +351,12 @@ class SkipNoChoiceCellsWrapper(RailEnvWrapper): ...@@ -386,15 +351,12 @@ class SkipNoChoiceCellsWrapper(RailEnvWrapper):
self.skipper.reset_cells() self.skipper.reset_cells()
# TODO: this is clunky..
# for easier access / checking
self.switches = self.skipper.switches self.switches = self.skipper.switches
self.switches_neighbors = self.skipper.switches_neighbors self.switches_neighbors = self.skipper.switches_neighbors
self.decision_cells = self.skipper.decision_cells self.decision_cells = self.skipper.decision_cells
self.skipped_rewards = self.skipper.skipped_rewards self.skipped_rewards = self.skipper.skipped_rewards
# MICHEL: trying to isolate the core part and put it into a separate method.
def step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]: def step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]:
obs, rewards, dones, info = self.skipper.no_choice_skip_step(action_dict=action_dict) obs, rewards, dones, info = self.skipper.no_choice_skip_step(action_dict=action_dict)
return obs, rewards, dones, info return obs, rewards, dones, info
...@@ -409,4 +371,4 @@ class SkipNoChoiceCellsWrapper(RailEnvWrapper): ...@@ -409,4 +371,4 @@ class SkipNoChoiceCellsWrapper(RailEnvWrapper):
# resets decision cells, switches, etc. These can change with an env.reset(...)! # resets decision cells, switches, etc. These can change with an env.reset(...)!
# needs to be done after env.reset(). # needs to be done after env.reset().
self.skipper.reset_cells() self.skipper.reset_cells()
return obs, info return obs, info
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment