Skip to content
Snippets Groups Projects
Commit d4bdcaf4 authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files

Working version

parent 1865ae1d
No related branches found
No related tags found
No related merge requests found
File added
File added
File added
File added
from flatland.evaluators.client import FlatlandRemoteClient
from flatland.core.env_observation_builder import DummyObservationBuilder
from my_observation_builder import CustomObservationBuilder
import numpy as np
import time
import numpy as np
from flatland.envs.agent_utils import RailAgentStatus
from flatland.evaluators.client import FlatlandRemoteClient
#####################################################################
# Instantiate a Remote Client
#####################################################################
from src.extra import Extra
from src.observations import MyTreeObsForRailEnv
remote_client = FlatlandRemoteClient()
#####################################################################
# Define your custom controller
#
......@@ -18,11 +20,9 @@ remote_client = FlatlandRemoteClient()
# compute the necessary action for this step for all (or even some)
# of the agents
#####################################################################
def my_controller(obs, number_of_agents):
_action = {}
for _idx in range(number_of_agents):
_action[_idx] = np.random.randint(0, 5)
return _action
def my_controller(extra: Extra, observation, my_observation_builder):
return extra.rl_agent_act(observation, my_observation_builder.max_depth)
#####################################################################
# Instantiate your custom Observation Builder
......@@ -31,7 +31,7 @@ def my_controller(obs, number_of_agents):
# the example here :
# https://gitlab.aicrowd.com/flatland/flatland/blob/master/flatland/envs/observations.py#L14
#####################################################################
my_observation_builder = CustomObservationBuilder()
my_observation_builder = MyTreeObsForRailEnv(max_depth=3)
# Or if you want to use your own approach to build the observation from the env_step,
# please feel free to pass a DummyObservationBuilder() object as mentioned below,
......@@ -61,9 +61,8 @@ while True:
# over the observation of your choice.
time_start = time.time()
observation, info = remote_client.env_create(
obs_builder_object=my_observation_builder
)
env_creation_time = time.time() - time_start
obs_builder_object=my_observation_builder
)
if not observation:
#
# If the remote_client returns False on a `env_create` call,
......@@ -71,7 +70,7 @@ while True:
# evaluated on all the required evaluation environments,
# and hence its safe to break out of the main evaluation loop
break
print("Evaluation Number : {}".format(evaluation_number))
#####################################################################
......@@ -106,6 +105,14 @@ while True:
time_taken_by_controller = []
time_taken_per_step = []
steps = 0
extra = Extra(local_env)
env_creation_time = time.time() - time_start
print("Env Creation Time : ", env_creation_time)
print("Agents : ", extra.env.get_num_agents())
print("w : ", extra.env.width)
print("h : ", extra.env.height)
while True:
#####################################################################
# Evaluation of a single episode
......@@ -114,7 +121,7 @@ while True:
# Compute the action for this step by using the previously
# defined controller
time_start = time.time()
action = my_controller(observation, number_of_agents)
action = my_controller(extra, observation, my_observation_builder)
time_taken = time.time() - time_start
time_taken_by_controller.append(time_taken)
......@@ -129,6 +136,12 @@ while True:
time_taken = time.time() - time_start
time_taken_per_step.append(time_taken)
total_done = 0
for a in range(local_env.get_num_agents()):
x = (local_env.agents[a].status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED])
total_done += int(x)
print("total_done:", total_done)
if done['__all__']:
print("Reward : ", sum(list(all_rewards.values())))
#
......@@ -136,18 +149,19 @@ while True:
# particular Env instantiation is complete, and we can break out
# of this loop, and move onto the next Env evaluation
break
np_time_taken_by_controller = np.array(time_taken_by_controller)
np_time_taken_per_step = np.array(time_taken_per_step)
print("="*100)
print("="*100)
print("=" * 100)
print("=" * 100)
print("Evaluation Number : ", evaluation_number)
print("Current Env Path : ", remote_client.current_env_path)
print("Env Creation Time : ", env_creation_time)
print("Number of Steps : ", steps)
print("Mean/Std of Time taken by Controller : ", np_time_taken_by_controller.mean(), np_time_taken_by_controller.std())
print("Mean/Std of Time taken by Controller : ", np_time_taken_by_controller.mean(),
np_time_taken_by_controller.std())
print("Mean/Std of Time per Step : ", np_time_taken_per_step.mean(), np_time_taken_per_step.std())
print("="*100)
print("=" * 100)
print("Evaluation of all environments complete...")
########################################################################
......
import numpy as np
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_env import RailEnvActions
from flatland.utils.rendertools import RenderTool, AgentRenderVariant
from src.agent.dueling_double_dqn import Agent
from src.observations import normalize_observation
state_size = 179
action_size = 5
print("state_size: ", state_size)
print("action_size: ", action_size)
# Now we load a Double dueling DQN agent
global_rl_agent = Agent(state_size, action_size, "FC", 0)
global_rl_agent.load('./nets/training_best_0.626_agents_5276.pth')
class Extra:
global_rl_agent = None
def __init__(self, env: RailEnv):
self.env = env
self.rl_agent = global_rl_agent
self.switches = {}
self.switches_neighbours = {}
self.find_all_cell_where_agent_can_choose()
self.steps_counter = 0
self.debug_render_list = []
self.debug_render_path_list = []
def rl_agent_act(self, observation, max_depth, eps=0.0):
self.steps_counter += 1
print(self.steps_counter, self.env.get_num_agents())
agent_obs = [None] * self.env.get_num_agents()
for a in range(self.env.get_num_agents()):
if observation[a]:
agent_obs[a] = self.generate_state(a, observation, max_depth)
action_dict = {}
# estimate whether the agent(s) can freely choose an action
agents_can_choose, agents_on_switch, agents_near_to_switch, agents_near_to_switch_all = \
self.required_agent_descision()
for a in range(self.env.get_num_agents()):
if agent_obs[a] is not None:
if agents_can_choose[a]:
act, agent_rnd = self.rl_agent.act(agent_obs[a], eps=eps)
l = len(agent_obs[a])
if agent_obs[a][l - 3] > 0 and agents_near_to_switch_all[a]:
act = RailEnvActions.STOP_MOVING
action_dict.update({a: act})
else:
act = RailEnvActions.MOVE_FORWARD
action_dict.update({a: act})
else:
action_dict.update({a: RailEnvActions.DO_NOTHING})
return action_dict
def find_all_cell_where_agent_can_choose(self):
switches = {}
for h in range(self.env.height):
for w in range(self.env.width):
pos = (h, w)
for dir in range(4):
possible_transitions = self.env.rail.get_transitions(*pos, dir)
num_transitions = np.count_nonzero(possible_transitions)
if num_transitions > 1:
if pos not in switches.keys():
switches.update({pos: [dir]})
else:
switches[pos].append(dir)
switches_neighbours = {}
for h in range(self.env.height):
for w in range(self.env.width):
# look one step forward
for dir in range(4):
pos = (h, w)
possible_transitions = self.env.rail.get_transitions(*pos, dir)
for d in range(4):
if possible_transitions[d] == 1:
new_cell = get_new_position(pos, d)
if new_cell in switches.keys() and pos not in switches.keys():
if pos not in switches_neighbours.keys():
switches_neighbours.update({pos: [dir]})
else:
switches_neighbours[pos].append(dir)
self.switches = switches
self.switches_neighbours = switches_neighbours
def check_agent_descision(self, position, direction, switches, switches_neighbours):
agents_on_switch = False
agents_near_to_switch = False
agents_near_to_switch_all = False
if position in switches.keys():
agents_on_switch = direction in switches[position]
if position in switches_neighbours.keys():
new_cell = get_new_position(position, direction)
if new_cell in switches.keys():
if not direction in switches[new_cell]:
agents_near_to_switch = direction in switches_neighbours[position]
else:
agents_near_to_switch = direction in switches_neighbours[position]
agents_near_to_switch_all = direction in switches_neighbours[position]
return agents_on_switch, agents_near_to_switch, agents_near_to_switch_all
def required_agent_descision(self):
agents_can_choose = {}
agents_on_switch = {}
agents_near_to_switch = {}
agents_near_to_switch_all = {}
for a in range(self.env.get_num_agents()):
ret_agents_on_switch, ret_agents_near_to_switch, ret_agents_near_to_switch_all = \
self.check_agent_descision(
self.env.agents[a].position,
self.env.agents[a].direction,
self.switches,
self.switches_neighbours)
agents_on_switch.update({a: ret_agents_on_switch})
ready_to_depart = self.env.agents[a].status == RailAgentStatus.READY_TO_DEPART
agents_near_to_switch.update({a: (ret_agents_near_to_switch or ready_to_depart)})
agents_can_choose.update({a: agents_on_switch[a] or agents_near_to_switch[a]})
agents_near_to_switch_all.update({a: (ret_agents_near_to_switch_all or ready_to_depart)})
return agents_can_choose, agents_on_switch, agents_near_to_switch, agents_near_to_switch_all
def check_deadlock(self, only_next_cell_check=False, handle=None):
agents_with_deadlock = []
agents = range(self.env.get_num_agents())
if handle is not None:
agents = [handle]
for a in agents:
if self.env.agents[a].status < RailAgentStatus.DONE:
position = self.env.agents[a].position
first_step = True
if position is None:
position = self.env.agents[a].initial_position
first_step = True
direction = self.env.agents[a].direction
while position is not None: # and position != self.env.agents[a].target:
possible_transitions = self.env.rail.get_transitions(*position, direction)
# num_transitions = np.count_nonzero(possible_transitions)
agents_on_switch, agents_near_to_switch, agents_near_to_switch_all = self.check_agent_descision(
position,
direction,
self.switches,
self.switches_neighbours)
if not agents_on_switch or first_step:
first_step = False
new_direction_me = np.argmax(possible_transitions)
new_cell_me = get_new_position(position, new_direction_me)
opp_agent = self.env.agent_positions[new_cell_me]
if opp_agent != -1:
opp_position = self.env.agents[opp_agent].position
opp_direction = self.env.agents[opp_agent].direction
opp_agents_on_switch, opp_agents_near_to_switch, agents_near_to_switch_all = \
self.check_agent_descision(opp_position,
opp_direction,
self.switches,
self.switches_neighbours)
# opp_possible_transitions = self.env.rail.get_transitions(*opp_position, opp_direction)
# opp_num_transitions = np.count_nonzero(opp_possible_transitions)
if not opp_agents_on_switch:
if opp_direction != direction:
agents_with_deadlock.append(a)
position = None
else:
if only_next_cell_check:
position = None
else:
position = new_cell_me
direction = new_direction_me
else:
if only_next_cell_check:
position = None
else:
position = new_cell_me
direction = new_direction_me
else:
if only_next_cell_check:
position = None
else:
position = new_cell_me
direction = new_direction_me
else:
position = None
return agents_with_deadlock
def generate_state(self, handle: int, root, max_depth: int):
n_obs = normalize_observation(root[handle], max_depth)
position = self.env.agents[handle].position
direction = self.env.agents[handle].direction
cell_free_4_first_step = -1
deadlock_agents = []
if self.env.agents[handle].status == RailAgentStatus.READY_TO_DEPART:
if self.env.agent_positions[self.env.agents[handle].initial_position] == -1:
cell_free_4_first_step = 1
position = self.env.agents[handle].initial_position
else:
deadlock_agents = self.check_deadlock(only_next_cell_check=False, handle=handle)
agents_on_switch, agents_near_to_switch, agents_near_to_switch_all = self.check_agent_descision(position,
direction,
self.switches,
self.switches_neighbours)
append_obs = [self.env.agents[handle].status - RailAgentStatus.ACTIVE,
cell_free_4_first_step,
2.0 * int(len(deadlock_agents)) - 1.0,
2.0 * int(agents_on_switch) - 1.0,
2.0 * int(agents_near_to_switch) - 1.0]
n_obs = np.append(n_obs, append_obs)
return n_obs
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment