diff --git a/nets/training_5500.pth.local b/nets/training_5500.pth.local new file mode 100644 index 0000000000000000000000000000000000000000..e77bb01ae7aef933e59799528c325a412b520290 Binary files /dev/null and b/nets/training_5500.pth.local differ diff --git a/nets/training_5500.pth.target b/nets/training_5500.pth.target new file mode 100644 index 0000000000000000000000000000000000000000..a086633e9a959abb573f4231dc76d9a8123f4861 Binary files /dev/null and b/nets/training_5500.pth.target differ diff --git a/nets/training_best_0.626_agents_5276.pth.local b/nets/training_best_0.626_agents_5276.pth.local new file mode 100644 index 0000000000000000000000000000000000000000..0a080fcc30deae97f34610670bce980599e648d6 Binary files /dev/null and b/nets/training_best_0.626_agents_5276.pth.local differ diff --git a/nets/training_best_0.626_agents_5276.pth.target b/nets/training_best_0.626_agents_5276.pth.target new file mode 100644 index 0000000000000000000000000000000000000000..58f0d0bcaca15481c42523a21cda75cb226a782e Binary files /dev/null and b/nets/training_best_0.626_agents_5276.pth.target differ diff --git a/run.py b/run.py index 5c5bb9a020297404cc25e31dcf0325d052723db0..855dac19dcfdd6d4a971a8e74ee36c57591b45fe 100644 --- a/run.py +++ b/run.py @@ -1,16 +1,18 @@ -from flatland.evaluators.client import FlatlandRemoteClient -from flatland.core.env_observation_builder import DummyObservationBuilder -from my_observation_builder import CustomObservationBuilder -import numpy as np import time - +import numpy as np +from flatland.envs.agent_utils import RailAgentStatus +from flatland.evaluators.client import FlatlandRemoteClient ##################################################################### # Instantiate a Remote Client ##################################################################### +from src.extra import Extra +from src.observations import MyTreeObsForRailEnv + remote_client = FlatlandRemoteClient() + ##################################################################### # Define your custom controller # @@ -18,11 +20,9 @@ remote_client = FlatlandRemoteClient() # compute the necessary action for this step for all (or even some) # of the agents ##################################################################### -def my_controller(obs, number_of_agents): - _action = {} - for _idx in range(number_of_agents): - _action[_idx] = np.random.randint(0, 5) - return _action +def my_controller(extra: Extra, observation, my_observation_builder): + return extra.rl_agent_act(observation, my_observation_builder.max_depth) + ##################################################################### # Instantiate your custom Observation Builder @@ -31,7 +31,7 @@ def my_controller(obs, number_of_agents): # the example here : # https://gitlab.aicrowd.com/flatland/flatland/blob/master/flatland/envs/observations.py#L14 ##################################################################### -my_observation_builder = CustomObservationBuilder() +my_observation_builder = MyTreeObsForRailEnv(max_depth=3) # Or if you want to use your own approach to build the observation from the env_step, # please feel free to pass a DummyObservationBuilder() object as mentioned below, @@ -61,9 +61,8 @@ while True: # over the observation of your choice. time_start = time.time() observation, info = remote_client.env_create( - obs_builder_object=my_observation_builder - ) - env_creation_time = time.time() - time_start + obs_builder_object=my_observation_builder + ) if not observation: # # If the remote_client returns False on a `env_create` call, @@ -71,7 +70,7 @@ while True: # evaluated on all the required evaluation environments, # and hence its safe to break out of the main evaluation loop break - + print("Evaluation Number : {}".format(evaluation_number)) ##################################################################### @@ -106,6 +105,14 @@ while True: time_taken_by_controller = [] time_taken_per_step = [] steps = 0 + + extra = Extra(local_env) + env_creation_time = time.time() - time_start + print("Env Creation Time : ", env_creation_time) + print("Agents : ", extra.env.get_num_agents()) + print("w : ", extra.env.width) + print("h : ", extra.env.height) + while True: ##################################################################### # Evaluation of a single episode @@ -114,7 +121,7 @@ while True: # Compute the action for this step by using the previously # defined controller time_start = time.time() - action = my_controller(observation, number_of_agents) + action = my_controller(extra, observation, my_observation_builder) time_taken = time.time() - time_start time_taken_by_controller.append(time_taken) @@ -129,6 +136,12 @@ while True: time_taken = time.time() - time_start time_taken_per_step.append(time_taken) + total_done = 0 + for a in range(local_env.get_num_agents()): + x = (local_env.agents[a].status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED]) + total_done += int(x) + print("total_done:", total_done) + if done['__all__']: print("Reward : ", sum(list(all_rewards.values()))) # @@ -136,18 +149,19 @@ while True: # particular Env instantiation is complete, and we can break out # of this loop, and move onto the next Env evaluation break - + np_time_taken_by_controller = np.array(time_taken_by_controller) np_time_taken_per_step = np.array(time_taken_per_step) - print("="*100) - print("="*100) + print("=" * 100) + print("=" * 100) print("Evaluation Number : ", evaluation_number) print("Current Env Path : ", remote_client.current_env_path) print("Env Creation Time : ", env_creation_time) print("Number of Steps : ", steps) - print("Mean/Std of Time taken by Controller : ", np_time_taken_by_controller.mean(), np_time_taken_by_controller.std()) + print("Mean/Std of Time taken by Controller : ", np_time_taken_by_controller.mean(), + np_time_taken_by_controller.std()) print("Mean/Std of Time per Step : ", np_time_taken_per_step.mean(), np_time_taken_per_step.std()) - print("="*100) + print("=" * 100) print("Evaluation of all environments complete...") ######################################################################## diff --git a/src/extra.py b/src/extra.py new file mode 100644 index 0000000000000000000000000000000000000000..a6e237102cc64dd475e0b25a01a518cfca8149dd --- /dev/null +++ b/src/extra.py @@ -0,0 +1,230 @@ +import numpy as np +from flatland.core.grid.grid4_utils import get_new_position +from flatland.envs.agent_utils import RailAgentStatus +from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_env import RailEnvActions +from flatland.utils.rendertools import RenderTool, AgentRenderVariant + +from src.agent.dueling_double_dqn import Agent +from src.observations import normalize_observation + +state_size = 179 +action_size = 5 +print("state_size: ", state_size) +print("action_size: ", action_size) +# Now we load a Double dueling DQN agent +global_rl_agent = Agent(state_size, action_size, "FC", 0) +global_rl_agent.load('./nets/training_best_0.626_agents_5276.pth') + + +class Extra: + global_rl_agent = None + + def __init__(self, env: RailEnv): + self.env = env + self.rl_agent = global_rl_agent + self.switches = {} + self.switches_neighbours = {} + self.find_all_cell_where_agent_can_choose() + self.steps_counter = 0 + + self.debug_render_list = [] + self.debug_render_path_list = [] + + def rl_agent_act(self, observation, max_depth, eps=0.0): + + self.steps_counter += 1 + print(self.steps_counter, self.env.get_num_agents()) + + agent_obs = [None] * self.env.get_num_agents() + for a in range(self.env.get_num_agents()): + if observation[a]: + agent_obs[a] = self.generate_state(a, observation, max_depth) + + action_dict = {} + # estimate whether the agent(s) can freely choose an action + agents_can_choose, agents_on_switch, agents_near_to_switch, agents_near_to_switch_all = \ + self.required_agent_descision() + + for a in range(self.env.get_num_agents()): + if agent_obs[a] is not None: + if agents_can_choose[a]: + act, agent_rnd = self.rl_agent.act(agent_obs[a], eps=eps) + + l = len(agent_obs[a]) + if agent_obs[a][l - 3] > 0 and agents_near_to_switch_all[a]: + act = RailEnvActions.STOP_MOVING + + action_dict.update({a: act}) + else: + act = RailEnvActions.MOVE_FORWARD + action_dict.update({a: act}) + else: + action_dict.update({a: RailEnvActions.DO_NOTHING}) + return action_dict + + def find_all_cell_where_agent_can_choose(self): + switches = {} + for h in range(self.env.height): + for w in range(self.env.width): + pos = (h, w) + for dir in range(4): + possible_transitions = self.env.rail.get_transitions(*pos, dir) + num_transitions = np.count_nonzero(possible_transitions) + if num_transitions > 1: + if pos not in switches.keys(): + switches.update({pos: [dir]}) + else: + switches[pos].append(dir) + + switches_neighbours = {} + for h in range(self.env.height): + for w in range(self.env.width): + # look one step forward + for dir in range(4): + pos = (h, w) + possible_transitions = self.env.rail.get_transitions(*pos, dir) + for d in range(4): + if possible_transitions[d] == 1: + new_cell = get_new_position(pos, d) + if new_cell in switches.keys() and pos not in switches.keys(): + if pos not in switches_neighbours.keys(): + switches_neighbours.update({pos: [dir]}) + else: + switches_neighbours[pos].append(dir) + + self.switches = switches + self.switches_neighbours = switches_neighbours + + def check_agent_descision(self, position, direction, switches, switches_neighbours): + agents_on_switch = False + agents_near_to_switch = False + agents_near_to_switch_all = False + if position in switches.keys(): + agents_on_switch = direction in switches[position] + + if position in switches_neighbours.keys(): + new_cell = get_new_position(position, direction) + if new_cell in switches.keys(): + if not direction in switches[new_cell]: + agents_near_to_switch = direction in switches_neighbours[position] + else: + agents_near_to_switch = direction in switches_neighbours[position] + + agents_near_to_switch_all = direction in switches_neighbours[position] + + return agents_on_switch, agents_near_to_switch, agents_near_to_switch_all + + def required_agent_descision(self): + agents_can_choose = {} + agents_on_switch = {} + agents_near_to_switch = {} + agents_near_to_switch_all = {} + for a in range(self.env.get_num_agents()): + ret_agents_on_switch, ret_agents_near_to_switch, ret_agents_near_to_switch_all = \ + self.check_agent_descision( + self.env.agents[a].position, + self.env.agents[a].direction, + self.switches, + self.switches_neighbours) + agents_on_switch.update({a: ret_agents_on_switch}) + ready_to_depart = self.env.agents[a].status == RailAgentStatus.READY_TO_DEPART + agents_near_to_switch.update({a: (ret_agents_near_to_switch or ready_to_depart)}) + + agents_can_choose.update({a: agents_on_switch[a] or agents_near_to_switch[a]}) + + agents_near_to_switch_all.update({a: (ret_agents_near_to_switch_all or ready_to_depart)}) + + return agents_can_choose, agents_on_switch, agents_near_to_switch, agents_near_to_switch_all + + def check_deadlock(self, only_next_cell_check=False, handle=None): + agents_with_deadlock = [] + agents = range(self.env.get_num_agents()) + if handle is not None: + agents = [handle] + for a in agents: + if self.env.agents[a].status < RailAgentStatus.DONE: + position = self.env.agents[a].position + first_step = True + if position is None: + position = self.env.agents[a].initial_position + first_step = True + direction = self.env.agents[a].direction + while position is not None: # and position != self.env.agents[a].target: + possible_transitions = self.env.rail.get_transitions(*position, direction) + # num_transitions = np.count_nonzero(possible_transitions) + agents_on_switch, agents_near_to_switch, agents_near_to_switch_all = self.check_agent_descision( + position, + direction, + self.switches, + self.switches_neighbours) + + if not agents_on_switch or first_step: + first_step = False + new_direction_me = np.argmax(possible_transitions) + new_cell_me = get_new_position(position, new_direction_me) + opp_agent = self.env.agent_positions[new_cell_me] + if opp_agent != -1: + opp_position = self.env.agents[opp_agent].position + opp_direction = self.env.agents[opp_agent].direction + opp_agents_on_switch, opp_agents_near_to_switch, agents_near_to_switch_all = \ + self.check_agent_descision(opp_position, + opp_direction, + self.switches, + self.switches_neighbours) + + # opp_possible_transitions = self.env.rail.get_transitions(*opp_position, opp_direction) + # opp_num_transitions = np.count_nonzero(opp_possible_transitions) + if not opp_agents_on_switch: + if opp_direction != direction: + agents_with_deadlock.append(a) + position = None + else: + if only_next_cell_check: + position = None + else: + position = new_cell_me + direction = new_direction_me + else: + if only_next_cell_check: + position = None + else: + position = new_cell_me + direction = new_direction_me + else: + if only_next_cell_check: + position = None + else: + position = new_cell_me + direction = new_direction_me + else: + position = None + + return agents_with_deadlock + + def generate_state(self, handle: int, root, max_depth: int): + n_obs = normalize_observation(root[handle], max_depth) + + position = self.env.agents[handle].position + direction = self.env.agents[handle].direction + cell_free_4_first_step = -1 + deadlock_agents = [] + if self.env.agents[handle].status == RailAgentStatus.READY_TO_DEPART: + if self.env.agent_positions[self.env.agents[handle].initial_position] == -1: + cell_free_4_first_step = 1 + position = self.env.agents[handle].initial_position + else: + deadlock_agents = self.check_deadlock(only_next_cell_check=False, handle=handle) + agents_on_switch, agents_near_to_switch, agents_near_to_switch_all = self.check_agent_descision(position, + direction, + self.switches, + self.switches_neighbours) + + append_obs = [self.env.agents[handle].status - RailAgentStatus.ACTIVE, + cell_free_4_first_step, + 2.0 * int(len(deadlock_agents)) - 1.0, + 2.0 * int(agents_on_switch) - 1.0, + 2.0 * int(agents_near_to_switch) - 1.0] + n_obs = np.append(n_obs, append_obs) + + return n_obs