Skip to content
Snippets Groups Projects
Commit d4bdcaf4 authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files

Working version

parent 1865ae1d
No related branches found
Tags submission-v0.1
No related merge requests found
File added
File added
File added
File added
from flatland.evaluators.client import FlatlandRemoteClient
from flatland.core.env_observation_builder import DummyObservationBuilder
from my_observation_builder import CustomObservationBuilder
import numpy as np
import time
import numpy as np
from flatland.envs.agent_utils import RailAgentStatus
from flatland.evaluators.client import FlatlandRemoteClient
#####################################################################
# Instantiate a Remote Client
#####################################################################
from src.extra import Extra
from src.observations import MyTreeObsForRailEnv
remote_client = FlatlandRemoteClient()
#####################################################################
# Define your custom controller
#
......@@ -18,11 +20,9 @@ remote_client = FlatlandRemoteClient()
# compute the necessary action for this step for all (or even some)
# of the agents
#####################################################################
def my_controller(obs, number_of_agents):
_action = {}
for _idx in range(number_of_agents):
_action[_idx] = np.random.randint(0, 5)
return _action
def my_controller(extra: Extra, observation, my_observation_builder):
return extra.rl_agent_act(observation, my_observation_builder.max_depth)
#####################################################################
# Instantiate your custom Observation Builder
......@@ -31,7 +31,7 @@ def my_controller(obs, number_of_agents):
# the example here :
# https://gitlab.aicrowd.com/flatland/flatland/blob/master/flatland/envs/observations.py#L14
#####################################################################
my_observation_builder = CustomObservationBuilder()
my_observation_builder = MyTreeObsForRailEnv(max_depth=3)
# Or if you want to use your own approach to build the observation from the env_step,
# please feel free to pass a DummyObservationBuilder() object as mentioned below,
......@@ -61,9 +61,8 @@ while True:
# over the observation of your choice.
time_start = time.time()
observation, info = remote_client.env_create(
obs_builder_object=my_observation_builder
)
env_creation_time = time.time() - time_start
obs_builder_object=my_observation_builder
)
if not observation:
#
# If the remote_client returns False on a `env_create` call,
......@@ -71,7 +70,7 @@ while True:
# evaluated on all the required evaluation environments,
# and hence its safe to break out of the main evaluation loop
break
print("Evaluation Number : {}".format(evaluation_number))
#####################################################################
......@@ -106,6 +105,14 @@ while True:
time_taken_by_controller = []
time_taken_per_step = []
steps = 0
extra = Extra(local_env)
env_creation_time = time.time() - time_start
print("Env Creation Time : ", env_creation_time)
print("Agents : ", extra.env.get_num_agents())
print("w : ", extra.env.width)
print("h : ", extra.env.height)
while True:
#####################################################################
# Evaluation of a single episode
......@@ -114,7 +121,7 @@ while True:
# Compute the action for this step by using the previously
# defined controller
time_start = time.time()
action = my_controller(observation, number_of_agents)
action = my_controller(extra, observation, my_observation_builder)
time_taken = time.time() - time_start
time_taken_by_controller.append(time_taken)
......@@ -129,6 +136,12 @@ while True:
time_taken = time.time() - time_start
time_taken_per_step.append(time_taken)
total_done = 0
for a in range(local_env.get_num_agents()):
x = (local_env.agents[a].status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED])
total_done += int(x)
print("total_done:", total_done)
if done['__all__']:
print("Reward : ", sum(list(all_rewards.values())))
#
......@@ -136,18 +149,19 @@ while True:
# particular Env instantiation is complete, and we can break out
# of this loop, and move onto the next Env evaluation
break
np_time_taken_by_controller = np.array(time_taken_by_controller)
np_time_taken_per_step = np.array(time_taken_per_step)
print("="*100)
print("="*100)
print("=" * 100)
print("=" * 100)
print("Evaluation Number : ", evaluation_number)
print("Current Env Path : ", remote_client.current_env_path)
print("Env Creation Time : ", env_creation_time)
print("Number of Steps : ", steps)
print("Mean/Std of Time taken by Controller : ", np_time_taken_by_controller.mean(), np_time_taken_by_controller.std())
print("Mean/Std of Time taken by Controller : ", np_time_taken_by_controller.mean(),
np_time_taken_by_controller.std())
print("Mean/Std of Time per Step : ", np_time_taken_per_step.mean(), np_time_taken_per_step.std())
print("="*100)
print("=" * 100)
print("Evaluation of all environments complete...")
########################################################################
......
import numpy as np
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_env import RailEnvActions
from flatland.utils.rendertools import RenderTool, AgentRenderVariant
from src.agent.dueling_double_dqn import Agent
from src.observations import normalize_observation
state_size = 179
action_size = 5
print("state_size: ", state_size)
print("action_size: ", action_size)
# Now we load a Double dueling DQN agent
global_rl_agent = Agent(state_size, action_size, "FC", 0)
global_rl_agent.load('./nets/training_best_0.626_agents_5276.pth')
class Extra:
global_rl_agent = None
def __init__(self, env: RailEnv):
self.env = env
self.rl_agent = global_rl_agent
self.switches = {}
self.switches_neighbours = {}
self.find_all_cell_where_agent_can_choose()
self.steps_counter = 0
self.debug_render_list = []
self.debug_render_path_list = []
def rl_agent_act(self, observation, max_depth, eps=0.0):
self.steps_counter += 1
print(self.steps_counter, self.env.get_num_agents())
agent_obs = [None] * self.env.get_num_agents()
for a in range(self.env.get_num_agents()):
if observation[a]:
agent_obs[a] = self.generate_state(a, observation, max_depth)
action_dict = {}
# estimate whether the agent(s) can freely choose an action
agents_can_choose, agents_on_switch, agents_near_to_switch, agents_near_to_switch_all = \
self.required_agent_descision()
for a in range(self.env.get_num_agents()):
if agent_obs[a] is not None:
if agents_can_choose[a]:
act, agent_rnd = self.rl_agent.act(agent_obs[a], eps=eps)
l = len(agent_obs[a])
if agent_obs[a][l - 3] > 0 and agents_near_to_switch_all[a]:
act = RailEnvActions.STOP_MOVING
action_dict.update({a: act})
else:
act = RailEnvActions.MOVE_FORWARD
action_dict.update({a: act})
else:
action_dict.update({a: RailEnvActions.DO_NOTHING})
return action_dict
def find_all_cell_where_agent_can_choose(self):
switches = {}
for h in range(self.env.height):
for w in range(self.env.width):
pos = (h, w)
for dir in range(4):
possible_transitions = self.env.rail.get_transitions(*pos, dir)
num_transitions = np.count_nonzero(possible_transitions)
if num_transitions > 1:
if pos not in switches.keys():
switches.update({pos: [dir]})
else:
switches[pos].append(dir)
switches_neighbours = {}
for h in range(self.env.height):
for w in range(self.env.width):
# look one step forward
for dir in range(4):
pos = (h, w)
possible_transitions = self.env.rail.get_transitions(*pos, dir)
for d in range(4):
if possible_transitions[d] == 1:
new_cell = get_new_position(pos, d)
if new_cell in switches.keys() and pos not in switches.keys():
if pos not in switches_neighbours.keys():
switches_neighbours.update({pos: [dir]})
else:
switches_neighbours[pos].append(dir)
self.switches = switches
self.switches_neighbours = switches_neighbours
def check_agent_descision(self, position, direction, switches, switches_neighbours):
agents_on_switch = False
agents_near_to_switch = False
agents_near_to_switch_all = False
if position in switches.keys():
agents_on_switch = direction in switches[position]
if position in switches_neighbours.keys():
new_cell = get_new_position(position, direction)
if new_cell in switches.keys():
if not direction in switches[new_cell]:
agents_near_to_switch = direction in switches_neighbours[position]
else:
agents_near_to_switch = direction in switches_neighbours[position]
agents_near_to_switch_all = direction in switches_neighbours[position]
return agents_on_switch, agents_near_to_switch, agents_near_to_switch_all
def required_agent_descision(self):
agents_can_choose = {}
agents_on_switch = {}
agents_near_to_switch = {}
agents_near_to_switch_all = {}
for a in range(self.env.get_num_agents()):
ret_agents_on_switch, ret_agents_near_to_switch, ret_agents_near_to_switch_all = \
self.check_agent_descision(
self.env.agents[a].position,
self.env.agents[a].direction,
self.switches,
self.switches_neighbours)
agents_on_switch.update({a: ret_agents_on_switch})
ready_to_depart = self.env.agents[a].status == RailAgentStatus.READY_TO_DEPART
agents_near_to_switch.update({a: (ret_agents_near_to_switch or ready_to_depart)})
agents_can_choose.update({a: agents_on_switch[a] or agents_near_to_switch[a]})
agents_near_to_switch_all.update({a: (ret_agents_near_to_switch_all or ready_to_depart)})
return agents_can_choose, agents_on_switch, agents_near_to_switch, agents_near_to_switch_all
def check_deadlock(self, only_next_cell_check=False, handle=None):
agents_with_deadlock = []
agents = range(self.env.get_num_agents())
if handle is not None:
agents = [handle]
for a in agents:
if self.env.agents[a].status < RailAgentStatus.DONE:
position = self.env.agents[a].position
first_step = True
if position is None:
position = self.env.agents[a].initial_position
first_step = True
direction = self.env.agents[a].direction
while position is not None: # and position != self.env.agents[a].target:
possible_transitions = self.env.rail.get_transitions(*position, direction)
# num_transitions = np.count_nonzero(possible_transitions)
agents_on_switch, agents_near_to_switch, agents_near_to_switch_all = self.check_agent_descision(
position,
direction,
self.switches,
self.switches_neighbours)
if not agents_on_switch or first_step:
first_step = False
new_direction_me = np.argmax(possible_transitions)
new_cell_me = get_new_position(position, new_direction_me)
opp_agent = self.env.agent_positions[new_cell_me]
if opp_agent != -1:
opp_position = self.env.agents[opp_agent].position
opp_direction = self.env.agents[opp_agent].direction
opp_agents_on_switch, opp_agents_near_to_switch, agents_near_to_switch_all = \
self.check_agent_descision(opp_position,
opp_direction,
self.switches,
self.switches_neighbours)
# opp_possible_transitions = self.env.rail.get_transitions(*opp_position, opp_direction)
# opp_num_transitions = np.count_nonzero(opp_possible_transitions)
if not opp_agents_on_switch:
if opp_direction != direction:
agents_with_deadlock.append(a)
position = None
else:
if only_next_cell_check:
position = None
else:
position = new_cell_me
direction = new_direction_me
else:
if only_next_cell_check:
position = None
else:
position = new_cell_me
direction = new_direction_me
else:
if only_next_cell_check:
position = None
else:
position = new_cell_me
direction = new_direction_me
else:
position = None
return agents_with_deadlock
def generate_state(self, handle: int, root, max_depth: int):
n_obs = normalize_observation(root[handle], max_depth)
position = self.env.agents[handle].position
direction = self.env.agents[handle].direction
cell_free_4_first_step = -1
deadlock_agents = []
if self.env.agents[handle].status == RailAgentStatus.READY_TO_DEPART:
if self.env.agent_positions[self.env.agents[handle].initial_position] == -1:
cell_free_4_first_step = 1
position = self.env.agents[handle].initial_position
else:
deadlock_agents = self.check_deadlock(only_next_cell_check=False, handle=handle)
agents_on_switch, agents_near_to_switch, agents_near_to_switch_all = self.check_agent_descision(position,
direction,
self.switches,
self.switches_neighbours)
append_obs = [self.env.agents[handle].status - RailAgentStatus.ACTIVE,
cell_free_4_first_step,
2.0 * int(len(deadlock_agents)) - 1.0,
2.0 * int(agents_on_switch) - 1.0,
2.0 * int(agents_near_to_switch) - 1.0]
n_obs = np.append(n_obs, append_obs)
return n_obs
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment