diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 862774319ce58c2625b227b89b77940c12016e89..1805e8c6b01a9fb88db082dc0e7de7909800c0b6 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -300,7 +300,7 @@ class RailEnv(Environment): agent = self.agents[i_agent] # Decrease counter for next event - if agent.malfunction_data['malfunction_rate'] > 0: + if agent.malfunction_data['malfunction_rate'] > 0 and agent.malfunction_data['next_malfunction'] > 0: agent.malfunction_data['next_malfunction'] -= 1 # Only agents that have a positive rate for malfunctions and are not currently broken are considered @@ -468,7 +468,6 @@ class RailEnv(Environment): _action_stored = True if not _action_stored: - # If the agent cannot move due to an invalid transition, we set its state to not moving self.rewards_dict[i_agent] += self.invalid_action_penalty self.rewards_dict[i_agent] += self.stop_penalty diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index 1b3c6adead4d0d82fd676efcc051fc66b4486ef8..99c83e3b6d87470eb237afd0752231b7a378c758 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -419,3 +419,44 @@ def test_initial_malfunction_do_nothing(): ) run_replay_config(env, [replay_config]) + + +def test_initial_nextmalfunction_not_below_zero(): + random.seed(0) + np.random.seed(0) + + stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents + 'malfunction_rate': 0.5, # Rate of malfunction occurence + 'min_duration': 5, # Minimal duration of malfunction + 'max_duration': 5 # Max duration of malfunction + } + + speed_ration_map = {1.: 1., # Fast passenger train + 1. / 2.: 0., # Fast freight train + 1. / 3.: 0., # Slow commuter train + 1. / 4.: 0.} # Slow freight train + + env = RailEnv(width=25, + height=30, + rail_generator=sparse_rail_generator(num_cities=5, + # Number of cities in map (where train stations are) + num_intersections=4, + # Number of intersections (no start / target) + num_trainstations=25, # Number of possible start/targets on map + min_node_dist=6, # Minimal distance of nodes + node_radius=3, # Proximity of stations to city center + num_neighb=3, + # Number of connections to other cities/intersections + seed=215545, # Random seed + grid_mode=True, + enhance_intersection=False + ), + schedule_generator=sparse_schedule_generator(speed_ration_map), + number_of_agents=1, + stochastic_data=stochastic_data, # Malfunction data generator + ) + agent = env.agents[0] + env.step({}) + # was next_malfunction was -1 befor the bugfix https://gitlab.aicrowd.com/flatland/flatland/issues/186 + assert agent.malfunction_data['next_malfunction'] >= 0, \ + "next_malfunction should be >=0, found {}".format(agent.malfunction_data['next_malfunction'])