diff --git a/examples/introduction_flatland_2_1.py b/examples/introduction_flatland_2_1.py index 87451ecef1f7a07c475c02a3be5915a5cb408ce9..4cdf63b0b32d431efabc05ed4d593aa746a44b40 100644 --- a/examples/introduction_flatland_2_1.py +++ b/examples/introduction_flatland_2_1.py @@ -1,3 +1,5 @@ +import numpy as np + # In Flatland you can use custom observation builders and predicitors # Observation builders generate the observation needed by the controller # Preditctors can be used to do short time prediction which can help in avoiding conflicts in the network @@ -58,9 +60,9 @@ schedule_generator = sparse_schedule_generator(speed_ration_map) # We can furthermore pass stochastic data to the RailEnv constructor which will allow for stochastic malfunctions # during an episode. -stochastic_data = {'malfunction_rate': 5, # Rate of malfunction occurence of single agent - 'min_duration': 3, # Minimal duration of malfunction - 'max_duration': 20 # Max duration of malfunction +stochastic_data = {'malfunction_rate': 100, # Rate of malfunction occurence of single agent + 'min_duration': 15, # Minimal duration of malfunction + 'max_duration': 50 # Max duration of malfunction } # Custom observation builder without predictor @@ -107,7 +109,7 @@ class RandomAgent: :param state: input is the observation of the agent :return: returns an action """ - return 2 # np.random.choice(np.arange(self.action_size)) + return np.random.choice([1, 2, 3, 4]) # [Left, Forward, Right, Stop] def step(self, memories): """ @@ -251,8 +253,8 @@ for step in range(500): next_obs, all_rewards, done, _ = env.step(action_dict) - env_renderer.render_env(show=False, show_observations=False, show_predictions=False) - env_renderer.gl.save_image('./misc/Fames2/flatland_frame_{:04d}.png'.format(step)) + env_renderer.render_env(show=True, show_observations=False, show_predictions=False) + # env_renderer.gl.save_image('./misc/Fames2/flatland_frame_{:04d}.png'.format(step)) frame_step += 1 # Update replay buffer and train agent for a in range(env.get_num_agents()): @@ -262,5 +264,4 @@ for step in range(500): observations = next_obs.copy() if done['__all__']: break - print('Episode: Steps {}\t Score = {}'.format(step, score)) diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 2c322c7f0ce33a5cb2dc0f197d8d8a85e0e04407..f284f3ac38b480d8feb6c0b4944cf8831d2a70d2 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -431,9 +431,20 @@ class RailEnv(Environment): breaking_agent_idx = self.np_random.choice(self.active_agents) breaking_agent = self.agents[breaking_agent_idx] - # Only break agents that are not broken yet - # TODO: Do we want to guarantee that we have the desired rate or are we happy with lower rates? - if breaking_agent.malfunction_data['malfunction'] < 1: + # We assume that less then half of the active agents should be broken at MOST. + # Therefore we only try that many times before ignoring the malfunction + + tries = 0 + max_tries = 0.5 * len(self.active_agents) + + # Look for a functioning active agent + while breaking_agent.malfunction_data['malfunction'] > 0 and tries < max_tries: + breaking_agent_idx = self.np_random.choice(self.active_agents) + breaking_agent = self.agents[breaking_agent_idx] + tries += 1 + + # If we did not manage to find a functioning agent among the active ones skip this malfunction + if tries < max_tries: # Because we update agents in the same step as we break them we add one to the duration of the # malfunction num_broken_steps = self.np_random.randint(self.min_number_of_steps_broken, @@ -443,6 +454,10 @@ class RailEnv(Environment): breaking_agent.malfunction_data['fixed'] = False breaking_agent.malfunction_data['nr_malfunctions'] += 1 + return + + return + def step(self, action_dict_: Dict[int, RailEnvActions]): """ Updates rewards for the agents at a step. diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index da710ad375019adbcf2a0bf83f49ded6b457b93c..e18685f6c68f757f50075ebf206423b562514dd6 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -140,17 +140,17 @@ def test_malfunction_process_statistically(): env.agents[0].target = (0, 0) # Next line only for test generation - # agent_malfunction_list = [[] for i in range(10)] - agent_malfunction_list = [[0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1], - [0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5], - [0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]] + #agent_malfunction_list = [[] for i in range(10)] + agent_malfunction_list = [[0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1], + [0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0], + [0, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 5, 4], + [0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 5, 4, 3, 2], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 5], + [0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 5, 4, 3], + [0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]] for step in range(20): action_dict: Dict[int, RailEnvActions] = {} @@ -158,10 +158,10 @@ def test_malfunction_process_statistically(): # We randomly select an action action_dict[agent_idx] = RailEnvActions(np.random.randint(4)) # For generating tests only: - # agent_malfunction_list[agent_idx].append(env.agents[agent_idx].malfunction_data['malfunction']) + #agent_malfunction_list[agent_idx].append(env.agents[agent_idx].malfunction_data['malfunction']) assert env.agents[agent_idx].malfunction_data['malfunction'] == agent_malfunction_list[agent_idx][step] env.step(action_dict) - # print(agent_malfunction_list) + #print(agent_malfunction_list) def test_malfunction_before_entry():