Skip to content
Snippets Groups Projects
Commit ce138664 authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

updated poisson process for malfunction of agents

parent 1868a59a
No related branches found
No related tags found
No related merge requests found
......@@ -3,11 +3,8 @@ import time
import numpy as np
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.grid.grid_utils import coordinate_to_position
from flatland.envs.generators import random_rail_generator, complex_rail_generator
from flatland.envs.generators import complex_rail_generator
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.utils.rendertools import RenderTool
......@@ -77,7 +74,8 @@ for step in range(100):
actions[0] = 4 # Halt
obs, all_rewards, done, _ = env.step(actions)
print("Agent 0 broken-ness: ", env.agents[0].broken_data['broken'])
if env.agents[0].broken_data['broken'] > 0:
print("Agent 0 broken-ness: ", env.agents[0].broken_data['broken'])
env_renderer.render_env(show=True, frames=True, show_observations=False)
time.sleep(0.5)
......
......@@ -16,9 +16,9 @@ TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictor
LocalGridObs = LocalObsForRailEnv(view_height=10, view_width=2, center=2)
env = RailEnv(width=50,
height=50,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
rail_generator=complex_rail_generator(nr_start_goal=20, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
obs_builder_object=TreeObservation,
number_of_agents=5)
number_of_agents=20)
env_renderer = RenderTool(env, gl="PILSVG", )
......@@ -75,6 +75,7 @@ for trials in range(1, n_trials + 1):
score = 0
# Run episode
mean_malfunction_interval = []
for step in range(100):
# Chose an action for each agent in the environment
for a in range(env.get_num_agents()):
......@@ -84,7 +85,7 @@ for trials in range(1, n_trials + 1):
# Environment step which returns the observations for all agents, their corresponding
# reward and whether their are done
next_obs, all_rewards, done, _ = env.step(action_dict)
env_renderer.render_env(show=True, show_observations=True, show_predictions=True)
env_renderer.render_env(show=True, show_observations=False, show_predictions=True)
# Update replay buffer and train agent
for a in range(env.get_num_agents()):
......@@ -94,4 +95,5 @@ for trials in range(1, n_trials + 1):
obs = next_obs.copy()
if done['__all__']:
break
print(np.mean(mean_malfunction_interval))
print('Episode Nr. {}\t Score = {}'.format(trials, score))
......@@ -25,8 +25,8 @@ class EnvAgentStatic(object):
# if broken>0, the agent's actions are ignored for 'broken' steps
# number of time the agent had to stop, since the last time it broke down
broken_data = attrib(
default=Factory(lambda: dict({'broken': 0, 'number_of_halts': 0})))
malfunction_data = attrib(
default=Factory(lambda: dict({'malfunction': 0, 'malfunction_rate': 0, 'next_malfunction': 0})))
@classmethod
def from_lists(cls, positions, directions, targets, speeds=None):
......@@ -42,8 +42,9 @@ class EnvAgentStatic(object):
# some as broken?
broken_datas = []
for i in range(len(positions)):
broken_datas.append({'broken': 0,
'number_of_halts': 0})
broken_datas.append({'malfunction': 0,
'malfunction_rate': 0,
'next_malfunction': 0})
return list(starmap(EnvAgentStatic, zip(positions,
directions,
......@@ -64,7 +65,7 @@ class EnvAgentStatic(object):
if type(lTarget) is np.ndarray:
lTarget = lTarget.tolist()
return [lPos, int(self.direction), lTarget, int(self.moving), self.speed_data, self.broken_data]
return [lPos, int(self.direction), lTarget, int(self.moving), self.speed_data, self.malfunction_data]
@attrs
......@@ -82,7 +83,7 @@ class EnvAgent(EnvAgentStatic):
def to_list(self):
return [
self.position, self.direction, self.target, self.handle,
self.old_direction, self.old_position, self.moving, self.speed_data, self.broken_data]
self.old_direction, self.old_position, self.moving, self.speed_data, self.malfunction_data]
@classmethod
def from_static(cls, oStatic):
......
......@@ -75,16 +75,13 @@ class RailEnv(Environment):
- stop_penalty = 0 # penalty for stopping a moving agent
- start_penalty = 0 # penalty for starting a stopped agent
Stochastic breaking of trains:
Trains in RailEnv can break down if they are halted too often (either by their own choice or because an invalid
Stochastic malfunctioning of trains:
Trains in RailEnv can malfunction if they are halted too often (either by their own choice or because an invalid
action or cell is selected.
Every time an agent stops, an agent has a certain probability of breaking. The probability is the product of 2
distributions: the first distribution selects the average number of trains that will break during an episode
(e.g., max(1, 10% of the trains) ). The second distribution is a Poisson distribution with mean set to the average
number of stops at which a train breaks.
If a random number in [0,1] is lower than the product of the 2 distributions, the train breaks.
A broken train samples a random number of steps it will stay broken for, during which all its actions are ignored.
Every time an agent stops, an agent has a certain probability of malfunctioning. Malfunctions of trains follow a
poisson process with a certain rate. Not all trains will be affected by malfunctions during episodes to keep
complexity managable.
TODO: currently, the parameters that control the stochasticity of the environment are hard-coded in init().
For Round 2, they will be passed to the constructor as arguments, to allow for more flexibility.
......@@ -160,20 +157,20 @@ class RailEnv(Environment):
self.action_space = [1]
self.observation_space = self.obs_builder.observation_space # updated on resets?
# Stochastic train malfunctioning parameters
self.proportion_malfunctioning_trains = 0.1 # percentage of malfunctioning trains
self.mean_malfunction_rate = 5 # Average malfunction in number of stops
# Uniform distribution parameters for malfunction duration
self.min_number_of_steps_broken = 4
self.max_number_of_steps_broken = 10
# Rest environment
self.reset()
self.num_resets = 0 # yes, set it to zero again!
self.valid_positions = None
# Stochastic train breaking parameters
self.min_average_broken_trains = 1
self.average_proportion_of_broken_trains = 0.1 # ~10% of the trains can be expected to break down in an episode
self.mean_number_halts_to_break = 3
# Uniform distribution
self.min_number_of_steps_broken = 4
self.max_number_of_steps_broken = 8
# no more agent_handles
def get_agent_handles(self):
return range(self.get_num_agents())
......@@ -218,9 +215,12 @@ class RailEnv(Environment):
for i_agent in range(self.get_num_agents()):
agent = self.agents[i_agent]
# A proportion of agent in the environment will receive a positive malfunction rate
if np.random.random() >= self.proportion_malfunctioning_trains:
agent.malfunction_data['malfunction_rate'] = self.mean_malfunction_rate
agent.speed_data['position_fraction'] = 0.0
agent.broken_data['broken'] = 0
agent.broken_data['number_of_halts'] = 0
agent.malfunction_data['malfunction'] = 0
self.num_resets += 1
self._elapsed_steps = 0
......@@ -236,24 +236,26 @@ class RailEnv(Environment):
return self._get_observations()
def _agent_stopped(self, i_agent):
self.agents[i_agent].broken_data['number_of_halts'] += 1
# Make sure agent is stopped
self.agents[i_agent].moving = False
# Only agents that have a positive rate for malfunctions are considered
if self.agents[i_agent].malfunction_data['malfunction_rate'] > 0:
def poisson_pdf(x, mean):
return np.power(mean, x) * np.exp(-mean) / np.prod(range(2, x))
# Decrease counter for next event
self.agents[i_agent].malfunction_data['next_malfunction'] -= 1
p1_prob_train_i_breaks = max(self.min_average_broken_trains / len(self.agents),
self.average_proportion_of_broken_trains)
p2_prob_train_breaks_at_halt_j = poisson_pdf(self.agents[i_agent].broken_data['number_of_halts'],
self.mean_number_halts_to_break)
# If counter has come to zero, set next malfunction time and duration of current malfunction
s1 = np.random.random()
s2 = np.random.random()
if self.agents[i_agent].malfunction_data['next_malfunction'] <= 0:
# Next malfunction in number of stops
self.agents[i_agent].malfunction_data['next_malfunction'] = int(np.random.exponential(
scale=self.agents[i_agent].malfunction_data['malfunction_rate']))
if s1 * s2 <= p1_prob_train_i_breaks * p2_prob_train_breaks_at_halt_j:
# +1 because the counter is decreased at the beginning of step()
num_broken_steps = np.random.randint(self.min_number_of_steps_broken, self.max_number_of_steps_broken+1) + 1
self.agents[i_agent].broken_data['broken'] = num_broken_steps
self.agents[i_agent].broken_data['number_of_halts'] = 0
# Duration of current malfunction
num_broken_steps = np.random.randint(self.min_number_of_steps_broken,
self.max_number_of_steps_broken + 1) + 1
self.agents[i_agent].malfunction_data['malfunction'] = num_broken_steps
def step(self, action_dict_):
self._elapsed_steps += 1
......@@ -284,8 +286,8 @@ class RailEnv(Environment):
agent.old_direction = agent.direction
agent.old_position = agent.position
if agent.broken_data['broken'] > 0:
agent.broken_data['broken'] -= 1
if agent.malfunction_data['malfunction'] > 0:
agent.malfunction_data['malfunction'] -= 1
if self.dones[i_agent]: # this agent has already completed...
continue
......@@ -295,7 +297,7 @@ class RailEnv(Environment):
action_dict[i_agent] = RailEnvActions.DO_NOTHING
# The train is broken
if agent.broken_data['broken'] > 0:
if agent.malfunction_data['malfunction'] > 0:
action_dict[i_agent] = RailEnvActions.DO_NOTHING
if action_dict[i_agent] < 0 or action_dict[i_agent] > len(RailEnvActions):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment