Skip to content
Snippets Groups Projects
Commit 2df4a434 authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

fixed bug in tree observation.

parent 6d7ff5cd
No related branches found
No related tags found
No related merge requests found
import numpy as np
from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
......@@ -39,7 +39,7 @@ env = RailEnv(width=40,
schedule_generator=sparse_schedule_generator(speed_ration_map),
number_of_agents=20,
stochastic_data=stochastic_data, # Malfunction data generator
obs_builder_object=GlobalObsForRailEnv(),
obs_builder_object=TreeObservation,
remove_agents_at_target=True
)
......
......@@ -174,7 +174,6 @@ class TreeObsForRailEnv(ObservationBuilder):
else:
self.location_has_agent_direction[(agent.position, agent.direction)] = 1
self.location_has_agent_speed = {tuple(agent.position): agent.speed_data['speed'] for agent in self.env.agents}
self.location_has_agent_malfunction = {tuple(agent.position): agent.malfunction_data['malfunction'] for agent in
self.env.agents}
......@@ -271,9 +270,9 @@ class TreeObsForRailEnv(ObservationBuilder):
if self.location_has_agent_malfunction[position] > malfunctioning_agent:
malfunctioning_agent = self.location_has_agent_malfunction[position]
if (agent.position, agent.direction) in self.location_has_agent_direction:
if (position, direction) in self.location_has_agent_direction:
# Cummulate the number of agents on branch with same direction
other_agent_same_direction += self.location_has_agent_direction[(agent.position, agent.direction)]
other_agent_same_direction += self.location_has_agent_direction[(position, direction)]
# Check fractional speed of agents
current_fractional_speed = self.location_has_agent_speed[position]
......@@ -284,13 +283,11 @@ class TreeObsForRailEnv(ObservationBuilder):
# TODO: This does not work as expected yet
other_agent_opposite_direction += self.location_has_agent[position] - \
self.location_has_agent_direction[
(agent.position, agent.direction)]
(position, direction)]
else:
# If no agent in the same direction was found all agents in that position are other direction
other_agent_opposite_direction += self.location_has_agent[position]
print("went in here")
if other_agent_opposite_direction > 0:
print("Other agents", other_agent_opposite_direction)
# Check number of possible transitions for agent and total number of transitions in cell (type)
cell_transitions = self.env.rail.get_transitions(*position, direction)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment