Commit 1a61fc91 authored by u214892's avatar u214892
Browse files

#188 finally

parent 406cdac8
Pipeline #2345 passed with stages
in 31 minutes and 52 seconds
......@@ -189,15 +189,15 @@ class TreeObsForRailEnv(ObservationBuilder):
agent = self.env.agents[handle] # TODO: handle being treated as index
if agent.status == RailAgentStatus.READY_TO_DEPART:
_agent_initial_position = agent.initial_position
agent_virtual_position = agent.initial_position
elif agent.status == RailAgentStatus.ACTIVE:
_agent_initial_position = agent.position
agent_virtual_position = agent.position
elif agent.status == RailAgentStatus.DONE:
_agent_initial_position = agent.target
agent_virtual_position = agent.target
else:
return None
possible_transitions = self.env.rail.get_transitions(*_agent_initial_position, agent.direction)
possible_transitions = self.env.rail.get_transitions(*agent_virtual_position, agent.direction)
num_transitions = np.count_nonzero(possible_transitions)
# Here information about the agent itself is stored
......@@ -207,7 +207,7 @@ class TreeObsForRailEnv(ObservationBuilder):
dist_other_agent_encountered=0, dist_potential_conflict=0,
dist_unusable_switch=0, dist_to_next_branch=0,
dist_min_to_target=distance_map[
(handle, *_agent_initial_position,
(handle, *agent_virtual_position,
agent.direction)],
num_agents_same_direction=0, num_agents_opposite_direction=0,
num_agents_malfunctioning=agent.malfunction_data['malfunction'],
......@@ -228,7 +228,7 @@ class TreeObsForRailEnv(ObservationBuilder):
for i, branch_direction in enumerate([(orientation + i) % 4 for i in range(-1, 3)]):
if possible_transitions[branch_direction]:
new_cell = get_new_position(_agent_initial_position, branch_direction)
new_cell = get_new_position(agent_virtual_position, branch_direction)
branch_observation, branch_visited = \
self._explore_branch(handle, new_cell, branch_direction, 1, 1)
......@@ -562,11 +562,11 @@ class GlobalObsForRailEnv(ObservationBuilder):
agent = self.env.agents[handle]
if agent.status == RailAgentStatus.READY_TO_DEPART:
_agent_initial_position = agent.initial_position
agent_virtual_position = agent.initial_position
elif agent.status == RailAgentStatus.ACTIVE:
_agent_initial_position = agent.position
agent_virtual_position = agent.position
elif agent.status == RailAgentStatus.DONE:
_agent_initial_position = agent.target
agent_virtual_position = agent.target
else:
return None
......@@ -578,7 +578,7 @@ class GlobalObsForRailEnv(ObservationBuilder):
for c in range(self.env.width):
obs_agents_state[(r, c)][4] = 0
obs_agents_state[_agent_initial_position][0] = agent.direction
obs_agents_state[agent_virtual_position][0] = agent.direction
obs_targets[agent.target][0] = 1
for i in range(len(self.env.agents)):
......
......@@ -52,10 +52,10 @@ class DummyPredictorForRailEnv(PredictionBuilder):
# TODO make this generic
continue
action_priorities = [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT]
_agent_initial_position = agent.position
_agent_initial_direction = agent.direction
agent_virtual_position = agent.position
agent_virtual_direction = agent.direction
prediction = np.zeros(shape=(self.max_depth + 1, 5))
prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0]
prediction[0] = [0, *agent_virtual_position, agent_virtual_direction, 0]
for index in range(1, self.max_depth + 1):
action_done = False
# if we're at the target, stop moving...
......@@ -77,8 +77,8 @@ class DummyPredictorForRailEnv(PredictionBuilder):
if not action_done:
raise Exception("Cannot move further. Something is wrong")
prediction_dict[agent.handle] = prediction
agent.position = _agent_initial_position
agent.direction = _agent_initial_direction
agent.position = agent_virtual_position
agent.direction = agent_virtual_direction
return prediction_dict
......@@ -128,20 +128,20 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
for agent in agents:
if agent.status == RailAgentStatus.READY_TO_DEPART:
_agent_initial_position = agent.initial_position
agent_virtual_position = agent.initial_position
elif agent.status == RailAgentStatus.ACTIVE:
_agent_initial_position = agent.position
agent_virtual_position = agent.position
elif agent.status == RailAgentStatus.DONE:
_agent_initial_position = agent.target
agent_virtual_position = agent.target
else:
prediction_dict[agent.handle] = None
continue
_agent_initial_direction = agent.direction
agent_virtual_direction = agent.direction
agent_speed = agent.speed_data["speed"]
times_per_cell = int(np.reciprocal(agent_speed))
prediction = np.zeros(shape=(self.max_depth + 1, 5))
prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0]
prediction[0] = [0, *agent_virtual_position, agent_virtual_direction, 0]
shortest_path = shortest_paths[agent.handle]
......@@ -149,8 +149,8 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
if shortest_path:
shortest_path = shortest_path[1:]
new_direction = _agent_initial_direction
new_position = _agent_initial_position
new_direction = agent_virtual_direction
new_position = agent_virtual_position
visited = OrderedSet()
for index in range(1, self.max_depth + 1):
# if we're at the target or not moving, stop moving until max_depth is reached
......
......@@ -31,15 +31,15 @@ class SingleAgentNavigationObs(ObservationBuilder):
agent = self.env.agents[handle]
if agent.status == RailAgentStatus.READY_TO_DEPART:
_agent_initial_position = agent.initial_position
agent_virtual_position = agent.initial_position
elif agent.status == RailAgentStatus.ACTIVE:
_agent_initial_position = agent.position
agent_virtual_position = agent.position
elif agent.status == RailAgentStatus.DONE:
_agent_initial_position = agent.target
agent_virtual_position = agent.target
else:
return None
possible_transitions = self.env.rail.get_transitions(*_agent_initial_position, agent.direction)
possible_transitions = self.env.rail.get_transitions(*agent_virtual_position, agent.direction)
num_transitions = np.count_nonzero(possible_transitions)
# Start from the current orientation, and see which transitions are available;
......@@ -51,7 +51,7 @@ class SingleAgentNavigationObs(ObservationBuilder):
min_distances = []
for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
if possible_transitions[direction]:
new_position = get_new_position(_agent_initial_position, direction)
new_position = get_new_position(agent_virtual_position, direction)
min_distances.append(
self.env.distance_map.get()[handle, new_position[0], new_position[1], direction])
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment