diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index b9ace9f537aea14c233d55aeaf814e9620a8fbd0..5ece03e9c56d672b76a453e0036f6b89c3a6ee77 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -37,7 +37,7 @@ env = RailEnv(width=100, seed=14, # Random seed grid_mode=False, max_rails_between_cities=2, - max_rails_in_city=6, + max_rails_in_city=8, ), schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=100, diff --git a/flatland/envs/distance_map.py b/flatland/envs/distance_map.py index 2bc1a5117794959cca82d2edad821cb629397f78..c6e73b0bdbe752b8d5df9c4a0697bb621e5276ec 100644 --- a/flatland/envs/distance_map.py +++ b/flatland/envs/distance_map.py @@ -55,6 +55,7 @@ class DistanceMap: self.env_width = rail.width def _compute(self, agents: List[EnvAgent], rail: GridTransitionMap): + print("computing distance map") self.agents_previous_computation = self.agents self.distance_map = np.inf * np.ones(shape=(len(agents), self.env_height, diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 464d34bdae435208e336b31ebd2f997ce95cdf35..df0b88485b8e92a11ee780ba73634d7a704b6864 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -9,7 +9,6 @@ from typing import List, NamedTuple, Optional, Dict import msgpack import msgpack_numpy as m import numpy as np - from gym.utils import seeding from flatland.core.env import Environment @@ -187,7 +186,7 @@ class RailEnv(Environment): self.distance_map = DistanceMap(self.agents, self.height, self.width) self.action_space = [1] - + # Stochastic train malfunctioning parameters if stochastic_data is not None: prop_malfunction = stochastic_data['prop_malfunction'] @@ -466,7 +465,7 @@ class RailEnv(Environment): return # Is the agent at the beginning of the cell? Then, it can take an action. - # As long as the agent is malfunctioning or stopped at the beginning of the cell, + # As long as the agent is malfunctioning or stopped at the beginning of the cell, # different actions may be taken! if agent.speed_data['position_fraction'] == 0.0: # No action has been supplied for this agent -> set DO_NOTHING as default diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index e87bd93d546ab2b1436f49b15a803b5dac16d0b1..c72fc5190c106b0d28a8b2df84b7c3009d7404c9 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -70,14 +70,6 @@ def test_malfunction_process(): 'malfunction_rate': 1000, 'min_duration': 3, 'max_duration': 3} - # random.seed(0) - # np.random.seed(0) - - stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents - 'malfunction_rate': 70, # Rate of malfunction occurence - 'min_duration': 2, # Minimal duration of malfunction - 'max_duration': 5 # Max duration of malfunction - } rail, rail_map = make_simple_rail2() @@ -90,8 +82,8 @@ def test_malfunction_process(): obs_builder_object=SingleAgentNavigationObs() ) # reset to initialize agents_static - obs = env.reset(False, False, True, random_seed=0) - + obs, info = env.reset(False, False, True, random_seed=0) + print(env.agents[0].malfunction_data) # Check that a initial duration for malfunction was assigned assert env.agents[0].malfunction_data['next_malfunction'] > 0 for agent in env.agents: @@ -100,6 +92,9 @@ def test_malfunction_process(): agent_halts = 0 total_down_time = 0 agent_old_position = env.agents[0].position + + # Move target to unreachable position in order to not interfere with test + env.agents[0].target = (0, 0) for step in range(100): actions = {} @@ -157,6 +152,7 @@ def test_malfunction_process_statistically(): ) # reset to initialize agents_static env.reset(False, False, False, random_seed=0) + env.agents[0].target = (0, 0) nb_malfunction = 0 for step in range(20): @@ -166,7 +162,6 @@ def test_malfunction_process_statistically(): action_dict[agent.handle] = RailEnvActions(np.random.randint(4)) env.step(action_dict) - # check that generation of malfunctions works as expected assert env.agents[0].malfunction_data["nr_malfunctions"] == 4