diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py index 6a0e595bbd2e47202a9fc2e78c64f438f8190684..975bfe708b4e05eb17714c28d5117c6417526c84 100644 --- a/flatland/envs/agent_utils.py +++ b/flatland/envs/agent_utils.py @@ -64,7 +64,8 @@ class EnvAgentStatic(object): malfunction_datas.append({'malfunction': 0, 'malfunction_rate': schedule.agent_malfunction_rates[i] if schedule.agent_malfunction_rates is not None else 0., 'next_malfunction': 0, - 'nr_malfunctions': 0}) + 'nr_malfunctions': 0, + 'fixed':False}) return list(starmap(EnvAgentStatic, zip(schedule.agent_positions, schedule.agent_directions, diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 3c4a0487da4a543b51841a65411216f5f92acb33..2271a62c20dd73e8c84b1362cbb44e04e08e645f 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -400,13 +400,16 @@ class RailEnv(Environment): return True # Restart fixed agents - - - if agent.malfunction_data['malfunction'] < 1 and agent.malfunction_data['next_malfunction'] > 0: + if agent.malfunction_data['malfunction'] < 1 and agent.malfunction_data['next_malfunction'] > 0 and not agent.malfunction_data['fixed']: agent.malfunction_data['next_malfunction'] -= 1 + agent.malfunction_data['fixed'] = True if 'moving_before_malfunction' in agent.malfunction_data: self.agents[i_agent].moving = agent.malfunction_data['moving_before_malfunction'] return False + # Agent has been running smoothly + elif agent.malfunction_data['malfunction'] < 1 and agent.malfunction_data['next_malfunction'] > 0: + agent.malfunction_data['next_malfunction'] -= 1 + return False # Break agents that have next_malfunction if agent.malfunction_data['malfunction'] < 1 and agent.malfunction_data['next_malfunction'] < 1: @@ -429,7 +432,7 @@ class RailEnv(Environment): if agent.malfunction_data['next_malfunction'] > 0 and agent.malfunction_data['malfunction'] < 1: agent.malfunction_data['next_malfunction'] -= 1 - return False + def step(self, action_dict_: Dict[int, RailEnvActions]):