diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..d6fd6f21b4f8d95e972cf75f62bebcdc4537a139 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,2 @@ +# Default ignored files +/workspace.xml diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000000000000000000000000000000000000..20fc29e7a85e1af4cdab6fc32d2197bfd9cc1d27 --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ +<component name="InspectionProjectProfileManager"> + <settings> + <option name="USE_PROJECT_PROFILE" value="false" /> + <version value="1.0" /> + </settings> +</component> \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000000000000000000000000000000000000..5417b684e7619b0b44d2f9d8be364ac8fb576783 --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,4 @@ +<?xml version="1.0" encoding="UTF-8"?> +<project version="4"> + <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.6" project-jdk-type="Python SDK" /> +</project> \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000000000000000000000000000000000000..c9b6e2db97b91b3f14bb3798b10f26f334334bb6 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ +<?xml version="1.0" encoding="UTF-8"?> +<project version="4"> + <component name="ProjectModuleManager"> + <modules> + <module fileurl="file://$PROJECT_DIR$/.idea/neurips2020-flatland-starter-kit.iml" filepath="$PROJECT_DIR$/.idea/neurips2020-flatland-starter-kit.iml" /> + </modules> + </component> +</project> \ No newline at end of file diff --git a/.idea/neurips2020-flatland-starter-kit.iml b/.idea/neurips2020-flatland-starter-kit.iml new file mode 100644 index 0000000000000000000000000000000000000000..951c9286734f053453803b04b5335bb575715344 --- /dev/null +++ b/.idea/neurips2020-flatland-starter-kit.iml @@ -0,0 +1,11 @@ +<?xml version="1.0" encoding="UTF-8"?> +<module type="PYTHON_MODULE" version="4"> + <component name="NewModuleRootManager"> + <content url="file://$MODULE_DIR$" /> + <orderEntry type="inheritedJdk" /> + <orderEntry type="sourceFolder" forTests="false" /> + </component> + <component name="TestRunnerService"> + <option name="PROJECT_TEST_RUNNER" value="pytest" /> + </component> +</module> \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000000000000000000000000000000000000..9661ac713428efbad557d3ba3a62216b5bb7d226 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ +<?xml version="1.0" encoding="UTF-8"?> +<project version="4"> + <component name="VcsDirectoryMappings"> + <mapping directory="$PROJECT_DIR$" vcs="Git" /> + </component> +</project> \ No newline at end of file diff --git a/checkpoints/ppo/README.md b/checkpoints/ppo/README.md new file mode 100644 index 0000000000000000000000000000000000000000..fce8cd3428ab5c13ea082dd922054080beae4822 --- /dev/null +++ b/checkpoints/ppo/README.md @@ -0,0 +1 @@ +PPO checkpoints will be saved here diff --git a/checkpoints/ppo/model_checkpoint.meta b/checkpoints/ppo/model_checkpoint.meta new file mode 100644 index 0000000000000000000000000000000000000000..86356a4182e2dfe3f589ef620f88beca40694952 Binary files /dev/null and b/checkpoints/ppo/model_checkpoint.meta differ diff --git a/checkpoints/ppo/model_checkpoint.optimizer b/checkpoints/ppo/model_checkpoint.optimizer new file mode 100644 index 0000000000000000000000000000000000000000..bc423f8c53440c69332bef7afc73b23fe4908c61 Binary files /dev/null and b/checkpoints/ppo/model_checkpoint.optimizer differ diff --git a/checkpoints/ppo/model_checkpoint.policy b/checkpoints/ppo/model_checkpoint.policy new file mode 100644 index 0000000000000000000000000000000000000000..a739edb31c3fbf01d2658a714222d2973246416c Binary files /dev/null and b/checkpoints/ppo/model_checkpoint.policy differ diff --git a/docker_run.sh b/docker_run.sh index eeec29823b7603c361b65c0752f62a3b328d31c9..f14996e5e254c6d266e2f4d0bb47033b8547aaec 100755 --- a/docker_run.sh +++ b/docker_run.sh @@ -1,18 +1,18 @@ -#!/bin/bash - - -if [ -e environ_secret.sh ] -then - echo "Note: Gathering environment variables from environ_secret.sh" - source environ_secret.sh -else - echo "Note: Gathering environment variables from environ.sh" - source environ.sh -fi - -# Expected Env variables : in environ.sh -sudo docker run \ - --net=host \ - -v ./scratch/test-envs:/flatland_envs:z \ - -it ${IMAGE_NAME}:${IMAGE_TAG} \ - /home/aicrowd/run.sh +#!/bin/bash + + +if [ -e environ_secret.sh ] +then + echo "Note: Gathering environment variables from environ_secret.sh" + source environ_secret.sh +else + echo "Note: Gathering environment variables from environ.sh" + source environ.sh +fi + +# Expected Env variables : in environ.sh +sudo docker run \ + --net=host \ + -v ./scratch/test-envs:/flatland_envs:z \ + -it ${IMAGE_NAME}:${IMAGE_TAG} \ + /home/aicrowd/run.sh diff --git a/dump.rdb b/dump.rdb new file mode 100644 index 0000000000000000000000000000000000000000..d719ed7cce7a692fb2775ab881f5020877164240 Binary files /dev/null and b/dump.rdb differ diff --git a/nets/training_5500.pth.local b/nets/training_5500.pth.local new file mode 100644 index 0000000000000000000000000000000000000000..e77bb01ae7aef933e59799528c325a412b520290 Binary files /dev/null and b/nets/training_5500.pth.local differ diff --git a/nets/training_5500.pth.target b/nets/training_5500.pth.target new file mode 100644 index 0000000000000000000000000000000000000000..a086633e9a959abb573f4231dc76d9a8123f4861 Binary files /dev/null and b/nets/training_5500.pth.target differ diff --git a/nets/training_best_0.626_agents_5276.pth.local b/nets/training_best_0.626_agents_5276.pth.local new file mode 100644 index 0000000000000000000000000000000000000000..0a080fcc30deae97f34610670bce980599e648d6 Binary files /dev/null and b/nets/training_best_0.626_agents_5276.pth.local differ diff --git a/nets/training_best_0.626_agents_5276.pth.target b/nets/training_best_0.626_agents_5276.pth.target new file mode 100644 index 0000000000000000000000000000000000000000..58f0d0bcaca15481c42523a21cda75cb226a782e Binary files /dev/null and b/nets/training_best_0.626_agents_5276.pth.target differ diff --git a/reinforcement_learning/multi_policy.py b/reinforcement_learning/multi_policy.py new file mode 100644 index 0000000000000000000000000000000000000000..765bcf599f681f8b7a3dca311f223c5eed85e42d --- /dev/null +++ b/reinforcement_learning/multi_policy.py @@ -0,0 +1,64 @@ +import numpy as np +from flatland.envs.rail_env import RailEnvActions + +from reinforcement_learning.policy import Policy +from reinforcement_learning.ppo.ppo_agent import PPOAgent +from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent +from utils.extra import ExtraPolicy + + +class MultiPolicy(Policy): + def __init__(self, state_size, action_size, n_agents, env): + self.state_size = state_size + self.action_size = action_size + self.memory = [] + self.loss = 0 + self.extra_policy = ExtraPolicy(state_size, action_size) + self.ppo_policy = PPOAgent(state_size + action_size, action_size, n_agents, env) + + def load(self, filename): + self.ppo_policy.load(filename) + self.extra_policy.load(filename) + + def save(self, filename): + self.ppo_policy.save(filename) + self.extra_policy.save(filename) + + def step(self, handle, state, action, reward, next_state, done): + action_extra_state = self.extra_policy.act(handle, state, 0.0) + action_extra_next_state = self.extra_policy.act(handle, next_state, 0.0) + + extended_state = np.copy(state) + for action_itr in np.arange(self.action_size): + extended_state = np.append(extended_state, [int(action_extra_state == action_itr)]) + extended_next_state = np.copy(next_state) + for action_itr in np.arange(self.action_size): + extended_next_state = np.append(extended_next_state, [int(action_extra_next_state == action_itr)]) + + self.extra_policy.step(handle, state, action, reward, next_state, done) + self.ppo_policy.step(handle, extended_state, action, reward, extended_next_state, done) + + def act(self, handle, state, eps=0.): + action_extra_state = self.extra_policy.act(handle, state, 0.0) + extended_state = np.copy(state) + for action_itr in np.arange(self.action_size): + extended_state = np.append(extended_state, [int(action_extra_state == action_itr)]) + action_ppo = self.ppo_policy.act(handle, extended_state, eps) + self.loss = self.ppo_policy.loss + return action_ppo + + def reset(self): + self.ppo_policy.reset() + self.extra_policy.reset() + + def test(self): + self.ppo_policy.test() + self.extra_policy.test() + + def start_step(self): + self.extra_policy.start_step() + self.ppo_policy.start_step() + + def end_step(self): + self.extra_policy.end_step() + self.ppo_policy.end_step() diff --git a/reinforcement_learning/sequential_agent_training.py b/reinforcement_learning/sequential_agent_training.py new file mode 100644 index 0000000000000000000000000000000000000000..ca19d1fcbbb4e3508a16b847d4b4cfcefc6aad98 --- /dev/null +++ b/reinforcement_learning/sequential_agent_training.py @@ -0,0 +1,78 @@ +import sys +import numpy as np + +from flatland.envs.observations import TreeObsForRailEnv +from flatland.envs.predictions import ShortestPathPredictorForRailEnv +from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import complex_rail_generator +from flatland.envs.schedule_generators import complex_schedule_generator +from flatland.utils.rendertools import RenderTool +from pathlib import Path + +base_dir = Path(__file__).resolve().parent.parent +sys.path.append(str(base_dir)) + +from reinforcement_learning.ordered_policy import OrderedPolicy + +np.random.seed(2) + +x_dim = 20 # np.random.randint(8, 20) +y_dim = 20 # np.random.randint(8, 20) +n_agents = 10 # np.random.randint(3, 8) +n_goals = n_agents + np.random.randint(0, 3) +min_dist = int(0.75 * min(x_dim, y_dim)) + +env = RailEnv(width=x_dim, + height=y_dim, + rail_generator=complex_rail_generator( + nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist, + max_dist=99999, + seed=0 + ), + schedule_generator=complex_schedule_generator(), + obs_builder_object=TreeObsForRailEnv(max_depth=1, predictor=ShortestPathPredictorForRailEnv()), + number_of_agents=n_agents) +env.reset(True, True) + +tree_depth = 1 +observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=ShortestPathPredictorForRailEnv()) +env_renderer = RenderTool(env, gl="PGL", ) +handle = env.get_agent_handles() +n_episodes = 1 +max_steps = 100 * (env.height + env.width) +record_images = False +policy = OrderedPolicy() +action_dict = dict() + +for trials in range(1, n_episodes + 1): + + # Reset environment + obs, info = env.reset(True, True) + done = env.dones + env_renderer.reset() + frame_step = 0 + + # Run episode + for step in range(max_steps): + env_renderer.render_env(show=True, show_observations=False, show_predictions=True) + + if record_images: + env_renderer.gl.save_image("./Images/flatland_frame_{:04d}.bmp".format(frame_step)) + frame_step += 1 + + # Action + acting_agent = 0 + for a in range(env.get_num_agents()): + if done[a]: + acting_agent += 1 + if a == acting_agent: + action = policy.act(obs[a]) + else: + action = 4 + action_dict.update({a: action}) + + # Environment step + obs, all_rewards, done, _ = env.step(action_dict) + + if done['__all__']: + break diff --git a/run.sh b/run.sh index 953c1660c6abafcc0a474c526ef7ffcedef6a5d8..5ead27c555cb94a1aff3d199fe1ddc55147fbd7e 100755 --- a/run.sh +++ b/run.sh @@ -1,3 +1,5 @@ #!/bin/bash +# manually install submodules. + python ./run.py diff --git a/run_fast_methods.py b/run_fast_methods.py new file mode 100644 index 0000000000000000000000000000000000000000..dbb5ea10f31983ce6aa0210a9cb3d531c4142659 --- /dev/null +++ b/run_fast_methods.py @@ -0,0 +1,26 @@ +from time import time + +import numpy as np +from flatland.envs.rail_env import fast_isclose + + +def print_timing(label, start_time, end_time): + print("{:>10.4f}ms".format(1000 * (end_time - start_time)) + "\t" + label) + + +def check_isclose(nbr=100000): + s = time() + for x in range(nbr): + fast_isclose(x, 0.0, rtol=1e-03) + e = time() + print_timing("fast_isclose", start_time=s, end_time=e) + + s = time() + for x in range(nbr): + np.isclose(x, 0.0, rtol=1e-03) + e = time() + print_timing("np.isclose", start_time=s, end_time=e) + + +if __name__ == "__main__": + check_isclose() diff --git a/utils/extra.py b/utils/extra.py new file mode 100644 index 0000000000000000000000000000000000000000..89ed0bb9ea2b7a993fe9eecd9332f0910294cbce --- /dev/null +++ b/utils/extra.py @@ -0,0 +1,363 @@ +import numpy as np +from flatland.core.env_observation_builder import ObservationBuilder +from flatland.core.grid.grid4_utils import get_new_position +from flatland.envs.agent_utils import RailAgentStatus +from flatland.envs.rail_env import RailEnvActions, fast_argmax, fast_count_nonzero + +from reinforcement_learning.policy import Policy +from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent, DeadlockAvoidanceShortestDistanceWalker + + +class ExtraPolicy(Policy): + def __init__(self, state_size, action_size): + self.state_size = state_size + self.action_size = action_size + self.memory = [] + self.loss = 0 + + def load(self, filename): + pass + + def save(self, filename): + pass + + def step(self, handle, state, action, reward, next_state, done): + pass + + def act(self, handle, state, eps=0.): + a = 0 + b = 4 + action = RailEnvActions.STOP_MOVING + if state[2] == 1 and state[10 + a] == 0: + action = RailEnvActions.MOVE_LEFT + elif state[3] == 1 and state[11 + a] == 0: + action = RailEnvActions.MOVE_FORWARD + elif state[4] == 1 and state[12 + a] == 0: + action = RailEnvActions.MOVE_RIGHT + elif state[5] == 1 and state[13 + a] == 0: + action = RailEnvActions.MOVE_FORWARD + + elif state[6] == 1 and state[10 + b] == 0: + action = RailEnvActions.MOVE_LEFT + elif state[7] == 1 and state[11 + b] == 0: + action = RailEnvActions.MOVE_FORWARD + elif state[8] == 1 and state[12 + b] == 0: + action = RailEnvActions.MOVE_RIGHT + elif state[9] == 1 and state[13 + b] == 0: + action = RailEnvActions.MOVE_FORWARD + + return action + + def test(self): + pass + + +class Extra(ObservationBuilder): + + def __init__(self, max_depth): + self.max_depth = max_depth + self.observation_dim = 31 + + def build_data(self): + self.dead_lock_avoidance_agent = None + if self.env is not None: + self.env.dev_obs_dict = {} + self.dead_lock_avoidance_agent = DeadLockAvoidanceAgent(self.env, None, None) + + self.switches = {} + self.switches_neighbours = {} + self.debug_render_list = [] + self.debug_render_path_list = [] + if self.env is not None: + self.find_all_cell_where_agent_can_choose() + + def find_all_cell_where_agent_can_choose(self): + + switches = {} + for h in range(self.env.height): + for w in range(self.env.width): + pos = (h, w) + for dir in range(4): + possible_transitions = self.env.rail.get_transitions(*pos, dir) + num_transitions = fast_count_nonzero(possible_transitions) + if num_transitions > 1: + if pos not in switches.keys(): + switches.update({pos: [dir]}) + else: + switches[pos].append(dir) + + switches_neighbours = {} + for h in range(self.env.height): + for w in range(self.env.width): + # look one step forward + for dir in range(4): + pos = (h, w) + possible_transitions = self.env.rail.get_transitions(*pos, dir) + for d in range(4): + if possible_transitions[d] == 1: + new_cell = get_new_position(pos, d) + if new_cell in switches.keys() and pos not in switches.keys(): + if pos not in switches_neighbours.keys(): + switches_neighbours.update({pos: [dir]}) + else: + switches_neighbours[pos].append(dir) + + self.switches = switches + self.switches_neighbours = switches_neighbours + + def check_agent_descision(self, position, direction): + switches = self.switches + switches_neighbours = self.switches_neighbours + agents_on_switch = False + agents_on_switch_all = False + agents_near_to_switch = False + agents_near_to_switch_all = False + if position in switches.keys(): + agents_on_switch = direction in switches[position] + agents_on_switch_all = True + + if position in switches_neighbours.keys(): + new_cell = get_new_position(position, direction) + if new_cell in switches.keys(): + if not direction in switches[new_cell]: + agents_near_to_switch = direction in switches_neighbours[position] + else: + agents_near_to_switch = direction in switches_neighbours[position] + + agents_near_to_switch_all = direction in switches_neighbours[position] + + return agents_on_switch, agents_near_to_switch, agents_near_to_switch_all, agents_on_switch_all + + def required_agent_descision(self): + agents_can_choose = {} + agents_on_switch = {} + agents_on_switch_all = {} + agents_near_to_switch = {} + agents_near_to_switch_all = {} + for a in range(self.env.get_num_agents()): + ret_agents_on_switch, ret_agents_near_to_switch, ret_agents_near_to_switch_all, ret_agents_on_switch_all = \ + self.check_agent_descision( + self.env.agents[a].position, + self.env.agents[a].direction) + agents_on_switch.update({a: ret_agents_on_switch}) + agents_on_switch_all.update({a: ret_agents_on_switch_all}) + ready_to_depart = self.env.agents[a].status == RailAgentStatus.READY_TO_DEPART + agents_near_to_switch.update({a: (ret_agents_near_to_switch and not ready_to_depart)}) + + agents_can_choose.update({a: agents_on_switch[a] or agents_near_to_switch[a]}) + + agents_near_to_switch_all.update({a: (ret_agents_near_to_switch_all and not ready_to_depart)}) + + return agents_can_choose, agents_on_switch, agents_near_to_switch, agents_near_to_switch_all, agents_on_switch_all + + def debug_render(self, env_renderer): + agents_can_choose, agents_on_switch, agents_near_to_switch, agents_near_to_switch_all = \ + self.required_agent_descision() + self.env.dev_obs_dict = {} + for a in range(max(3, self.env.get_num_agents())): + self.env.dev_obs_dict.update({a: []}) + + selected_agent = None + if agents_can_choose[0]: + if self.env.agents[0].position is not None: + self.debug_render_list.append(self.env.agents[0].position) + else: + self.debug_render_list.append(self.env.agents[0].initial_position) + + if self.env.agents[0].position is not None: + self.debug_render_path_list.append(self.env.agents[0].position) + else: + self.debug_render_path_list.append(self.env.agents[0].initial_position) + + env_renderer.gl.agent_colors[0] = env_renderer.gl.rgb_s2i("FF0000") + env_renderer.gl.agent_colors[1] = env_renderer.gl.rgb_s2i("666600") + env_renderer.gl.agent_colors[2] = env_renderer.gl.rgb_s2i("006666") + env_renderer.gl.agent_colors[3] = env_renderer.gl.rgb_s2i("550000") + + self.env.dev_obs_dict[0] = self.debug_render_list + self.env.dev_obs_dict[1] = self.switches.keys() + self.env.dev_obs_dict[2] = self.switches_neighbours.keys() + self.env.dev_obs_dict[3] = self.debug_render_path_list + + def reset(self): + self.build_data() + return + + + def _check_dead_lock_at_branching_position(self, handle, new_position, branch_direction): + _, full_shortest_distance_agent_map = self.dead_lock_avoidance_agent.shortest_distance_walker.getData() + opp_agents = self.dead_lock_avoidance_agent.shortest_distance_walker.opp_agent_map.get(handle, []) + local_walker = DeadlockAvoidanceShortestDistanceWalker( + self.env, + self.dead_lock_avoidance_agent.shortest_distance_walker.agent_positions, + self.dead_lock_avoidance_agent.shortest_distance_walker.switches) + local_walker.walk_to_target(handle, new_position, branch_direction) + shortest_distance_agent_map, _ = self.dead_lock_avoidance_agent.shortest_distance_walker.getData() + my_shortest_path_to_check = shortest_distance_agent_map[handle] + next_step_ok = self.dead_lock_avoidance_agent.check_agent_can_move(my_shortest_path_to_check, + opp_agents, + full_shortest_distance_agent_map) + return next_step_ok + + def _explore(self, handle, new_position, new_direction, depth=0): + + has_opp_agent = 0 + has_same_agent = 0 + has_switch = 0 + visited = [] + + # stop exploring (max_depth reached) + if depth >= self.max_depth: + return has_opp_agent, has_same_agent, has_switch, visited + + # max_explore_steps = 100 + cnt = 0 + while cnt < 100: + cnt += 1 + + visited.append(new_position) + opp_a = self.env.agent_positions[new_position] + if opp_a != -1 and opp_a != handle: + if self.env.agents[opp_a].direction != new_direction: + # opp agent found + has_opp_agent = 1 + return has_opp_agent, has_same_agent, has_switch, visited + else: + has_same_agent = 1 + return has_opp_agent, has_same_agent, has_switch, visited + + # convert one-hot encoding to 0,1,2,3 + agents_on_switch, \ + agents_near_to_switch, \ + agents_near_to_switch_all, \ + agents_on_switch_all = \ + self.check_agent_descision(new_position, new_direction) + if agents_near_to_switch: + return has_opp_agent, has_same_agent, has_switch, visited + + possible_transitions = self.env.rail.get_transitions(*new_position, new_direction) + if agents_on_switch: + f = 0 + for dir_loop in range(4): + if possible_transitions[dir_loop] == 1: + f += 1 + hoa, hsa, hs, v = self._explore(handle, + get_new_position(new_position, dir_loop), + dir_loop, + depth + 1) + visited.append(v) + has_opp_agent += hoa + has_same_agent += hsa + has_switch += hs + f = max(f, 1.0) + return has_opp_agent / f, has_same_agent / f, has_switch / f, visited + else: + new_direction = fast_argmax(possible_transitions) + new_position = get_new_position(new_position, new_direction) + + return has_opp_agent, has_same_agent, has_switch, visited + + def get(self, handle): + + if handle == 0: + self.dead_lock_avoidance_agent.start_step() + + # all values are [0,1] + # observation[0] : 1 path towards target (direction 0) / otherwise 0 -> path is longer or there is no path + # observation[1] : 1 path towards target (direction 1) / otherwise 0 -> path is longer or there is no path + # observation[2] : 1 path towards target (direction 2) / otherwise 0 -> path is longer or there is no path + # observation[3] : 1 path towards target (direction 3) / otherwise 0 -> path is longer or there is no path + # observation[4] : int(agent.status == RailAgentStatus.READY_TO_DEPART) + # observation[5] : int(agent.status == RailAgentStatus.ACTIVE) + # observation[6] : int(agent.status == RailAgentStatus.DONE or agent.status == RailAgentStatus.DONE_REMOVED) + # observation[7] : current agent is located at a switch, where it can take a routing decision + # observation[8] : current agent is located at a cell, where it has to take a stop-or-go decision + # observation[9] : current agent is located one step before/after a switch + # observation[10] : 1 if there is a path (track/branch) otherwise 0 (direction 0) + # observation[11] : 1 if there is a path (track/branch) otherwise 0 (direction 1) + # observation[12] : 1 if there is a path (track/branch) otherwise 0 (direction 2) + # observation[13] : 1 if there is a path (track/branch) otherwise 0 (direction 3) + # observation[14] : If there is a path with step (direction 0) and there is a agent with opposite direction -> 1 + # observation[15] : If there is a path with step (direction 1) and there is a agent with opposite direction -> 1 + # observation[16] : If there is a path with step (direction 2) and there is a agent with opposite direction -> 1 + # observation[17] : If there is a path with step (direction 3) and there is a agent with opposite direction -> 1 + # observation[18] : If there is a path with step (direction 0) and there is a agent with same direction -> 1 + # observation[19] : If there is a path with step (direction 1) and there is a agent with same direction -> 1 + # observation[20] : If there is a path with step (direction 2) and there is a agent with same direction -> 1 + # observation[21] : If there is a path with step (direction 3) and there is a agent with same direction -> 1 + # observation[22] : If there is a switch on the path which agent can not use -> 1 + # observation[23] : If there is a switch on the path which agent can not use -> 1 + # observation[24] : If there is a switch on the path which agent can not use -> 1 + # observation[25] : If there is a switch on the path which agent can not use -> 1 + # observation[26] : Is there a deadlock signal on shortest path walk(s) (direction 0)-> 1 + # observation[27] : Is there a deadlock signal on shortest path walk(s) (direction 1)-> 1 + # observation[28] : Is there a deadlock signal on shortest path walk(s) (direction 2)-> 1 + # observation[29] : Is there a deadlock signal on shortest path walk(s) (direction 3)-> 1 + # observation[30] : Is there a deadlock signal on shortest path walk(s) (current position check)-> 1 + + observation = np.zeros(self.observation_dim) + visited = [] + agent = self.env.agents[handle] + + agent_done = False + if agent.status == RailAgentStatus.READY_TO_DEPART: + agent_virtual_position = agent.initial_position + observation[4] = 1 + elif agent.status == RailAgentStatus.ACTIVE: + agent_virtual_position = agent.position + observation[5] = 1 + else: + observation[6] = 1 + agent_virtual_position = (-1, -1) + agent_done = True + + if not agent_done: + visited.append(agent_virtual_position) + distance_map = self.env.distance_map.get() + current_cell_dist = distance_map[handle, + agent_virtual_position[0], agent_virtual_position[1], + agent.direction] + possible_transitions = self.env.rail.get_transitions(*agent_virtual_position, agent.direction) + orientation = agent.direction + if fast_count_nonzero(possible_transitions) == 1: + orientation = fast_argmax(possible_transitions) + + for dir_loop, branch_direction in enumerate([(orientation + dir_loop) % 4 for dir_loop in range(-1, 3)]): + if possible_transitions[branch_direction]: + new_position = get_new_position(agent_virtual_position, branch_direction) + new_cell_dist = distance_map[handle, + new_position[0], new_position[1], + branch_direction] + if not (np.math.isinf(new_cell_dist) and np.math.isinf(current_cell_dist)): + observation[dir_loop] = int(new_cell_dist < current_cell_dist) + + has_opp_agent, has_same_agent, has_switch, v = self._explore(handle, new_position, branch_direction) + visited.append(v) + + observation[10 + dir_loop] = 1 + observation[14 + dir_loop] = has_opp_agent + observation[18 + dir_loop] = has_same_agent + observation[22 + dir_loop] = has_switch + + next_step_ok = self._check_dead_lock_at_branching_position(handle, new_position, branch_direction) + if next_step_ok: + observation[26 + dir_loop] = 1 + + agents_on_switch, \ + agents_near_to_switch, \ + agents_near_to_switch_all, \ + agents_on_switch_all = \ + self.check_agent_descision(agent_virtual_position, agent.direction) + observation[7] = int(agents_on_switch) + observation[8] = int(agents_near_to_switch) + observation[9] = int(agents_near_to_switch_all) + + observation[30] = int(self.dead_lock_avoidance_agent.act(handle, None, 0) == RailEnvActions.STOP_MOVING) + + self.env.dev_obs_dict.update({handle: visited}) + + return observation + + @staticmethod + def agent_can_choose(observation): + return observation[7] == 1 or observation[8] == 1 diff --git a/utils/shortest_Distance_walker.py b/utils/shortest_Distance_walker.py new file mode 100644 index 0000000000000000000000000000000000000000..62b686fff0f61a13c48220565553d7e63067739a --- /dev/null +++ b/utils/shortest_Distance_walker.py @@ -0,0 +1,87 @@ +import numpy as np +from flatland.core.grid.grid4_utils import get_new_position +from flatland.envs.rail_env import RailEnv, RailEnvActions +from flatland.envs.rail_env import fast_count_nonzero, fast_argmax + + +class ShortestDistanceWalker: + def __init__(self, env: RailEnv): + self.env = env + + def walk(self, handle, position, direction): + possible_transitions = self.env.rail.get_transitions(*position, direction) + num_transitions = fast_count_nonzero(possible_transitions) + if num_transitions == 1: + new_direction = fast_argmax(possible_transitions) + new_position = get_new_position(position, new_direction) + + dist = self.env.distance_map.get()[handle, new_position[0], new_position[1], new_direction] + return new_position, new_direction, dist, RailEnvActions.MOVE_FORWARD, possible_transitions + else: + min_distances = [] + positions = [] + directions = [] + for new_direction in [(direction + i) % 4 for i in range(-1, 2)]: + if possible_transitions[new_direction]: + new_position = get_new_position(position, new_direction) + min_distances.append( + self.env.distance_map.get()[handle, new_position[0], new_position[1], new_direction]) + positions.append(new_position) + directions.append(new_direction) + else: + min_distances.append(np.inf) + positions.append(None) + directions.append(None) + + a = self.get_action(handle, min_distances) + return positions[a], directions[a], min_distances[a], a + 1, possible_transitions + + def get_action(self, handle, min_distances): + return np.argmin(min_distances) + + def callback(self, handle, agent, position, direction, action, possible_transitions): + pass + + def get_agent_position_and_direction(self, handle): + agent = self.env.agents[handle] + if agent.position is not None: + position = agent.position + else: + position = agent.initial_position + direction = agent.direction + return position, direction + + def walk_to_target(self, handle, position=None, direction=None, max_step=500): + if position is None and direction is None: + position, direction = self.get_agent_position_and_direction(handle) + elif position is None: + position, _ = self.get_agent_position_and_direction(handle) + elif direction is None: + _, direction = self.get_agent_position_and_direction(handle) + + agent = self.env.agents[handle] + step = 0 + while (position != agent.target) and (step < max_step): + position, direction, dist, action, possible_transitions = self.walk(handle, position, direction) + if position is None: + break + self.callback(handle, agent, position, direction, action, possible_transitions) + step += 1 + + def callback_one_step(self, handle, agent, position, direction, action, possible_transitions): + pass + + def walk_one_step(self, handle): + agent = self.env.agents[handle] + if agent.position is not None: + position = agent.position + else: + position = agent.initial_position + direction = agent.direction + possible_transitions = (0, 1, 0, 0) + if (position != agent.target): + new_position, new_direction, dist, action, possible_transitions = self.walk(handle, position, direction) + if new_position is None: + return position, direction, RailEnvActions.STOP_MOVING, possible_transitions + self.callback_one_step(handle, agent, new_position, new_direction, action, possible_transitions) + return new_position, new_direction, action, possible_transitions