Skip to content
Snippets Groups Projects
Commit 5d186dc5 authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

still buggy with treeeobservation

parent 7ecdafb9
No related branches found
No related tags found
No related merge requests found
......@@ -28,13 +28,13 @@ speed_ration_map = {1.: 0.25, # Fast passenger train
1. / 3.: 0.25, # Slow commuter train
1. / 4.: 0.25} # Slow freight train
env = RailEnv(width=100,
height=20,
rail_generator=sparse_rail_generator(num_cities=2, # Number of cities in map (where train stations are)
env = RailEnv(width=40,
height=40,
rail_generator=sparse_rail_generator(num_cities=8, # Number of cities in map (where train stations are)
seed=1, # Random seed
grid_mode=True,
grid_mode=False,
max_inter_city_rails=2,
max_tracks_in_city=8,
max_tracks_in_city=4,
),
schedule_generator=sparse_schedule_generator(speed_ration_map),
number_of_agents=20,
......
......@@ -161,13 +161,20 @@ class TreeObsForRailEnv(ObservationBuilder):
# Update local lookup table for all agents' positions
self.location_has_agent = dict()
self.location_has_agent_direction = dict()
for agent in self.env.agents:
if tuple(agent.position) in self.location_has_agent:
self.location_has_agent[tuple(agent.position)] = self.location_has_agent[tuple(agent.position)] + 1
else:
self.location_has_agent[tuple(agent.position)] = 1
# TODO: Update this to handle number of agents at same location
self.location_has_agent_direction = {tuple(agent.position): agent.direction for agent in self.env.agents}
if (agent.position, agent.direction) in self.location_has_agent_direction:
self.location_has_agent_direction[(agent.position, agent.direction)] = \
self.location_has_agent_direction[(agent.position, agent.direction)] + 1
else:
self.location_has_agent_direction[(agent.position, agent.direction)] = 1
self.location_has_agent_speed = {tuple(agent.position): agent.speed_data['speed'] for agent in self.env.agents}
self.location_has_agent_malfunction = {tuple(agent.position): agent.malfunction_data['malfunction'] for agent in
self.env.agents}
......@@ -264,20 +271,28 @@ class TreeObsForRailEnv(ObservationBuilder):
if self.location_has_agent_malfunction[position] > malfunctioning_agent:
malfunctioning_agent = self.location_has_agent_malfunction[position]
if self.location_has_agent_direction[position] == direction:
if (agent.position, agent.direction) in self.location_has_agent_direction:
# Cummulate the number of agents on branch with same direction
other_agent_same_direction += 1
other_agent_same_direction += self.location_has_agent_direction[(agent.position, agent.direction)]
# Check fractional speed of agents
current_fractional_speed = self.location_has_agent_speed[position]
if current_fractional_speed < min_fractional_speed:
min_fractional_speed = current_fractional_speed
if self.location_has_agent_direction[position] != direction:
# Cummulate the number of agents on branch with other direction
other_agent_opposite_direction += 1
# Other direction agents
# TODO: This does not work as expected yet
other_agent_opposite_direction += self.location_has_agent[position] - \
self.location_has_agent_direction[
(agent.position, agent.direction)]
else:
# If no agent in the same direction was found all agents in that position are other direction
other_agent_opposite_direction += self.location_has_agent[position]
print("went in here")
if other_agent_opposite_direction > 0:
print("Other agents", other_agent_opposite_direction)
# Check number of possible transitions for agent and total number of transitions in cell (type)
# Check number of possible transitions for agent and total number of transitions in cell (type)
cell_transitions = self.env.rail.get_transitions(*position, direction)
transition_bit = bin(self.env.rail.get_full_transitions(*position))
total_transitions = transition_bit.count("1")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment