Skip to content
Snippets Groups Projects
Commit f1647114 authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

updating tests to new malfunction generation

parent c58b1e34
No related branches found
No related tags found
No related merge requests found
......@@ -63,6 +63,7 @@ class EnvAgentStatic(object):
for i in range(len(schedule.agent_positions)):
malfunction_datas.append({'malfunction': 0,
'nr_malfunctions': 0,
'moving_before_malfunction': False,
'fixed': True})
return list(starmap(EnvAgentStatic, zip(schedule.agent_positions,
......
"""
Definition of the RailEnv environment.
"""
import random
# TODO: _ this is a global method --> utils or remove later
from enum import IntEnum
from typing import List, NamedTuple, Optional, Dict
......@@ -8,7 +9,6 @@ from typing import List, NamedTuple, Optional, Dict
import msgpack
import msgpack_numpy as m
import numpy as np
import random
from gym.utils import seeding
from flatland.core.env import Environment
......@@ -211,7 +211,6 @@ class RailEnv(Environment):
# Uniform distribution parameters for malfunction duration
self.min_number_of_steps_broken = malfunction_min_duration
self.max_number_of_steps_broken = malfunction_max_duration
# Reset environment
self.valid_positions = None
......@@ -336,8 +335,8 @@ class RailEnv(Environment):
if agents_hints and 'city_orientations' in agents_hints:
ratio_nr_agents_to_nr_cities = self.get_num_agents() / len(agents_hints['city_orientations'])
self._max_episode_steps = self.compute_max_episode_steps(
width=self.width, height=self.height,
ratio_nr_agents_to_nr_cities=ratio_nr_agents_to_nr_cities)
width=self.width, height=self.height,
ratio_nr_agents_to_nr_cities=ratio_nr_agents_to_nr_cities)
else:
self._max_episode_steps = self.compute_max_episode_steps(width=self.width, height=self.height)
......@@ -401,9 +400,6 @@ class RailEnv(Environment):
self.agents[i_agent].moving = agent.malfunction_data['moving_before_malfunction']
return False
def _malfunction(self, rate) -> bool:
"""
Malfunction generator that breaks agents at a given rate. It does randomly chose agent to break during the run
......@@ -411,16 +407,13 @@ class RailEnv(Environment):
"""
if np.random.random() < self._malfunction_prob(rate):
breaking_agent = random.choice(self.agents)
while breaking_agent.status == RailAgentStatus.DONE_REMOVED:
breaking_agent = random.choice(self.agents)
num_broken_steps = self.np_random.randint(self.min_number_of_steps_broken,
self.max_number_of_steps_broken + 1)
breaking_agent.malfunction_data['malfunction'] = num_broken_steps
breaking_agent.malfunction_data['moving_before_malfunction'] = breaking_agent.moving
breaking_agent.malfunction_data['fixed'] = False
if breaking_agent.malfunction_data['malfunction'] < 1:
num_broken_steps = self.np_random.randint(self.min_number_of_steps_broken,
self.max_number_of_steps_broken + 1)
breaking_agent.malfunction_data['malfunction'] = num_broken_steps
breaking_agent.malfunction_data['moving_before_malfunction'] = breaking_agent.moving
breaking_agent.malfunction_data['fixed'] = False
breaking_agent.malfunction_data['nr_malfunctions'] += 1
def step(self, action_dict_: Dict[int, RailEnvActions]):
"""
......@@ -437,10 +430,10 @@ class RailEnv(Environment):
if self.dones["__all__"]:
self.rewards_dict = {}
info_dict = {
"action_required" : {},
"malfunction" : {},
"speed" : {},
"status" : {},
"action_required": {},
"malfunction": {},
"speed": {},
"status": {},
}
for i_agent, agent in enumerate(self.agents):
self.rewards_dict[i_agent] = self.global_reward
......@@ -454,12 +447,12 @@ class RailEnv(Environment):
# Reset the step rewards
self.rewards_dict = dict()
info_dict = {
"action_required" : {},
"malfunction" : {},
"speed" : {},
"status" : {},
"action_required": {},
"malfunction": {},
"speed": {},
"status": {},
}
have_all_agents_ended = True # boolean flag to check if all agents are done
have_all_agents_ended = True # boolean flag to check if all agents are done
# Evoke the malfunction generator
self._malfunction(self.mean_malfunction_rate)
......@@ -476,8 +469,8 @@ class RailEnv(Environment):
# Build info dict
info_dict["action_required"][i_agent] = \
(agent.status == RailAgentStatus.READY_TO_DEPART or (
agent.status == RailAgentStatus.ACTIVE and np.isclose(agent.speed_data['position_fraction'], 0.0,
rtol=1e-03)))
agent.status == RailAgentStatus.ACTIVE and np.isclose(agent.speed_data['position_fraction'], 0.0,
rtol=1e-03)))
info_dict["malfunction"][i_agent] = agent.malfunction_data['malfunction']
info_dict["speed"][i_agent] = agent.speed_data['speed']
info_dict["status"][i_agent] = agent.status
......
......@@ -79,7 +79,7 @@ def complex_schedule_generator(speed_ratio_map: Mapping[float, float] = None, se
speeds = [1.0] * len(agents_position)
return Schedule(agent_positions=agents_position, agent_directions=agents_direction,
agent_targets=agents_target, agent_speeds=speeds, agent_malfunction_rates=None)
agent_targets=agents_target, agent_speeds=speeds)
return generator
......@@ -165,7 +165,7 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None, see
speeds = [1.0] * len(agents_position)
return Schedule(agent_positions=agents_position, agent_directions=agents_direction,
agent_targets=agents_target, agent_speeds=speeds, agent_malfunction_rates=None)
agent_targets=agents_target, agent_speeds=speeds)
return generator
......@@ -199,12 +199,12 @@ def random_schedule_generator(speed_ratio_map: Optional[Mapping[float, float]] =
valid_positions.append((r, c))
if len(valid_positions) == 0:
return Schedule(agent_positions=[], agent_directions=[],
agent_targets=[], agent_speeds=[], agent_malfunction_rates=None)
agent_targets=[], agent_speeds=[])
if len(valid_positions) < num_agents:
warnings.warn("schedule_generators: len(valid_positions) < num_agents")
return Schedule(agent_positions=[], agent_directions=[],
agent_targets=[], agent_speeds=[], agent_malfunction_rates=None)
agent_targets=[], agent_speeds=[])
agents_position_idx = [i for i in np.random.choice(len(valid_positions), num_agents, replace=False)]
agents_position = [valid_positions[agents_position_idx[i]] for i in range(num_agents)]
......@@ -263,7 +263,7 @@ def random_schedule_generator(speed_ratio_map: Optional[Mapping[float, float]] =
agents_speed = speed_initialization_helper(num_agents, speed_ratio_map, seed=_runtime_seed)
return Schedule(agent_positions=agents_position, agent_directions=agents_direction,
agent_targets=agents_target, agent_speeds=agents_speed, agent_malfunction_rates=None)
agent_targets=agents_target, agent_speeds=agents_speed)
return generator
......@@ -304,12 +304,9 @@ def schedule_from_file(filename, load_from_package=None) -> ScheduleGenerator:
agents_target = [a.target for a in agents_static]
if len(data['agents_static'][0]) > 5:
agents_speed = [a.speed_data['speed'] for a in agents_static]
agents_malfunction = [a.malfunction_data['malfunction_rate'] for a in agents_static]
else:
agents_speed = None
agents_malfunction = None
return Schedule(agent_positions=agents_position, agent_directions=agents_direction,
agent_targets=agents_target, agent_speeds=agents_speed,
agent_malfunction_rates=agents_malfunction)
agent_targets=agents_target, agent_speeds=agents_speed)
return generator
......@@ -6,5 +6,4 @@ from flatland.core.grid.grid_utils import IntVector2DArray
Schedule = NamedTuple('Schedule', [('agent_positions', IntVector2DArray),
('agent_directions', List[Grid4TransitionsEnum]),
('agent_targets', IntVector2DArray),
('agent_speeds', List[float]),
('agent_malfunction_rates', List[int])])
('agent_speeds', List[float])])
......@@ -66,8 +66,7 @@ class SingleAgentNavigationObs(ObservationBuilder):
def test_malfunction_process():
# Set fixed malfunction duration for this test
stochastic_data = {'prop_malfunction': 1.,
'malfunction_rate': 1000,
stochastic_data = {'malfunction_rate': 1,
'min_duration': 3,
'max_duration': 3}
......@@ -84,11 +83,6 @@ def test_malfunction_process():
# reset to initialize agents_static
obs, info = env.reset(False, False, True, random_seed=10)
# Check that a initial duration for malfunction was assigned
assert env.agents[0].malfunction_data['next_malfunction'] > 0
for agent in env.agents:
agent.status = RailAgentStatus.ACTIVE
agent_halts = 0
total_down_time = 0
agent_old_position = env.agents[0].position
......@@ -101,12 +95,6 @@ def test_malfunction_process():
for i in range(len(obs)):
actions[i] = np.argmax(obs[i]) + 1
if step % 5 == 0:
# Stop the agent and set it to be malfunctioning
env.agents[0].malfunction_data['malfunction'] = -1
env.agents[0].malfunction_data['next_malfunction'] = 0
agent_halts += 1
obs, all_rewards, done, _ = env.step(actions)
if env.agents[0].malfunction_data['malfunction'] > 0:
......@@ -122,12 +110,9 @@ def test_malfunction_process():
total_down_time += env.agents[0].malfunction_data['malfunction']
# Check that the appropriate number of malfunctions is achieved
assert env.agents[0].malfunction_data['nr_malfunctions'] == 20, "Actual {}".format(
assert env.agents[0].malfunction_data['nr_malfunctions'] == 30, "Actual {}".format(
env.agents[0].malfunction_data['nr_malfunctions'])
# Check that 20 stops where performed
assert agent_halts == 20
# Check that malfunctioning data was standing around
assert total_down_time > 0
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment