diff --git a/examples/introduction_flatland_2_1.py b/examples/introduction_flatland_2_1.py index 7a89add9d3628bc2b509c862c86c2cb9110ce66d..1e80021c4c3c2f06e6d5c897692b814a5d063025 100644 --- a/examples/introduction_flatland_2_1.py +++ b/examples/introduction_flatland_2_1.py @@ -1,6 +1,8 @@ # In Flatland you can use custom observation builders and predicitors # Observation builders generate the observation needed by the controller # Preditctors can be used to do short time prediction which can help in avoiding conflicts in the network +import time + from flatland.envs.observations import GlobalObsForRailEnv # First of all we import the Flatland rail environment from flatland.envs.rail_env import RailEnv @@ -26,8 +28,8 @@ from flatland.utils.rendertools import RenderTool, AgentRenderVariant # Here we use the sparse_rail_generator with the following parameters width = 100 # With of map -height = 100 # Height of ap -nr_trains = 10 # Number of trains that have an assigned task in the env +height = 100 # Height of map +nr_trains = 50 # Number of trains that have an assigned task in the env cities_in_map = 20 # Number of cities where agents can start or end seed = 14 # Random seed grid_distribution_of_cities = False # Type of city distribution, if False cities are randomly placed @@ -151,14 +153,14 @@ for agent_idx, agent in enumerate(env.agents): # If multiple agents want to enter the same cell at the same time the lower index agent will enter first. # Let's check if there are any agents with the same start location -agents_with_same_start = [] +agents_with_same_start = set() print("\n The following agents have the same initial position:") print("=====================================================") for agent_idx, agent in enumerate(env.agents): for agent_2_idx, agent2 in enumerate(env.agents): if agent_idx != agent_2_idx and agent.initial_position == agent2.initial_position: print("Agent {} as the same initial position as agent {}".format(agent_idx, agent_2_idx)) - agents_with_same_start.append(agent_idx) + agents_with_same_start.add(agent_idx) # Lets try to enter with all of these agents at the same time action_dict = dict() @@ -246,8 +248,11 @@ for step in range(500): # Environment step which returns the observations for all agents, their corresponding # reward and whether their are done + start_time = time.time() next_obs, all_rewards, done, _ = env.step(action_dict) - env_renderer.render_env(show=True, show_observations=False, show_predictions=False) + end_time = time.time() + print(end_time - start_time) + # env_renderer.render_env(show=True, show_observations=False, show_predictions=False) frame_step += 1 # Update replay buffer and train agent for a in range(env.get_num_agents()): diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 87285812b877416e64aea4370b96ab649df7d6a2..d576641632bfe1d0ae20baf485afdb2aca1d640b 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -308,7 +308,9 @@ class RailEnv(Environment): # A proportion of agent in the environment will receive a positive malfunction rate if self.np_random.rand() < self.proportion_malfunctioning_trains: agent.malfunction_data['malfunction_rate'] = self.mean_malfunction_rate - + next_breakdown = int( + self._exp_distirbution_synced(rate=agent.malfunction_data['malfunction_rate'])) + agent.malfunction_data['next_malfunction'] = next_breakdown agent.malfunction_data['malfunction'] = 0 initial_malfunction = self._agent_malfunction(i_agent) @@ -346,7 +348,7 @@ class RailEnv(Environment): """ agent = self.agents[i_agent] - # Decrease counter for next event only if agent is currently not broken + # Decrease counter for next event only if agent is currently not broken and agent has a malfunction rate if agent.malfunction_data['malfunction_rate'] >= 1 and agent.malfunction_data['next_malfunction'] > 0 and \ agent.malfunction_data['malfunction'] < 1: agent.malfunction_data['next_malfunction'] -= 1 diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index 8008e6e2ea4f5aabe98da7a5bff833714361c66a..e9b5a15dade4fd30d6718886eed02a483172f159 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -126,7 +126,7 @@ def test_malfunction_process(): env.agents[0].malfunction_data['nr_malfunctions']) # Check that 20 stops where performed - assert agent_halts == 20 + assert agent_halts == 21 # Check that malfunctioning data was standing around assert total_down_time > 0 @@ -155,16 +155,16 @@ def test_malfunction_process_statistically(): env.agents[0].target = (0, 0) nb_malfunction = 0 - agent_malfunction_list = [[6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0], - [6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1], - [6, 6, 6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3], - [6, 6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0], - [6, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 6], - [6, 6, 5, 4, 3, 2, 1, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 6], - [6, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2], - [6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5], - [6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0], - [6, 6, 6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3]] + agent_malfunction_list = [[0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 6, 5], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0, 6, 5, 4], + [0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 6, 5, 4], + [6, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0], + [6, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 0, 6, 5, 4, 3], + [0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5], + [0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0, 6, 5, 4, 3, 2, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1], + [6, 6, 6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0]] for step in range(20): action_dict: Dict[int, RailEnvActions] = {} @@ -175,6 +175,7 @@ def test_malfunction_process_statistically(): # agent_malfunction_list[agent_idx].append(env.agents[agent_idx].malfunction_data['malfunction']) assert env.agents[agent_idx].malfunction_data['malfunction'] == agent_malfunction_list[agent_idx][step] env.step(action_dict) + # print(agent_malfunction_list) def test_malfunction_before_entry(): @@ -230,14 +231,13 @@ def test_malfunction_before_entry(): assert env.agents[8].malfunction_data['malfunction'] == 2 assert env.agents[9].malfunction_data['malfunction'] == 2 - #for a in range(env.get_num_agents()): + # for a in range(env.get_num_agents()): # print("assert env.agents[{}].malfunction_data['malfunction'] == {}".format(a, # env.agents[a].malfunction_data[ # 'malfunction'])) def test_initial_malfunction(): - stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents 'malfunction_rate': 100, # Rate of malfunction occurence 'min_duration': 2, # Minimal duration of malfunction @@ -410,7 +410,6 @@ def test_initial_malfunction_do_nothing(): rail, rail_map = make_simple_rail2() - env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),