From f310811b344ba2c7cd3e5321d39ea87232702f39 Mon Sep 17 00:00:00 2001 From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch> Date: Tue, 17 Nov 2020 13:08:07 +0100 Subject: [PATCH] ShortestPathWalkerHeuristicPolicy --- utils/shortest_path_walker_heuristic_agent.py | 57 +++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 utils/shortest_path_walker_heuristic_agent.py diff --git a/utils/shortest_path_walker_heuristic_agent.py b/utils/shortest_path_walker_heuristic_agent.py new file mode 100644 index 0000000..eaa71e9 --- /dev/null +++ b/utils/shortest_path_walker_heuristic_agent.py @@ -0,0 +1,57 @@ +import numpy as np +from flatland.envs.rail_env import RailEnvActions + +from reinforcement_learning.policy import Policy + + +class ShortestPathWalkerHeuristicPolicy(Policy): + def step(self, state, action, reward, next_state, done): + pass + + def act(self, node, eps=0.): + + left_node = node.childs.get('L') + forward_node = node.childs.get('F') + right_node = node.childs.get('R') + + dist_map = np.zeros(5) + dist_map[RailEnvActions.DO_NOTHING] = np.inf + dist_map[RailEnvActions.STOP_MOVING] = 100000 + # left + if left_node == -np.inf: + dist_map[RailEnvActions.MOVE_LEFT] = np.inf + else: + if left_node.num_agents_opposite_direction == 0: + dist_map[RailEnvActions.MOVE_LEFT] = left_node.dist_min_to_target + else: + dist_map[RailEnvActions.MOVE_LEFT] = np.inf + # forward + if forward_node == -np.inf: + dist_map[RailEnvActions.MOVE_FORWARD] = np.inf + else: + if forward_node.num_agents_opposite_direction == 0: + dist_map[RailEnvActions.MOVE_FORWARD] = forward_node.dist_min_to_target + else: + dist_map[RailEnvActions.MOVE_FORWARD] = np.inf + # right + if right_node == -np.inf: + dist_map[RailEnvActions.MOVE_RIGHT] = np.inf + else: + if right_node.num_agents_opposite_direction == 0: + dist_map[RailEnvActions.MOVE_RIGHT] = right_node.dist_min_to_target + else: + dist_map[RailEnvActions.MOVE_RIGHT] = np.inf + return np.argmin(dist_map) + + def save(self, filename): + pass + + def load(self, filename): + pass + + +policy = ShortestPathWalkerHeuristicPolicy() + + +def normalize_observation(observation, tree_depth: int, observation_radius=0): + return observation -- GitLab