Skip to content
Snippets Groups Projects
Commit b09e71e7 authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files
parents 8417e530 91ff7d01
No related branches found
No related tags found
No related merge requests found
...@@ -69,6 +69,31 @@ def validate_new_transition(rail_trans, rail_array, prev_pos, current_pos, new_p ...@@ -69,6 +69,31 @@ def validate_new_transition(rail_trans, rail_array, prev_pos, current_pos, new_p
return rail_trans.is_valid(new_trans) return rail_trans.is_valid(new_trans)
def position_to_coordinate(width, position):
"""
:param width:
:param position:
:return:
"""
coords = ()
for p in position:
coords = coords + ((int(p) % width, int(p) // width),) # changed x_dim to y_dim
return coords
def coordinate_to_position(width, coords):
"""
:param width:
:param coords:
:return:
"""
position = []
for t in coords:
position.append((t[1] * width + t[0]))
return np.array(position)
class AStarNode(): class AStarNode():
"""A node class for A* Pathfinding""" """A node class for A* Pathfinding"""
......
...@@ -173,10 +173,13 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -173,10 +173,13 @@ class TreeObsForRailEnv(ObservationBuilder):
in the `handles' list. in the `handles' list.
""" """
self.predictions = []
if self.predictor: if self.predictor:
for a in range(len(handles)): self.predictions = self.predictor.get()
self.predictions.append(self.predictor.get(a)) pred_pos = np.concatenate([[x[:, 1:3]] for x in list(self.predictions.values())], axis=0)
pred_pos = list(map(list, zip(*pred_pos)))
pred_dir = [x[:, 2] for x in list(self.predictions.values())]
observations = {} observations = {}
for h in handles: for h in handles:
observations[h] = self.get(h) observations[h] = self.get(h)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment