From a165ac00c1cbc94219574c893df5e942cba6ef2e Mon Sep 17 00:00:00 2001 From: MLErik <baerenjesus@gmail.com> Date: Tue, 8 Oct 2019 09:04:27 -0400 Subject: [PATCH] fixed first tests in malfunction test --- examples/flatland_2_0_example.py | 2 +- flatland/envs/distance_map.py | 1 + flatland/envs/rail_env.py | 5 ++--- tests/test_flatland_malfunction.py | 17 ++++++----------- 4 files changed, 10 insertions(+), 15 deletions(-) diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index b9ace9f5..5ece03e9 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -37,7 +37,7 @@ env = RailEnv(width=100, seed=14, # Random seed grid_mode=False, max_rails_between_cities=2, - max_rails_in_city=6, + max_rails_in_city=8, ), schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=100, diff --git a/flatland/envs/distance_map.py b/flatland/envs/distance_map.py index 2bc1a511..c6e73b0b 100644 --- a/flatland/envs/distance_map.py +++ b/flatland/envs/distance_map.py @@ -55,6 +55,7 @@ class DistanceMap: self.env_width = rail.width def _compute(self, agents: List[EnvAgent], rail: GridTransitionMap): + print("computing distance map") self.agents_previous_computation = self.agents self.distance_map = np.inf * np.ones(shape=(len(agents), self.env_height, diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 464d34bd..df0b8848 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -9,7 +9,6 @@ from typing import List, NamedTuple, Optional, Dict import msgpack import msgpack_numpy as m import numpy as np - from gym.utils import seeding from flatland.core.env import Environment @@ -187,7 +186,7 @@ class RailEnv(Environment): self.distance_map = DistanceMap(self.agents, self.height, self.width) self.action_space = [1] - + # Stochastic train malfunctioning parameters if stochastic_data is not None: prop_malfunction = stochastic_data['prop_malfunction'] @@ -466,7 +465,7 @@ class RailEnv(Environment): return # Is the agent at the beginning of the cell? Then, it can take an action. - # As long as the agent is malfunctioning or stopped at the beginning of the cell, + # As long as the agent is malfunctioning or stopped at the beginning of the cell, # different actions may be taken! if agent.speed_data['position_fraction'] == 0.0: # No action has been supplied for this agent -> set DO_NOTHING as default diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index e87bd93d..c72fc519 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -70,14 +70,6 @@ def test_malfunction_process(): 'malfunction_rate': 1000, 'min_duration': 3, 'max_duration': 3} - # random.seed(0) - # np.random.seed(0) - - stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents - 'malfunction_rate': 70, # Rate of malfunction occurence - 'min_duration': 2, # Minimal duration of malfunction - 'max_duration': 5 # Max duration of malfunction - } rail, rail_map = make_simple_rail2() @@ -90,8 +82,8 @@ def test_malfunction_process(): obs_builder_object=SingleAgentNavigationObs() ) # reset to initialize agents_static - obs = env.reset(False, False, True, random_seed=0) - + obs, info = env.reset(False, False, True, random_seed=0) + print(env.agents[0].malfunction_data) # Check that a initial duration for malfunction was assigned assert env.agents[0].malfunction_data['next_malfunction'] > 0 for agent in env.agents: @@ -100,6 +92,9 @@ def test_malfunction_process(): agent_halts = 0 total_down_time = 0 agent_old_position = env.agents[0].position + + # Move target to unreachable position in order to not interfere with test + env.agents[0].target = (0, 0) for step in range(100): actions = {} @@ -157,6 +152,7 @@ def test_malfunction_process_statistically(): ) # reset to initialize agents_static env.reset(False, False, False, random_seed=0) + env.agents[0].target = (0, 0) nb_malfunction = 0 for step in range(20): @@ -166,7 +162,6 @@ def test_malfunction_process_statistically(): action_dict[agent.handle] = RailEnvActions(np.random.randint(4)) env.step(action_dict) - # check that generation of malfunctions works as expected assert env.agents[0].malfunction_data["nr_malfunctions"] == 4 -- GitLab