Newer
Older
Erik Nygren
committed
"""
Collection of environment-specific ObservationBuilder.
"""
import pprint
from collections import deque
import numpy as np
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.grid_utils import coordinate_to_position
class TreeObsForRailEnv(ObservationBuilder):
"""
TreeObsForRailEnv object.
This object returns observation vectors for agents in the RailEnv environment.
The information is local to each agent and exploits the graph structure of the rail
network to simplify the representation of the state of the environment for each agent.
For details about the features in the tree observation see the get() function.
"""
def __init__(self, max_depth, predictor=None):
super().__init__()
self.max_depth = max_depth
self.observation_dim = 11
Erik Nygren
committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
# Compute the size of the returned observation vector
size = 0
pow4 = 1
for i in range(self.max_depth + 1):
size += pow4
pow4 *= 4
self.observation_space = [size * self.observation_dim]
self.location_has_agent = {}
self.location_has_agent_direction = {}
self.predictor = predictor
self.agents_previous_reset = None
self.tree_explored_actions = [1, 2, 3, 0]
self.tree_explorted_actions_char = ['L', 'F', 'R', 'B']
self.distance_map = None
self.distance_map_computed = False
def reset(self):
agents = self.env.agents
nb_agents = len(agents)
compute_distance_map = True
if self.agents_previous_reset is not None and nb_agents == len(self.agents_previous_reset):
compute_distance_map = False
for i in range(nb_agents):
if agents[i].target != self.agents_previous_reset[i].target:
compute_distance_map = True
# Don't compute the distance map if it was loaded
if self.agents_previous_reset is None and self.distance_map is not None:
self.location_has_target = {tuple(agent.target): 1 for agent in agents}
compute_distance_map = False
if compute_distance_map:
self._compute_distance_map()
self.agents_previous_reset = agents
def _compute_distance_map(self):
agents = self.env.agents
# For testing only --> To assert if a distance map need to be recomputed.
self.distance_map_computed = True
nb_agents = len(agents)
self.distance_map = np.inf * np.ones(shape=(nb_agents,
self.env.height,
self.env.width,
4))
self.max_dist = np.zeros(nb_agents)
self.max_dist = [self._distance_map_walker(agent.target, i) for i, agent in enumerate(agents)]
# Update local lookup table for all agents' target locations
self.location_has_target = {tuple(agent.target): 1 for agent in agents}
def _distance_map_walker(self, position, target_nr):
"""
Utility function to compute distance maps from each cell in the rail network (and each possible
orientation within it) to each agent's target cell.
"""
# Returns max distance to target, from the farthest away node, while filling in distance_map
self.distance_map[target_nr, position[0], position[1], :] = 0
# Fill in the (up to) 4 neighboring nodes
# direction is the direction of movement, meaning that at least a possible orientation of an agent
# in cell (row,col) allows a movement in direction `direction'
nodes_queue = deque(self._get_and_update_neighbors(position, target_nr, 0, enforce_target_direction=-1))
# BFS from target `position' to all the reachable nodes in the grid
# Stop the search if the target position is re-visited, in any direction
visited = {(position[0], position[1], 0), (position[0], position[1], 1), (position[0], position[1], 2),
(position[0], position[1], 3)}
max_distance = 0
while nodes_queue:
node = nodes_queue.popleft()
node_id = (node[0], node[1], node[2])
if node_id not in visited:
visited.add(node_id)
# From the list of possible neighbors that have at least a path to the current node, only keep those
# whose new orientation in the current cell would allow a transition to direction node[2]
valid_neighbors = self._get_and_update_neighbors((node[0], node[1]), target_nr, node[3], node[2])
for n in valid_neighbors:
nodes_queue.append(n)
if len(valid_neighbors) > 0:
max_distance = max(max_distance, node[3] + 1)
return max_distance
def _get_and_update_neighbors(self, position, target_nr, current_distance, enforce_target_direction=-1):
"""
Utility function used by _distance_map_walker to perform a BFS walk over the rail, filling in the
minimum distances from each target cell.
"""
neighbors = []
possible_directions = [0, 1, 2, 3]
if enforce_target_direction >= 0:
# The agent must land into the current cell with orientation `enforce_target_direction'.
# This is only possible if the agent has arrived from the cell in the opposite direction!
possible_directions = [(enforce_target_direction + 2) % 4]
for neigh_direction in possible_directions:
new_cell = self._new_position(position, neigh_direction)
if new_cell[0] >= 0 and new_cell[0] < self.env.height and new_cell[1] >= 0 and new_cell[1] < self.env.width:
desired_movement_from_new_cell = (neigh_direction + 2) % 4
# Check all possible transitions in new_cell
for agent_orientation in range(4):
# Is a transition along movement `desired_movement_from_new_cell' to the current cell possible?
is_valid = self.env.rail.get_transition((new_cell[0], new_cell[1], agent_orientation),
desired_movement_from_new_cell)
if is_valid:
"""
# TODO: check that it works with deadends! -- still bugged!
movement = desired_movement_from_new_cell
if isNextCellDeadEnd:
movement = (desired_movement_from_new_cell+2) % 4
"""
new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation],
current_distance + 1)
neighbors.append((new_cell[0], new_cell[1], agent_orientation, new_distance))
self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation] = new_distance
return neighbors
def _new_position(self, position, movement):
"""
Utility function that converts a compass movement over a 2D grid to new positions (r, c).
"""
if movement == Grid4TransitionsEnum.NORTH:
return (position[0] - 1, position[1])
elif movement == Grid4TransitionsEnum.EAST:
return (position[0], position[1] + 1)
elif movement == Grid4TransitionsEnum.SOUTH:
return (position[0] + 1, position[1])
elif movement == Grid4TransitionsEnum.WEST:
return (position[0], position[1] - 1)
def get_many(self, handles=None):
"""
Called whenever an observation has to be computed for the `env' environment, for each agent with handle
in the `handles' list.
"""
if handles is None:
handles = []
if self.predictor:
self.max_prediction_depth = 0
self.predicted_pos = {}
self.predicted_dir = {}
self.predictions = self.predictor.get(custom_args={'distance_map': self.distance_map})
if self.predictions:
for t in range(len(self.predictions[0])):
pos_list = []
dir_list = []
for a in handles:
pos_list.append(self.predictions[a][t][1:3])
dir_list.append(self.predictions[a][t][3])
self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)})
self.predicted_dir.update({t: dir_list})
self.max_prediction_depth = len(self.predicted_pos)
observations = {}
for h in handles:
observations[h] = self.get(h)
return observations
def get(self, handle):
"""
Computes the current observation for agent `handle' in env
The observation vector is composed of 4 sequential parts, corresponding to data from the up to 4 possible
movements in a RailEnv (up to because only a subset of possible transitions are allowed in RailEnv).
The possible movements are sorted relative to the current orientation of the agent, rather than NESW as for
the transitions. The order is:
[data from 'left'] + [data from 'forward'] + [data from 'right'] + [data from 'back']
Each branch data is organized as:
[root node information] +
[recursive branch data from 'left'] +
[... from 'forward'] +
[... from 'right] +
[... from 'back']
Each node information is composed of 9 features:
#1: if own target lies on the explored branch the current distance from the agent in number of cells is stored.
#2: if another agents target is detected the distance in number of cells from the agents current location
Erik Nygren
committed
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
is stored
#3: if another agent is detected the distance in number of cells from current agent position is stored.
#4: possible conflict detected
tot_dist = Other agent predicts to pass along this cell at the same time as the agent, we store the
distance in number of cells from current agent position
0 = No other agent reserve the same cell at similar time
#5: if an not usable switch (for agent) is detected we store the distance.
#6: This feature stores the distance in number of cells to the next branching (current node)
#7: minimum distance from node to the agent's target given the direction of the agent if this path is chosen
#8: agent in the same direction
n = number of agents present same direction
(possible future use: number of other agents in the same direction in this branch)
0 = no agent present same direction
#9: agent in the opposite direction
n = number of agents present other direction than myself (so conflict)
(possible future use: number of other agents in other direction in this branch, ie. number of conflicts)
0 = no agent present other direction than myself
#10: malfunctioning/blokcing agents
n = number of time steps the oberved agent remains blocked
#11: slowest observed speed of an agent in same direction
1 if no agent is observed
min_fractional speed otherwise
Erik Nygren
committed
Missing/padding nodes are filled in with -inf (truncated).
Missing values in present node are filled in with +inf (truncated).
In case of the root node, the values are [0, 0, 0, 0, distance from agent to target, own malfunction, own speed]
Erik Nygren
committed
In case the target node is reached, the values are [0, 0, 0, 0, 0].
"""
# Update local lookup table for all agents' positions
self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents}
self.location_has_agent_direction = {tuple(agent.position): agent.direction for agent in self.env.agents}
self.location_has_agent_speed = {tuple(agent.position): agent.speed_data['speed'] for agent in self.env.agents}
self.location_has_agent_malfunction = {tuple(agent.position): agent.malfunction_data['malfunction'] for agent in
self.env.agents}
Erik Nygren
committed
if handle > len(self.env.agents):
print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents))
agent = self.env.agents[handle] # TODO: handle being treated as index
possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
num_transitions = np.count_nonzero(possible_transitions)
# Root node - current position
# Here information about the agent itself is stored
observation = [0, 0, 0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0,
agent.malfunction_data['malfunction'], agent.speed_data['speed']]
Erik Nygren
committed
visited = set()
Erik Nygren
committed
# Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right, back], relative to the current orientation
# If only one transition is possible, the tree is oriented with this transition as the forward branch.
orientation = agent.direction
if num_transitions == 1:
orientation = np.argmax(possible_transitions)
for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]:
if possible_transitions[branch_direction]:
new_cell = self._new_position(agent.position, branch_direction)
branch_observation, branch_visited = \
self._explore_branch(handle, new_cell, branch_direction, 1, 1)
observation = observation + branch_observation
visited = visited.union(branch_visited)
else:
# add cells filled with infinity if no transition is possible
observation = observation + [-np.inf] * self._num_cells_to_fill_in(self.max_depth)
self.env.dev_obs_dict[handle] = visited
Erik Nygren
committed
return observation
def _num_cells_to_fill_in(self, remaining_depth):
"""Computes the length of observation vector: sum_{i=0,depth-1} 2^i * observation_dim."""
num_observations = 0
pow4 = 1
for i in range(remaining_depth):
num_observations += pow4
pow4 *= 4
return num_observations * self.observation_dim
def _explore_branch(self, handle, position, direction, tot_dist, depth):
"""
Utility function to compute tree-based observations.
We walk along the branch and collect the information documented in the get() function.
If there is a branching point a new node is created and each possible branch is explored.
"""
Erik Nygren
committed
# [Recursive branch opened]
if depth >= self.max_depth + 1:
return [], []
# Continue along direction until next switch or
# until no transitions are possible along the current direction (i.e., dead-ends)
# We treat dead-ends as nodes, instead of going back, to avoid loops
exploring = True
last_is_switch = False
last_is_dead_end = False
last_is_terminal = False # wrong cell OR cycle; either way, we don't want the agent to land here
last_is_target = False
visited = set()
agent = self.env.agents[handle]
time_per_cell = np.reciprocal(agent.speed_data["speed"])
Erik Nygren
committed
own_target_encountered = np.inf
other_agent_encountered = np.inf
other_target_encountered = np.inf
potential_conflict = np.inf
unusable_switch = np.inf
other_agent_same_direction = 0
other_agent_opposite_direction = 0
malfunctioning_agent = 0
min_fractional_speed = 1.
Erik Nygren
committed
num_steps = 1
while exploring:
# #############################
# #############################
# Modify here to compute any useful data required to build the end node's features. This code is called
# for each cell visited between the previous branching node and the next switch / target / dead-end.
if position in self.location_has_agent:
if tot_dist < other_agent_encountered:
other_agent_encountered = tot_dist
# Check if any of the observed agents is malfunctioning, store agent with longest duration left
if self.location_has_agent_malfunction[position] > malfunctioning_agent:
malfunctioning_agent = self.location_has_agent_malfunction[position]
Erik Nygren
committed
if self.location_has_agent_direction[position] == direction:
# Cummulate the number of agents on branch with same direction
other_agent_same_direction += 1
# Check fractional speed of agents
current_fractional_speed = self.location_has_agent_speed[position]
if current_fractional_speed < min_fractional_speed:
min_fractional_speed = current_fractional_speed
Erik Nygren
committed
if self.location_has_agent_direction[position] != direction:
# Cummulate the number of agents on branch with other direction
other_agent_opposite_direction += 1
# Check number of possible transitions for agent and total number of transitions in cell (type)
cell_transitions = self.env.rail.get_transitions(*position, direction)
transition_bit = bin(self.env.rail.get_full_transitions(*position))
total_transitions = transition_bit.count("1")
crossing_found = False
if int(transition_bit, 2) == int('1000010000100001', 2):
crossing_found = True
# Register possible future conflict
predicted_time = int(tot_dist * time_per_cell)
if self.predictor and predicted_time < self.max_prediction_depth:
Erik Nygren
committed
int_position = coordinate_to_position(self.env.width, [position])
if tot_dist < self.max_prediction_depth:
pre_step = max(0, predicted_time - 1)
post_step = min(self.max_prediction_depth - 1, predicted_time + 1)
# Look for conflicting paths at distance tot_dist
if int_position in np.delete(self.predicted_pos[predicted_time], handle, 0):
conflicting_agent = np.where(self.predicted_pos[predicted_time] == int_position)
Erik Nygren
committed
for ca in conflicting_agent[0]:
if direction != self.predicted_dir[predicted_time][ca] and cell_transitions[
self._reverse_dir(
self.predicted_dir[predicted_time][ca])] == 1 and tot_dist < potential_conflict:
Erik Nygren
committed
potential_conflict = tot_dist
if self.env.dones[ca] and tot_dist < potential_conflict:
potential_conflict = tot_dist
# Look for conflicting paths at distance num_step-1
elif int_position in np.delete(self.predicted_pos[pre_step], handle, 0):
conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position)
for ca in conflicting_agent[0]:
if direction != self.predicted_dir[pre_step][ca] \
and cell_transitions[self._reverse_dir(self.predicted_dir[pre_step][ca])] == 1 \
and tot_dist < potential_conflict: # noqa: E125
Erik Nygren
committed
potential_conflict = tot_dist
if self.env.dones[ca] and tot_dist < potential_conflict:
potential_conflict = tot_dist
# Look for conflicting paths at distance num_step+1
elif int_position in np.delete(self.predicted_pos[post_step], handle, 0):
conflicting_agent = np.where(self.predicted_pos[post_step] == int_position)
for ca in conflicting_agent[0]:
if direction != self.predicted_dir[post_step][ca] and cell_transitions[self._reverse_dir(
self.predicted_dir[post_step][ca])] == 1 \
and tot_dist < potential_conflict: # noqa: E125
Erik Nygren
committed
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
potential_conflict = tot_dist
if self.env.dones[ca] and tot_dist < potential_conflict:
potential_conflict = tot_dist
if position in self.location_has_target and position != agent.target:
if tot_dist < other_target_encountered:
other_target_encountered = tot_dist
if position == agent.target and tot_dist < own_target_encountered:
own_target_encountered = tot_dist
# #############################
# #############################
if (position[0], position[1], direction) in visited:
last_is_terminal = True
break
visited.add((position[0], position[1], direction))
# If the target node is encountered, pick that as node. Also, no further branching is possible.
if np.array_equal(position, self.env.agents[handle].target):
last_is_target = True
break
# Check if crossing is found --> Not an unusable switch
if crossing_found:
# Treat the crossing as a straight rail cell
total_transitions = 2
num_transitions = np.count_nonzero(cell_transitions)
exploring = False
# Detect Switches that can only be used by other agents.
if total_transitions > 2 > num_transitions and tot_dist < unusable_switch:
unusable_switch = tot_dist
if num_transitions == 1:
# Check if dead-end, or if we can go forward along direction
nbits = total_transitions
if nbits == 1:
# Dead-end!
last_is_dead_end = True
if not last_is_dead_end:
# Keep walking through the tree along `direction'
exploring = True
# convert one-hot encoding to 0,1,2,3
direction = np.argmax(cell_transitions)
position = self._new_position(position, direction)
num_steps += 1
tot_dist += 1
elif num_transitions > 0:
# Switch detected
last_is_switch = True
break
elif num_transitions == 0:
# Wrong cell type, but let's cover it and treat it as a dead-end, just in case
print("WRONG CELL TYPE detected in tree-search (0 transitions possible) at cell", position[0],
position[1], direction)
last_is_terminal = True
break
# `position' is either a terminal node or a switch
# #############################
# #############################
# Modify here to append new / different features for each visited cell!
if last_is_target:
observation = [own_target_encountered,
other_target_encountered,
other_agent_encountered,
potential_conflict,
unusable_switch,
tot_dist,
0,
other_agent_same_direction,
other_agent_opposite_direction,
malfunctioning_agent,
min_fractional_speed
Erik Nygren
committed
]
elif last_is_terminal:
observation = [own_target_encountered,
other_target_encountered,
other_agent_encountered,
potential_conflict,
unusable_switch,
np.inf,
self.distance_map[handle, position[0], position[1], direction],
other_agent_same_direction,
other_agent_opposite_direction,
malfunctioning_agent,
min_fractional_speed
Erik Nygren
committed
]
else:
observation = [own_target_encountered,
other_target_encountered,
other_agent_encountered,
potential_conflict,
unusable_switch,
tot_dist,
self.distance_map[handle, position[0], position[1], direction],
other_agent_same_direction,
other_agent_opposite_direction,
malfunctioning_agent,
min_fractional_speed
Erik Nygren
committed
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
]
# #############################
# #############################
# Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right, back], relative to the current orientation
# Get the possible transitions
possible_transitions = self.env.rail.get_transitions(*position, direction)
for branch_direction in [(direction + 4 + i) % 4 for i in range(-1, 3)]:
if last_is_dead_end and self.env.rail.get_transition((*position, direction),
(branch_direction + 2) % 4):
# Swap forward and back in case of dead-end, so that an agent can learn that going forward takes
# it back
new_cell = self._new_position(position, (branch_direction + 2) % 4)
branch_observation, branch_visited = self._explore_branch(handle,
new_cell,
(branch_direction + 2) % 4,
tot_dist + 1,
depth + 1)
observation = observation + branch_observation
if len(branch_visited) != 0:
visited = visited.union(branch_visited)
elif last_is_switch and possible_transitions[branch_direction]:
new_cell = self._new_position(position, branch_direction)
branch_observation, branch_visited = self._explore_branch(handle,
new_cell,
branch_direction,
tot_dist + 1,
depth + 1)
observation = observation + branch_observation
if len(branch_visited) != 0:
visited = visited.union(branch_visited)
else:
# no exploring possible, add just cells with infinity
observation = observation + [-np.inf] * self._num_cells_to_fill_in(self.max_depth - depth)
return observation, visited
def util_print_obs_subtree(self, tree):
"""
Utility function to pretty-print tree observations returned by this object.
"""
pp = pprint.PrettyPrinter(indent=4)
pp.pprint(self.unfold_observation_tree(tree))
def unfold_observation_tree(self, tree, current_depth=0, actions_for_display=True):
"""
Utility function to pretty-print tree observations returned by this object.
"""
if len(tree) < self.observation_dim:
return
depth = 0
tmp = len(tree) / self.observation_dim - 1
pow4 = 4
while tmp > 0:
tmp -= pow4
depth += 1
pow4 *= 4
unfolded = {}
unfolded[''] = tree[0:self.observation_dim]
child_size = (len(tree) - self.observation_dim) // 4
for child in range(4):
child_tree = tree[(self.observation_dim + child * child_size):
(self.observation_dim + (child + 1) * child_size)]
observation_tree = self.unfold_observation_tree(child_tree, current_depth=current_depth + 1)
if observation_tree is not None:
if actions_for_display:
label = self.tree_explorted_actions_char[child]
else:
label = self.tree_explored_actions[child]
unfolded[label] = observation_tree
return unfolded
def _set_env(self, env):
self.env = env
if self.predictor:
self.predictor._set_env(self.env)
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
def _reverse_dir(self, direction):
return int((direction + 2) % 4)
class GlobalObsForRailEnv(ObservationBuilder):
"""
Gives a global observation of the entire rail environment.
The observation is composed of the following elements:
- transition map array with dimensions (env.height, env.width, 16),
assuming 16 bits encoding of transitions.
- Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent
target and the positions of the other agents targets.
- A 3D array (map_height, map_width, 8) with the 4 first channels containing the one hot encoding
of the direction of the given agent and the 4 second channels containing the positions
of the other agents at their position coordinates.
"""
def __init__(self):
self.observation_space = ()
super(GlobalObsForRailEnv, self).__init__()
def _set_env(self, env):
super()._set_env(env)
self.observation_space = [4, self.env.height, self.env.width]
def reset(self):
self.rail_obs = np.zeros((self.env.height, self.env.width, 16))
for i in range(self.rail_obs.shape[0]):
for j in range(self.rail_obs.shape[1]):
bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]]
bitlist = [0] * (16 - len(bitlist)) + bitlist
self.rail_obs[i, j] = np.array(bitlist)
def get(self, handle):
obs_targets = np.zeros((self.env.height, self.env.width, 2))
obs_agents_state = np.zeros((self.env.height, self.env.width, 8))
agents = self.env.agents
agent = agents[handle]
direction = np.zeros(4)
direction[agent.direction] = 1
agent_pos = agents[handle].position
obs_agents_state[agent_pos][:4] = direction
obs_targets[agent.target][0] += 1
for i in range(len(agents)):
if i != handle: # TODO: handle used as index...?
agent2 = agents[i]
obs_agents_state[agent2.position][4 + agent2.direction] = 1
obs_targets[agent2.target][1] += 1
direction = self._get_one_hot_for_agent_direction(agent)
return self.rail_obs, obs_agents_state, obs_targets, direction
class GlobalObsForRailEnvDirectionDependent(ObservationBuilder):
"""
Gives a global observation of the entire rail environment.
The observation is composed of the following elements:
- transition map array with dimensions (env.height, env.width, 16),
assuming 16 bits encoding of transitions, flipped in the direction of the agent
(the agent is always heading north on the flipped view).
- Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent
target and the positions of the other agents targets, also flipped depending on the agent's direction.
- A 3D array (map_height, map_width, 5) containing the one hot encoding of the direction of the other
agents at their position coordinates, and the last channel containing the position of the given agent.
- A 4 elements array with one hot encoding of the direction.
"""
def __init__(self):
self.observation_space = ()
super(GlobalObsForRailEnvDirectionDependent, self).__init__()
def _set_env(self, env):
super()._set_env(env)
self.observation_space = [4, self.env.height, self.env.width]
def reset(self):
self.rail_obs = np.zeros((self.env.height, self.env.width, 16))
for i in range(self.rail_obs.shape[0]):
for j in range(self.rail_obs.shape[1]):
bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]]
bitlist = [0] * (16 - len(bitlist)) + bitlist
self.rail_obs[i, j] = np.array(bitlist)
def get(self, handle):
obs_targets = np.zeros((self.env.height, self.env.width, 2))
obs_agents_state = np.zeros((self.env.height, self.env.width, 5))
agents = self.env.agents
agent = agents[handle]
direction = agent.direction
idx = np.tile(np.arange(16), 2)
rail_obs = self.rail_obs[:, :, idx[direction * 4: direction * 4 + 16]]
if direction == 1:
rail_obs = np.flip(rail_obs, axis=1)
elif direction == 2:
rail_obs = np.flip(rail_obs)
elif direction == 3:
rail_obs = np.flip(rail_obs, axis=0)
agent_pos = agents[handle].position
obs_agents_state[agent_pos][0] = 1
obs_targets[agent.target][0] += 1
idx = np.tile(np.arange(4), 2)
for i in range(len(agents)):
if i != handle: # TODO: handle used as index...?
agent2 = agents[i]
obs_agents_state[agent2.position][1 + idx[4 + (agent2.direction - direction)]] = 1
obs_targets[agent2.target][1] += 1
direction = self._get_one_hot_for_agent_direction(agent)
return rail_obs, obs_agents_state, obs_targets, direction
class LocalObsForRailEnv(ObservationBuilder):
"""
Gives a local observation of the rail environment around the agent.
The observation is composed of the following elements:
- transition map array of the local environment around the given agent,
with dimensions (view_height,2*view_width+1, 16),
assuming 16 bits encoding of transitions.
- Two 2D arrays (view_height,2*view_width+1, 2) containing respectively,
if they are in the agent's vision range, its target position, the positions of the other targets.
- A 2D array (view_height,2*view_width+1, 4) containing the one hot encoding of directions
of the other agents at their position coordinates, if they are in the agent's vision range.
- A 4 elements array with one hot encoding of the direction.
Use the parameters view_width and view_height to define the rectangular view of the agent.
The center parameters moves the agent along the height axis of this rectangle. If it is 0 the agent only has
observation in front of it.
"""
def __init__(self, view_width, view_height, center):
super(LocalObsForRailEnv, self).__init__()
self.view_width = view_width
self.view_height = view_height
self.center = center
self.max_padding = max(self.view_width, self.view_height - self.center)
def reset(self):
# We build the transition map with a view_radius empty cells expansion on each side.
# This helps to collect the local transition map view when the agent is close to a border.
self.max_padding = max(self.view_width, self.view_height)
self.rail_obs = np.zeros((self.env.height,
self.env.width, 16))
for i in range(self.env.height):
for j in range(self.env.width):
bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]]
bitlist = [0] * (16 - len(bitlist)) + bitlist
self.rail_obs[i, j] = np.array(bitlist)
def get(self, handle):
agents = self.env.agents
agent = agents[handle]
# Correct agents position for padding
# agent_rel_pos[0] = agent.position[0] + self.max_padding
# agent_rel_pos[1] = agent.position[1] + self.max_padding
# Collect visible cells as set to be plotted
visited, rel_coords = self.field_of_view(agent.position, agent.direction, )
local_rail_obs = None
# Add the visible cells to the observed cells
self.env.dev_obs_dict[handle] = set(visited)
# Locate observed agents and their coresponding targets
local_rail_obs = np.zeros((self.view_height, 2 * self.view_width + 1, 16))
obs_map_state = np.zeros((self.view_height, 2 * self.view_width + 1, 2))
obs_other_agents_state = np.zeros((self.view_height, 2 * self.view_width + 1, 4))
_idx = 0
for pos in visited:
curr_rel_coord = rel_coords[_idx]
local_rail_obs[curr_rel_coord[0], curr_rel_coord[1], :] = self.rail_obs[pos[0], pos[1], :]
if pos == agent.target:
obs_map_state[curr_rel_coord[0], curr_rel_coord[1], 0] = 1
else:
for tmp_agent in agents:
if pos == tmp_agent.target:
obs_map_state[curr_rel_coord[0], curr_rel_coord[1], 1] = 1
if pos != agent.position:
for tmp_agent in agents:
if pos == tmp_agent.position:
obs_other_agents_state[curr_rel_coord[0], curr_rel_coord[1], :] = np.identity(4)[
tmp_agent.direction]
_idx += 1
direction = np.identity(4)[agent.direction]
return local_rail_obs, obs_map_state, obs_other_agents_state, direction
def get_many(self, handles=None):
"""
Called whenever an observation has to be computed for the `env' environment, for each agent with handle
in the `handles' list.
"""
observations = {}
for h in handles:
observations[h] = self.get(h)
return observations
def field_of_view(self, position, direction, state=None):
# Compute the local field of view for an agent in the environment
data_collection = False
if state is not None:
temp_visible_data = np.zeros(shape=(self.view_height, 2 * self.view_width + 1, 16))
data_collection = True
if direction == 0:
origin = (position[0] + self.center, position[1] - self.view_width)
elif direction == 1:
origin = (position[0] - self.view_width, position[1] - self.center)
elif direction == 2:
origin = (position[0] - self.center, position[1] + self.view_width)
else:
origin = (position[0] + self.view_width, position[1] + self.center)
visible = list()
rel_coords = list()
for h in range(self.view_height):
for w in range(2 * self.view_width + 1):
if direction == 0:
if 0 <= origin[0] - h < self.env.height and 0 <= origin[1] + w < self.env.width:
visible.append((origin[0] - h, origin[1] + w))
rel_coords.append((h, w))
# if data_collection:
# temp_visible_data[h, w, :] = state[origin[0] - h, origin[1] + w, :]
elif direction == 1:
if 0 <= origin[0] + w < self.env.height and 0 <= origin[1] + h < self.env.width:
visible.append((origin[0] + w, origin[1] + h))
rel_coords.append((h, w))
# if data_collection:
# temp_visible_data[h, w, :] = state[origin[0] + w, origin[1] + h, :]
elif direction == 2:
if 0 <= origin[0] + h < self.env.height and 0 <= origin[1] - w < self.env.width:
visible.append((origin[0] + h, origin[1] - w))
rel_coords.append((h, w))
# if data_collection:
# temp_visible_data[h, w, :] = state[origin[0] + h, origin[1] - w, :]
else:
if 0 <= origin[0] - w < self.env.height and 0 <= origin[1] - h < self.env.width:
visible.append((origin[0] - w, origin[1] - h))
rel_coords.append((h, w))
# if data_collection:
# temp_visible_data[h, w, :] = state[origin[0] - w, origin[1] - h, :]
if data_collection:
return temp_visible_data
else:
return visible, rel_coords