Skip to content
Snippets Groups Projects
Commit f636be2d authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

Merge branch 'stochasticbreaking' into 'master'

Stochasticbreaking

See merge request flatland/flatland!143
parents fc34b470 b1b2c42e
No related branches found
No related tags found
No related merge requests found
import random
import time
import numpy as np
from flatland.envs.generators import complex_rail_generator
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.utils.rendertools import RenderTool
random.seed(1)
np.random.seed(1)
class SingleAgentNavigationObs(TreeObsForRailEnv):
"""
We derive our bbservation builder from TreeObsForRailEnv, to exploit the existing implementation to compute
the minimum distances from each grid node to each agent's target.
We then build a representation vector with 3 binary components, indicating which of the 3 available directions
for each agent (Left, Forward, Right) lead to the shortest path to its target.
E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
will be [1, 0, 0].
"""
def __init__(self):
super().__init__(max_depth=0)
self.observation_space = [3]
def reset(self):
# Recompute the distance map, if the environment has changed.
super().reset()
def get(self, handle):
agent = self.env.agents[handle]
possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
num_transitions = np.count_nonzero(possible_transitions)
# Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right], relative to the current orientation
# If only one transition is possible, the forward branch is aligned with it.
if num_transitions == 1:
observation = [0, 1, 0]
else:
min_distances = []
for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
if possible_transitions[direction]:
new_position = self._new_position(agent.position, direction)
min_distances.append(self.distance_map[handle, new_position[0], new_position[1], direction])
else:
min_distances.append(np.inf)
observation = [0, 0, 0]
observation[np.argmin(min_distances)] = 1
return observation
env = RailEnv(width=14,
height=14,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999, seed=0),
number_of_agents=2,
obs_builder_object=SingleAgentNavigationObs())
obs = env.reset()
env_renderer = RenderTool(env, gl="PILSVG")
env_renderer.render_env(show=True, frames=True, show_observations=False)
for step in range(100):
actions = {}
for i in range(len(obs)):
actions[i] = np.argmax(obs[i])+1
if step%5 == 0:
print("Agent halts")
actions[0] = 4 # Halt
obs, all_rewards, done, _ = env.step(actions)
if env.agents[0].malfunction_data['malfunction'] > 0:
print("Agent 0 broken-ness: ", env.agents[0].malfunction_data['malfunction'])
env_renderer.render_env(show=True, frames=True, show_observations=False)
time.sleep(0.5)
if done["__all__"]:
break
env_renderer.close_window()
...@@ -57,7 +57,7 @@ class RandomAgent: ...@@ -57,7 +57,7 @@ class RandomAgent:
# Initialize the agent with the parameters corresponding to the environment and observation_builder # Initialize the agent with the parameters corresponding to the environment and observation_builder
agent = RandomAgent(218, 4) agent = RandomAgent(218, 5)
n_trials = 5 n_trials = 5
# Empty dictionary for all agent action # Empty dictionary for all agent action
...@@ -77,12 +77,11 @@ for trials in range(1, n_trials + 1): ...@@ -77,12 +77,11 @@ for trials in range(1, n_trials + 1):
score = 0 score = 0
# Run episode # Run episode
for step in range(100): for step in range(500):
# Chose an action for each agent in the environment # Chose an action for each agent in the environment
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
action = agent.act(obs[a]) action = agent.act(obs[a])
action_dict.update({a: action}) action_dict.update({a: action})
# Environment step which returns the observations for all agents, their corresponding # Environment step which returns the observations for all agents, their corresponding
# reward and whether their are done # reward and whether their are done
next_obs, all_rewards, done, _ = env.step(action_dict) next_obs, all_rewards, done, _ = env.step(action_dict)
...@@ -92,7 +91,6 @@ for trials in range(1, n_trials + 1): ...@@ -92,7 +91,6 @@ for trials in range(1, n_trials + 1):
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
agent.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a])) agent.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a]))
score += all_rewards[a] score += all_rewards[a]
obs = next_obs.copy() obs = next_obs.copy()
if done['__all__']: if done['__all__']:
break break
......
...@@ -15,6 +15,7 @@ class EnvAgentStatic(object): ...@@ -15,6 +15,7 @@ class EnvAgentStatic(object):
direction = attrib() direction = attrib()
target = attrib() target = attrib()
moving = attrib(default=False) moving = attrib(default=False)
# speed_data: speed is added to position_fraction on each moving step, until position_fraction>=1.0, # speed_data: speed is added to position_fraction on each moving step, until position_fraction>=1.0,
# after which 'transition_action_on_cellexit' is executed (equivalent to executing that action in the previous # after which 'transition_action_on_cellexit' is executed (equivalent to executing that action in the previous
# cell if speed=1, as default) # cell if speed=1, as default)
...@@ -22,6 +23,12 @@ class EnvAgentStatic(object): ...@@ -22,6 +23,12 @@ class EnvAgentStatic(object):
speed_data = attrib( speed_data = attrib(
default=Factory(lambda: dict({'position_fraction': 0.0, 'speed': 1.0, 'transition_action_on_cellexit': 0}))) default=Factory(lambda: dict({'position_fraction': 0.0, 'speed': 1.0, 'transition_action_on_cellexit': 0})))
# if broken>0, the agent's actions are ignored for 'broken' steps
# number of time the agent had to stop, since the last time it broke down
malfunction_data = attrib(
default=Factory(
lambda: dict({'malfunction': 0, 'malfunction_rate': 0, 'next_malfunction': 0, 'nr_malfunctions': 0})))
@classmethod @classmethod
def from_lists(cls, positions, directions, targets, speeds=None): def from_lists(cls, positions, directions, targets, speeds=None):
""" Create a list of EnvAgentStatics from lists of positions, directions and targets """ Create a list of EnvAgentStatics from lists of positions, directions and targets
...@@ -31,7 +38,22 @@ class EnvAgentStatic(object): ...@@ -31,7 +38,22 @@ class EnvAgentStatic(object):
speed_datas.append({'position_fraction': 0.0, speed_datas.append({'position_fraction': 0.0,
'speed': speeds[i] if speeds is not None else 1.0, 'speed': speeds[i] if speeds is not None else 1.0,
'transition_action_on_cellexit': 0}) 'transition_action_on_cellexit': 0})
return list(starmap(EnvAgentStatic, zip(positions, directions, targets, [False] * len(positions), speed_datas)))
# TODO: on initialization, all agents are re-set as non-broken. Perhaps it may be desirable to set
# some as broken?
malfunction_datas = []
for i in range(len(positions)):
malfunction_datas.append({'malfunction': 0,
'malfunction_rate': 0,
'next_malfunction': 0,
'nr_malfunctions': 0})
return list(starmap(EnvAgentStatic, zip(positions,
directions,
targets,
[False] * len(positions),
speed_datas,
malfunction_datas)))
def to_list(self): def to_list(self):
...@@ -45,7 +67,7 @@ class EnvAgentStatic(object): ...@@ -45,7 +67,7 @@ class EnvAgentStatic(object):
if type(lTarget) is np.ndarray: if type(lTarget) is np.ndarray:
lTarget = lTarget.tolist() lTarget = lTarget.tolist()
return [lPos, int(self.direction), lTarget, int(self.moving), self.speed_data] return [lPos, int(self.direction), lTarget, int(self.moving), self.speed_data, self.malfunction_data]
@attrs @attrs
...@@ -63,7 +85,7 @@ class EnvAgent(EnvAgentStatic): ...@@ -63,7 +85,7 @@ class EnvAgent(EnvAgentStatic):
def to_list(self): def to_list(self):
return [ return [
self.position, self.direction, self.target, self.handle, self.position, self.direction, self.target, self.handle,
self.old_direction, self.old_position, self.moving, self.speed_data] self.old_direction, self.old_position, self.moving, self.speed_data, self.malfunction_data]
@classmethod @classmethod
def from_static(cls, oStatic): def from_static(cls, oStatic):
......
...@@ -75,6 +75,17 @@ class RailEnv(Environment): ...@@ -75,6 +75,17 @@ class RailEnv(Environment):
- stop_penalty = 0 # penalty for stopping a moving agent - stop_penalty = 0 # penalty for stopping a moving agent
- start_penalty = 0 # penalty for starting a stopped agent - start_penalty = 0 # penalty for starting a stopped agent
Stochastic malfunctioning of trains:
Trains in RailEnv can malfunction if they are halted too often (either by their own choice or because an invalid
action or cell is selected.
Every time an agent stops, an agent has a certain probability of malfunctioning. Malfunctions of trains follow a
poisson process with a certain rate. Not all trains will be affected by malfunctions during episodes to keep
complexity managable.
TODO: currently, the parameters that control the stochasticity of the environment are hard-coded in init().
For Round 2, they will be passed to the constructor as arguments, to allow for more flexibility.
""" """
def __init__(self, def __init__(self,
...@@ -83,7 +94,8 @@ class RailEnv(Environment): ...@@ -83,7 +94,8 @@ class RailEnv(Environment):
rail_generator=random_rail_generator(), rail_generator=random_rail_generator(),
number_of_agents=1, number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2), obs_builder_object=TreeObsForRailEnv(max_depth=2),
max_episode_steps=None max_episode_steps=None,
stochastic_data=None
): ):
""" """
Environment init. Environment init.
...@@ -146,6 +158,29 @@ class RailEnv(Environment): ...@@ -146,6 +158,29 @@ class RailEnv(Environment):
self.action_space = [1] self.action_space = [1]
self.observation_space = self.obs_builder.observation_space # updated on resets? self.observation_space = self.obs_builder.observation_space # updated on resets?
# Stochastic train malfunctioning parameters
if stochastic_data is not None:
prop_malfunction = stochastic_data['prop_malfunction']
mean_malfunction_rate = stochastic_data['malfunction_rate']
malfunction_min_duration = stochastic_data['min_duration']
malfunction_max_duration = stochastic_data['max_duration']
else:
prop_malfunction = 0.
mean_malfunction_rate = 0.
malfunction_min_duration = 0.
malfunction_max_duration = 0.
# percentage of malfunctioning trains
self.proportion_malfunctioning_trains = prop_malfunction
# Mean malfunction in number of stops
self.mean_malfunction_rate = mean_malfunction_rate
# Uniform distribution parameters for malfunction duration
self.min_number_of_steps_broken = malfunction_min_duration
self.max_number_of_steps_broken = malfunction_max_duration
# Rest environment
self.reset() self.reset()
self.num_resets = 0 # yes, set it to zero again! self.num_resets = 0 # yes, set it to zero again!
...@@ -195,7 +230,15 @@ class RailEnv(Environment): ...@@ -195,7 +230,15 @@ class RailEnv(Environment):
for i_agent in range(self.get_num_agents()): for i_agent in range(self.get_num_agents()):
agent = self.agents[i_agent] agent = self.agents[i_agent]
# A proportion of agent in the environment will receive a positive malfunction rate
if np.random.random() < self.proportion_malfunctioning_trains:
agent.malfunction_data['malfunction_rate'] = self.mean_malfunction_rate
agent.speed_data['position_fraction'] = 0.0 agent.speed_data['position_fraction'] = 0.0
agent.malfunction_data['malfunction'] = 0
self._agent_malfunction(agent)
self.num_resets += 1 self.num_resets += 1
self._elapsed_steps = 0 self._elapsed_steps = 0
...@@ -210,6 +253,30 @@ class RailEnv(Environment): ...@@ -210,6 +253,30 @@ class RailEnv(Environment):
# Return the new observation vectors for each agent # Return the new observation vectors for each agent
return self._get_observations() return self._get_observations()
def _agent_malfunction(self, agent):
# Decrease counter for next event
agent.malfunction_data['next_malfunction'] -= 1
# Only agents that have a positive rate for malfunctions and are not currently broken are considered
if agent.malfunction_data['malfunction_rate'] > 0 >= agent.malfunction_data[
'malfunction']:
# If counter has come to zero --> Agent has malfunction
# set next malfunction time and duration of current malfunction
if agent.malfunction_data['next_malfunction'] <= 0:
# Increase number of malfunctions
agent.malfunction_data['nr_malfunctions'] += 1
# Next malfunction in number of stops
next_breakdown = int(
np.random.exponential(scale=agent.malfunction_data['malfunction_rate']))
agent.malfunction_data['next_malfunction'] = next_breakdown
# Duration of current malfunction
num_broken_steps = np.random.randint(self.min_number_of_steps_broken,
self.max_number_of_steps_broken + 1) + 1
agent.malfunction_data['malfunction'] = num_broken_steps
def step(self, action_dict_): def step(self, action_dict_):
self._elapsed_steps += 1 self._elapsed_steps += 1
...@@ -238,12 +305,29 @@ class RailEnv(Environment): ...@@ -238,12 +305,29 @@ class RailEnv(Environment):
agent = self.agents[i_agent] agent = self.agents[i_agent]
agent.old_direction = agent.direction agent.old_direction = agent.direction
agent.old_position = agent.position agent.old_position = agent.position
# Check if agent breaks at this step
self._agent_malfunction(agent)
if self.dones[i_agent]: # this agent has already completed... if self.dones[i_agent]: # this agent has already completed...
continue continue
if i_agent not in action_dict: # no action has been supplied for this agent # No action has been supplied for this agent
if i_agent not in action_dict:
action_dict[i_agent] = RailEnvActions.DO_NOTHING action_dict[i_agent] = RailEnvActions.DO_NOTHING
# The train is broken
if agent.malfunction_data['malfunction'] > 0:
agent.malfunction_data['malfunction'] -= 1
# Broken agents are stopped
self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed']
self.agents[i_agent].moving = False
action_dict[i_agent] = RailEnvActions.DO_NOTHING
# Nothing left to do with broken agent
continue
if action_dict[i_agent] < 0 or action_dict[i_agent] > len(RailEnvActions): if action_dict[i_agent] < 0 or action_dict[i_agent] > len(RailEnvActions):
print('ERROR: illegal action=', action_dict[i_agent], print('ERROR: illegal action=', action_dict[i_agent],
'for agent with index=', i_agent, 'for agent with index=', i_agent,
...@@ -329,7 +413,7 @@ class RailEnv(Environment): ...@@ -329,7 +413,7 @@ class RailEnv(Environment):
agent.direction = new_direction agent.direction = new_direction
agent.speed_data['position_fraction'] = 0.0 agent.speed_data['position_fraction'] = 0.0
else: else:
# If the agent cannot move due to any reason, we set its state to not moving. # If the agent cannot move due to any reason, we set its state to not moving
agent.moving = False agent.moving = False
if np.equal(agent.position, agent.target).all(): if np.equal(agent.position, agent.target).all():
......
import numpy as np
from flatland.envs.generators import complex_rail_generator
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.rail_env import RailEnv
class SingleAgentNavigationObs(TreeObsForRailEnv):
"""
We derive our bbservation builder from TreeObsForRailEnv, to exploit the existing implementation to compute
the minimum distances from each grid node to each agent's target.
We then build a representation vector with 3 binary components, indicating which of the 3 available directions
for each agent (Left, Forward, Right) lead to the shortest path to its target.
E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
will be [1, 0, 0].
"""
def __init__(self):
super().__init__(max_depth=0)
self.observation_space = [3]
def reset(self):
# Recompute the distance map, if the environment has changed.
super().reset()
def get(self, handle):
agent = self.env.agents[handle]
possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
num_transitions = np.count_nonzero(possible_transitions)
# Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right], relative to the current orientation
# If only one transition is possible, the forward branch is aligned with it.
if num_transitions == 1:
observation = [0, 1, 0]
else:
min_distances = []
for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
if possible_transitions[direction]:
new_position = self._new_position(agent.position, direction)
min_distances.append(self.distance_map[handle, new_position[0], new_position[1], direction])
else:
min_distances.append(np.inf)
observation = [0, 0, 0]
observation[np.argmin(min_distances)] = 1
return observation
def test_malfunction_process():
# Set fixed malfunction duration for this test
stochastic_data = {'prop_malfunction': 1.,
'malfunction_rate': 1000,
'min_duration': 3,
'max_duration': 3}
np.random.seed(5)
env = RailEnv(width=20,
height=20,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999,
seed=0),
number_of_agents=2,
obs_builder_object=SingleAgentNavigationObs(),
stochastic_data=stochastic_data)
obs = env.reset()
# Check that a initial duration for malfunction was assigned
assert env.agents[0].malfunction_data['next_malfunction'] > 0
agent_halts = 0
total_down_time = 0
agent_malfunctioning = False
agent_old_position = env.agents[0].position
for step in range(100):
actions = {}
for i in range(len(obs)):
actions[i] = np.argmax(obs[i]) + 1
if step % 5 == 0:
# Stop the agent and set it to be malfunctioning
env.agents[0].malfunction_data['malfunction'] = -1
env.agents[0].malfunction_data['next_malfunction'] = 0
agent_halts += 1
obs, all_rewards, done, _ = env.step(actions)
if env.agents[0].malfunction_data['malfunction'] > 0:
agent_malfunctioning = True
else:
agent_malfunctioning = False
if agent_malfunctioning:
# Check that agent is not moving while malfunctioning
assert agent_old_position == env.agents[0].position
agent_old_position = env.agents[0].position
total_down_time += env.agents[0].malfunction_data['malfunction']
# Check that the appropriate number of malfunctions is achieved
assert env.agents[0].malfunction_data['nr_malfunctions'] == 21
# Check that 20 stops where performed
assert agent_halts == 20
# Check that malfunctioning data was standing around
assert total_down_time > 0
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment