diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 4f016f2b3d4ee3e1918fbf47c9492f2c84228999..2281282977d8c9d972f13526efa9d96abaf84a52 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -238,7 +238,7 @@ class RailEnv(Environment): agent.speed_data['position_fraction'] = 0.0 agent.malfunction_data['malfunction'] = 0 - self._agent_stopped(i_agent) + self._agent_malfunction(agent) self.num_resets += 1 self._elapsed_steps = 0 @@ -253,29 +253,29 @@ class RailEnv(Environment): # Return the new observation vectors for each agent return self._get_observations() - def _agent_stopped(self, i_agent): + def _agent_malfunction(self, agent): # Decrease counter for next event - self.agents[i_agent].malfunction_data['next_malfunction'] -= 1 + agent.malfunction_data['next_malfunction'] -= 1 - # Only agents that have a positive rate for malfunctions are considered - if self.agents[i_agent].malfunction_data['malfunction_rate'] > 0 >= self.agents[i_agent].malfunction_data[ + # Only agents that have a positive rate for malfunctions and are not currently broken are considered + if agent.malfunction_data['malfunction_rate'] > 0 >= agent.malfunction_data[ 'malfunction']: # If counter has come to zero --> Agent has malfunction # set next malfunction time and duration of current malfunction - if self.agents[i_agent].malfunction_data['next_malfunction'] <= 0: + if agent.malfunction_data['next_malfunction'] <= 0: # Increase number of malfunctions - self.agents[i_agent].malfunction_data['nr_malfunctions'] += 1 + agent.malfunction_data['nr_malfunctions'] += 1 # Next malfunction in number of stops next_breakdown = int( - np.random.exponential(scale=self.agents[i_agent].malfunction_data['malfunction_rate'])) - self.agents[i_agent].malfunction_data['next_malfunction'] = next_breakdown + np.random.exponential(scale=agent.malfunction_data['malfunction_rate'])) + agent.malfunction_data['next_malfunction'] = next_breakdown # Duration of current malfunction num_broken_steps = np.random.randint(self.min_number_of_steps_broken, self.max_number_of_steps_broken + 1) + 1 - self.agents[i_agent].malfunction_data['malfunction'] = num_broken_steps + agent.malfunction_data['malfunction'] = num_broken_steps def step(self, action_dict_): self._elapsed_steps += 1 @@ -306,6 +306,9 @@ class RailEnv(Environment): agent.old_direction = agent.direction agent.old_position = agent.position + # Check if agent breaks at this step + self._agent_malfunction(agent) + if self.dones[i_agent]: # this agent has already completed... continue @@ -341,7 +344,6 @@ class RailEnv(Environment): # Only allow halting an agent on entering new cells. agent.moving = False self.rewards_dict[i_agent] += stop_penalty - self._agent_stopped(i_agent) if not agent.moving and not (action == RailEnvActions.DO_NOTHING or action == RailEnvActions.STOP_MOVING): # Allow agent to start with any forward or direction action @@ -385,8 +387,6 @@ class RailEnv(Environment): self.rewards_dict[i_agent] += invalid_action_penalty self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed'] self.rewards_dict[i_agent] += stop_penalty - if agent.moving: - self._agent_stopped(i_agent) agent.moving = False continue else: @@ -394,8 +394,6 @@ class RailEnv(Environment): self.rewards_dict[i_agent] += invalid_action_penalty self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed'] self.rewards_dict[i_agent] += stop_penalty - if agent.moving: - self._agent_stopped(i_agent) agent.moving = False continue @@ -416,14 +414,11 @@ class RailEnv(Environment): agent.speed_data['position_fraction'] = 0.0 else: # If the agent cannot move due to any reason, we set its state to not moving - if agent.moving: - self._agent_stopped(i_agent) agent.moving = False if np.equal(agent.position, agent.target).all(): self.dones[i_agent] = True agent.moving = False - # Do not call self._agent_stopped, as the agent has terminated its task else: self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed'] diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index 2e82d212687b7ef21832ea778ec2dd1552db6b36..67dcd25c0769e542fd9a03502c2a8c1b29333b2b 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -53,13 +53,13 @@ class SingleAgentNavigationObs(TreeObsForRailEnv): def test_malfunction_process(): # Set fixed malfunction duration for this test stochastic_data = {'prop_malfunction': 1., - 'malfunction_rate': 5, + 'malfunction_rate': 1000, 'min_duration': 3, 'max_duration': 3} np.random.seed(5) - env = RailEnv(width=14, - height=14, + env = RailEnv(width=20, + height=20, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999, seed=0), number_of_agents=2, @@ -82,17 +82,17 @@ def test_malfunction_process(): if step % 5 == 0: # Stop the agent and set it to be malfunctioning - actions[0] = 4 + env.agents[0].malfunction_data['malfunction'] = -1 env.agents[0].malfunction_data['next_malfunction'] = 0 agent_halts += 1 + obs, all_rewards, done, _ = env.step(actions) + if env.agents[0].malfunction_data['malfunction'] > 0: agent_malfunctioning = True else: agent_malfunctioning = False - obs, all_rewards, done, _ = env.step(actions) - if agent_malfunctioning: # Check that agent is not moving while malfunctioning assert agent_old_position == env.agents[0].position @@ -101,7 +101,7 @@ def test_malfunction_process(): total_down_time += env.agents[0].malfunction_data['malfunction'] # Check that the appropriate number of malfunctions is achieved - assert env.agents[0].malfunction_data['nr_malfunctions'] == 5 + assert env.agents[0].malfunction_data['nr_malfunctions'] == 21 # Check that 20 stops where performed assert agent_halts == 20