diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index a1ad9a85d9a41f57c317a0b4b0bc61796e4e0f4a..da8cacef2a3e9972ef4b90094c59f448ca07bc37 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -13,7 +13,7 @@ np.random.seed(1) # Training on simple small tasks is the best way to get familiar with the environment # Use a the malfunction generator to break agents from time to time -stochastic_data = {'prop_malfunction': 0.1, # Percentage of defective agents +stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents 'malfunction_rate': 30, # Rate of malfunction occurence 'min_duration': 3, # Minimal duration of malfunction 'max_duration': 20 # Max duration of malfunction @@ -31,11 +31,11 @@ speed_ration_map = {1.: 0.25, # Fast passenger train env = RailEnv(width=50, height=50, rail_generator=sparse_rail_generator(num_cities=25, # Number of cities in map (where train stations are) - num_intersections=0, # Number of intersections (no start / target) + num_intersections=10, # Number of intersections (no start / target) num_trainstations=50, # Number of possible start/targets on map min_node_dist=3, # Minimal distance of nodes - node_radius=2, # Proximity of stations to city center - num_neighb=3, # Number of connections to other cities/intersections + node_radius=4, # Proximity of stations to city center + num_neighb=4, # Number of connections to other cities/intersections seed=15, # Random seed grid_mode=True, enhance_intersection=False diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 8486fc7f5f024023b44a5c6139bdb0d628e12303..4926fb1fc153f25ade27f59db5d546ada8804c15 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -266,7 +266,8 @@ class RailEnv(Environment): def _agent_malfunction(self, agent): # Decrease counter for next event - agent.malfunction_data['next_malfunction'] -= 1 + if agent.malfunction_data['malfunction_rate'] > 0: + agent.malfunction_data['next_malfunction'] -= 1 # Only agents that have a positive rate for malfunctions and are not currently broken are considered if agent.malfunction_data['malfunction_rate'] > 0 >= agent.malfunction_data['malfunction']: @@ -328,15 +329,22 @@ class RailEnv(Environment): # The train is broken if agent.malfunction_data['malfunction'] > 0: - agent.malfunction_data['malfunction'] -= 1 + # Last step of malfunction --> Agent starts moving again + if agent.malfunction_data['malfunction'] < 2: + agent.malfunction_data['malfunction'] -= 1 + self.agents[i_agent].moving = True + action_dict[i_agent] = RailEnvActions.DO_NOTHING - # Broken agents are stopped - self.rewards_dict[i_agent] += step_penalty # * agent.speed_data['speed'] - self.agents[i_agent].moving = False - action_dict[i_agent] = RailEnvActions.DO_NOTHING + else: + agent.malfunction_data['malfunction'] -= 1 - # Nothing left to do with broken agent - continue + # Broken agents are stopped + self.rewards_dict[i_agent] += step_penalty # * agent.speed_data['speed'] + self.agents[i_agent].moving = False + action_dict[i_agent] = RailEnvActions.DO_NOTHING + + # Nothing left to do with broken agent + continue if action_dict[i_agent] < 0 or action_dict[i_agent] > len(RailEnvActions): print('ERROR: illegal action=', action_dict[i_agent], diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 39796515f73c7702ba4dc162301a63b0186dc1d3..c23593463c2b679cfd09fcbcf390c3d5a05acde4 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -563,7 +563,6 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 node_positions = [] city_positions = [] intersection_positions = [] - # Evenly distribute cities and intersections if grid_mode: tot_num_node = num_intersections + num_cities @@ -572,10 +571,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 nodes_per_col = int(np.ceil(tot_num_node / nodes_per_row)) x_positions = np.linspace(node_radius, height - node_radius, nodes_per_row, dtype=int) y_positions = np.linspace(node_radius, width - node_radius, nodes_per_col, dtype=int) - - fraction = 0 - city_fraction = num_cities / tot_num_node - step = np.gcd(num_intersections, num_cities) / tot_num_node + city_idx = np.random.choice(np.arange(tot_num_node), num_cities) for node_idx in range(num_cities + num_intersections): to_close = True @@ -608,10 +604,9 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 warnings.warn("Could not set nodes, please change initial parameters!!!!") break else: - fraction = (fraction + step) % 1. x_tmp = x_positions[node_idx % nodes_per_row] y_tmp = y_positions[node_idx // nodes_per_row] - if len(city_positions) < num_cities and fraction < city_fraction: + if node_idx in city_idx: city_positions.append((x_tmp, y_tmp)) else: intersection_positions.append((x_tmp, y_tmp))