From bf30a586585af953e6864363b1f0f604d1c5cc98 Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Sun, 1 Sep 2019 16:11:28 -0400 Subject: [PATCH] updated city placement in sparse_rail_generator Updated flatland 2.0 example added new feature: Malfunctioning agents will automatically resume operation when they are fixed --- examples/flatland_2_0_example.py | 8 ++++---- flatland/envs/rail_env.py | 24 ++++++++++++++++-------- flatland/envs/rail_generators.py | 9 ++------- 3 files changed, 22 insertions(+), 19 deletions(-) diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index a1ad9a85..da8cacef 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -13,7 +13,7 @@ np.random.seed(1) # Training on simple small tasks is the best way to get familiar with the environment # Use a the malfunction generator to break agents from time to time -stochastic_data = {'prop_malfunction': 0.1, # Percentage of defective agents +stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents 'malfunction_rate': 30, # Rate of malfunction occurence 'min_duration': 3, # Minimal duration of malfunction 'max_duration': 20 # Max duration of malfunction @@ -31,11 +31,11 @@ speed_ration_map = {1.: 0.25, # Fast passenger train env = RailEnv(width=50, height=50, rail_generator=sparse_rail_generator(num_cities=25, # Number of cities in map (where train stations are) - num_intersections=0, # Number of intersections (no start / target) + num_intersections=10, # Number of intersections (no start / target) num_trainstations=50, # Number of possible start/targets on map min_node_dist=3, # Minimal distance of nodes - node_radius=2, # Proximity of stations to city center - num_neighb=3, # Number of connections to other cities/intersections + node_radius=4, # Proximity of stations to city center + num_neighb=4, # Number of connections to other cities/intersections seed=15, # Random seed grid_mode=True, enhance_intersection=False diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 8486fc7f..4926fb1f 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -266,7 +266,8 @@ class RailEnv(Environment): def _agent_malfunction(self, agent): # Decrease counter for next event - agent.malfunction_data['next_malfunction'] -= 1 + if agent.malfunction_data['malfunction_rate'] > 0: + agent.malfunction_data['next_malfunction'] -= 1 # Only agents that have a positive rate for malfunctions and are not currently broken are considered if agent.malfunction_data['malfunction_rate'] > 0 >= agent.malfunction_data['malfunction']: @@ -328,15 +329,22 @@ class RailEnv(Environment): # The train is broken if agent.malfunction_data['malfunction'] > 0: - agent.malfunction_data['malfunction'] -= 1 + # Last step of malfunction --> Agent starts moving again + if agent.malfunction_data['malfunction'] < 2: + agent.malfunction_data['malfunction'] -= 1 + self.agents[i_agent].moving = True + action_dict[i_agent] = RailEnvActions.DO_NOTHING - # Broken agents are stopped - self.rewards_dict[i_agent] += step_penalty # * agent.speed_data['speed'] - self.agents[i_agent].moving = False - action_dict[i_agent] = RailEnvActions.DO_NOTHING + else: + agent.malfunction_data['malfunction'] -= 1 - # Nothing left to do with broken agent - continue + # Broken agents are stopped + self.rewards_dict[i_agent] += step_penalty # * agent.speed_data['speed'] + self.agents[i_agent].moving = False + action_dict[i_agent] = RailEnvActions.DO_NOTHING + + # Nothing left to do with broken agent + continue if action_dict[i_agent] < 0 or action_dict[i_agent] > len(RailEnvActions): print('ERROR: illegal action=', action_dict[i_agent], diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 39796515..c2359346 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -563,7 +563,6 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 node_positions = [] city_positions = [] intersection_positions = [] - # Evenly distribute cities and intersections if grid_mode: tot_num_node = num_intersections + num_cities @@ -572,10 +571,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 nodes_per_col = int(np.ceil(tot_num_node / nodes_per_row)) x_positions = np.linspace(node_radius, height - node_radius, nodes_per_row, dtype=int) y_positions = np.linspace(node_radius, width - node_radius, nodes_per_col, dtype=int) - - fraction = 0 - city_fraction = num_cities / tot_num_node - step = np.gcd(num_intersections, num_cities) / tot_num_node + city_idx = np.random.choice(np.arange(tot_num_node), num_cities) for node_idx in range(num_cities + num_intersections): to_close = True @@ -608,10 +604,9 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 warnings.warn("Could not set nodes, please change initial parameters!!!!") break else: - fraction = (fraction + step) % 1. x_tmp = x_positions[node_idx % nodes_per_row] y_tmp = y_positions[node_idx // nodes_per_row] - if len(city_positions) < num_cities and fraction < city_fraction: + if node_idx in city_idx: city_positions.append((x_tmp, y_tmp)) else: intersection_positions.append((x_tmp, y_tmp)) -- GitLab