Skip to content
Snippets Groups Projects
Commit 2ec6d6f4 authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

added test for malfunction and updated malfunction initialization

parent ca327a1a
No related branches found
No related tags found
No related merge requests found
......@@ -75,13 +75,11 @@ for trials in range(1, n_trials + 1):
score = 0
# Run episode
mean_malfunction_interval = []
for step in range(100):
# Chose an action for each agent in the environment
for a in range(env.get_num_agents()):
action = agent.act(obs[a])
action_dict.update({a: action})
# Environment step which returns the observations for all agents, their corresponding
# reward and whether their are done
next_obs, all_rewards, done, _ = env.step(action_dict)
......@@ -95,5 +93,4 @@ for trials in range(1, n_trials + 1):
obs = next_obs.copy()
if done['__all__']:
break
print(np.mean(mean_malfunction_interval))
print('Episode Nr. {}\t Score = {}'.format(trials, score))
......@@ -26,7 +26,8 @@ class EnvAgentStatic(object):
# if broken>0, the agent's actions are ignored for 'broken' steps
# number of time the agent had to stop, since the last time it broke down
malfunction_data = attrib(
default=Factory(lambda: dict({'malfunction': 0, 'malfunction_rate': 0, 'next_malfunction': 0})))
default=Factory(
lambda: dict({'malfunction': 0, 'malfunction_rate': 0, 'next_malfunction': 0, 'nr_malfunctions': 0})))
@classmethod
def from_lists(cls, positions, directions, targets, speeds=None):
......@@ -40,18 +41,19 @@ class EnvAgentStatic(object):
# TODO: on initialization, all agents are re-set as non-broken. Perhaps it may be desirable to set
# some as broken?
broken_datas = []
malfunction_datas = []
for i in range(len(positions)):
broken_datas.append({'malfunction': 0,
malfunction_datas.append({'malfunction': 0,
'malfunction_rate': 0,
'next_malfunction': 0})
'next_malfunction': 0,
'nr_malfunctions': 0})
return list(starmap(EnvAgentStatic, zip(positions,
directions,
targets,
[False] * len(positions),
speed_datas,
broken_datas)))
malfunction_datas)))
def to_list(self):
......
......@@ -94,7 +94,8 @@ class RailEnv(Environment):
rail_generator=random_rail_generator(),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2),
max_episode_steps=None
max_episode_steps=None,
stochastic_data=None
):
"""
Environment init.
......@@ -158,12 +159,26 @@ class RailEnv(Environment):
self.observation_space = self.obs_builder.observation_space # updated on resets?
# Stochastic train malfunctioning parameters
self.proportion_malfunctioning_trains = 0.1 # percentage of malfunctioning trains
self.mean_malfunction_rate = 5 # Average malfunction in number of stops
if stochastic_data is not None:
prop_malfunction = stochastic_data['prop_malfunction']
mean_malfunction_rate = stochastic_data['malfunction_rate']
malfunction_min_duration = stochastic_data['min_duration']
malfunction_max_duration = stochastic_data['max_duration']
else:
prop_malfunction = 0.
mean_malfunction_rate = 0.
malfunction_min_duration = 0.
malfunction_max_duration = 0.
# percentage of malfunctioning trains
self.proportion_malfunctioning_trains = prop_malfunction
# Mean malfunction in number of stops
self.mean_malfunction_rate = mean_malfunction_rate
# Uniform distribution parameters for malfunction duration
self.min_number_of_steps_broken = 4
self.max_number_of_steps_broken = 10
self.min_number_of_steps_broken = malfunction_min_duration
self.max_number_of_steps_broken = malfunction_max_duration
# Rest environment
self.reset()
......@@ -217,8 +232,9 @@ class RailEnv(Environment):
agent = self.agents[i_agent]
# A proportion of agent in the environment will receive a positive malfunction rate
if np.random.random() >= self.proportion_malfunctioning_trains:
if np.random.random() < self.proportion_malfunctioning_trains:
agent.malfunction_data['malfunction_rate'] = self.mean_malfunction_rate
agent.speed_data['position_fraction'] = 0.0
agent.malfunction_data['malfunction'] = 0
......@@ -236,21 +252,23 @@ class RailEnv(Environment):
return self._get_observations()
def _agent_stopped(self, i_agent):
# Make sure agent is stopped
self.agents[i_agent].moving = False
# Decrease counter for next event
self.agents[i_agent].malfunction_data['next_malfunction'] -= 1
# Only agents that have a positive rate for malfunctions are considered
if self.agents[i_agent].malfunction_data['malfunction_rate'] > 0:
# Decrease counter for next event
self.agents[i_agent].malfunction_data['next_malfunction'] -= 1
# If counter has come to zero, set next malfunction time and duration of current malfunction
if self.agents[i_agent].malfunction_data['malfunction_rate'] > 0 >= self.agents[i_agent].malfunction_data[
'malfunction']:
# If counter has come to zero --> Agent has malfunction
# set next malfunction time and duration of current malfunction
if self.agents[i_agent].malfunction_data['next_malfunction'] <= 0:
# Increase number of malfunctions
self.agents[i_agent].malfunction_data['nr_malfunctions'] += 1
# Next malfunction in number of stops
self.agents[i_agent].malfunction_data['next_malfunction'] = int(np.random.exponential(
scale=self.agents[i_agent].malfunction_data['malfunction_rate']))
next_breakdown = int(
np.random.exponential(scale=self.agents[i_agent].malfunction_data['malfunction_rate']))
self.agents[i_agent].malfunction_data['next_malfunction'] = next_breakdown
# Duration of current malfunction
num_broken_steps = np.random.randint(self.min_number_of_steps_broken,
......@@ -286,9 +304,6 @@ class RailEnv(Environment):
agent.old_direction = agent.direction
agent.old_position = agent.position
if agent.malfunction_data['malfunction'] > 0:
agent.malfunction_data['malfunction'] -= 1
if self.dones[i_agent]: # this agent has already completed...
continue
......@@ -298,8 +313,16 @@ class RailEnv(Environment):
# The train is broken
if agent.malfunction_data['malfunction'] > 0:
agent.malfunction_data['malfunction'] -= 1
# Broken agents are stopped
self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed']
self.agents[i_agent].moving = False
action_dict[i_agent] = RailEnvActions.DO_NOTHING
# Nothing left to do with broken agent
continue
if action_dict[i_agent] < 0 or action_dict[i_agent] > len(RailEnvActions):
print('ERROR: illegal action=', action_dict[i_agent],
'for agent with index=', i_agent,
......
import numpy as np
from flatland.envs.generators import complex_rail_generator
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.rail_env import RailEnv
class SingleAgentNavigationObs(TreeObsForRailEnv):
"""
We derive our bbservation builder from TreeObsForRailEnv, to exploit the existing implementation to compute
the minimum distances from each grid node to each agent's target.
We then build a representation vector with 3 binary components, indicating which of the 3 available directions
for each agent (Left, Forward, Right) lead to the shortest path to its target.
E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
will be [1, 0, 0].
"""
def __init__(self):
super().__init__(max_depth=0)
self.observation_space = [3]
def reset(self):
# Recompute the distance map, if the environment has changed.
super().reset()
def get(self, handle):
agent = self.env.agents[handle]
possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
num_transitions = np.count_nonzero(possible_transitions)
# Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right], relative to the current orientation
# If only one transition is possible, the forward branch is aligned with it.
if num_transitions == 1:
observation = [0, 1, 0]
else:
min_distances = []
for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
if possible_transitions[direction]:
new_position = self._new_position(agent.position, direction)
min_distances.append(self.distance_map[handle, new_position[0], new_position[1], direction])
else:
min_distances.append(np.inf)
observation = [0, 0, 0]
observation[np.argmin(min_distances)] = 1
return observation
def test_malfunction_process():
stochastic_data = {'prop_malfunction': 1.,
'malfunction_rate': 5,
'min_duration': 3,
'max_duration': 10}
np.random.seed(5)
env = RailEnv(width=14,
height=14,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999,
seed=0),
number_of_agents=2,
obs_builder_object=SingleAgentNavigationObs(),
stochastic_data=stochastic_data)
obs = env.reset()
agent_halts = 0
for step in range(100):
actions = {}
for i in range(len(obs)):
actions[i] = np.argmax(obs[i]) + 1
if step % 5 == 0:
actions[0] = 4
agent_halts += 1
obs, all_rewards, done, _ = env.step(actions)
if done["__all__"]:
break
# Check that the agents breaks twice
assert env.agents[0].malfunction_data['nr_malfunctions'] == 2
# Check that 7 stops where performed
assert agent_halts == 7
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment