Skip to content
Snippets Groups Projects
Commit bf30a586 authored by Erik Nygren's avatar Erik Nygren
Browse files

updated city placement in sparse_rail_generator

Updated flatland 2.0 example
added new feature: Malfunctioning agents will automatically resume operation when they are fixed
parent 0da9e3c3
No related branches found
No related tags found
No related merge requests found
...@@ -13,7 +13,7 @@ np.random.seed(1) ...@@ -13,7 +13,7 @@ np.random.seed(1)
# Training on simple small tasks is the best way to get familiar with the environment # Training on simple small tasks is the best way to get familiar with the environment
# Use a the malfunction generator to break agents from time to time # Use a the malfunction generator to break agents from time to time
stochastic_data = {'prop_malfunction': 0.1, # Percentage of defective agents stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents
'malfunction_rate': 30, # Rate of malfunction occurence 'malfunction_rate': 30, # Rate of malfunction occurence
'min_duration': 3, # Minimal duration of malfunction 'min_duration': 3, # Minimal duration of malfunction
'max_duration': 20 # Max duration of malfunction 'max_duration': 20 # Max duration of malfunction
...@@ -31,11 +31,11 @@ speed_ration_map = {1.: 0.25, # Fast passenger train ...@@ -31,11 +31,11 @@ speed_ration_map = {1.: 0.25, # Fast passenger train
env = RailEnv(width=50, env = RailEnv(width=50,
height=50, height=50,
rail_generator=sparse_rail_generator(num_cities=25, # Number of cities in map (where train stations are) rail_generator=sparse_rail_generator(num_cities=25, # Number of cities in map (where train stations are)
num_intersections=0, # Number of intersections (no start / target) num_intersections=10, # Number of intersections (no start / target)
num_trainstations=50, # Number of possible start/targets on map num_trainstations=50, # Number of possible start/targets on map
min_node_dist=3, # Minimal distance of nodes min_node_dist=3, # Minimal distance of nodes
node_radius=2, # Proximity of stations to city center node_radius=4, # Proximity of stations to city center
num_neighb=3, # Number of connections to other cities/intersections num_neighb=4, # Number of connections to other cities/intersections
seed=15, # Random seed seed=15, # Random seed
grid_mode=True, grid_mode=True,
enhance_intersection=False enhance_intersection=False
......
...@@ -266,7 +266,8 @@ class RailEnv(Environment): ...@@ -266,7 +266,8 @@ class RailEnv(Environment):
def _agent_malfunction(self, agent): def _agent_malfunction(self, agent):
# Decrease counter for next event # Decrease counter for next event
agent.malfunction_data['next_malfunction'] -= 1 if agent.malfunction_data['malfunction_rate'] > 0:
agent.malfunction_data['next_malfunction'] -= 1
# Only agents that have a positive rate for malfunctions and are not currently broken are considered # Only agents that have a positive rate for malfunctions and are not currently broken are considered
if agent.malfunction_data['malfunction_rate'] > 0 >= agent.malfunction_data['malfunction']: if agent.malfunction_data['malfunction_rate'] > 0 >= agent.malfunction_data['malfunction']:
...@@ -328,15 +329,22 @@ class RailEnv(Environment): ...@@ -328,15 +329,22 @@ class RailEnv(Environment):
# The train is broken # The train is broken
if agent.malfunction_data['malfunction'] > 0: if agent.malfunction_data['malfunction'] > 0:
agent.malfunction_data['malfunction'] -= 1 # Last step of malfunction --> Agent starts moving again
if agent.malfunction_data['malfunction'] < 2:
agent.malfunction_data['malfunction'] -= 1
self.agents[i_agent].moving = True
action_dict[i_agent] = RailEnvActions.DO_NOTHING
# Broken agents are stopped else:
self.rewards_dict[i_agent] += step_penalty # * agent.speed_data['speed'] agent.malfunction_data['malfunction'] -= 1
self.agents[i_agent].moving = False
action_dict[i_agent] = RailEnvActions.DO_NOTHING
# Nothing left to do with broken agent # Broken agents are stopped
continue self.rewards_dict[i_agent] += step_penalty # * agent.speed_data['speed']
self.agents[i_agent].moving = False
action_dict[i_agent] = RailEnvActions.DO_NOTHING
# Nothing left to do with broken agent
continue
if action_dict[i_agent] < 0 or action_dict[i_agent] > len(RailEnvActions): if action_dict[i_agent] < 0 or action_dict[i_agent] > len(RailEnvActions):
print('ERROR: illegal action=', action_dict[i_agent], print('ERROR: illegal action=', action_dict[i_agent],
......
...@@ -563,7 +563,6 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 ...@@ -563,7 +563,6 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
node_positions = [] node_positions = []
city_positions = [] city_positions = []
intersection_positions = [] intersection_positions = []
# Evenly distribute cities and intersections # Evenly distribute cities and intersections
if grid_mode: if grid_mode:
tot_num_node = num_intersections + num_cities tot_num_node = num_intersections + num_cities
...@@ -572,10 +571,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 ...@@ -572,10 +571,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
nodes_per_col = int(np.ceil(tot_num_node / nodes_per_row)) nodes_per_col = int(np.ceil(tot_num_node / nodes_per_row))
x_positions = np.linspace(node_radius, height - node_radius, nodes_per_row, dtype=int) x_positions = np.linspace(node_radius, height - node_radius, nodes_per_row, dtype=int)
y_positions = np.linspace(node_radius, width - node_radius, nodes_per_col, dtype=int) y_positions = np.linspace(node_radius, width - node_radius, nodes_per_col, dtype=int)
city_idx = np.random.choice(np.arange(tot_num_node), num_cities)
fraction = 0
city_fraction = num_cities / tot_num_node
step = np.gcd(num_intersections, num_cities) / tot_num_node
for node_idx in range(num_cities + num_intersections): for node_idx in range(num_cities + num_intersections):
to_close = True to_close = True
...@@ -608,10 +604,9 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 ...@@ -608,10 +604,9 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
warnings.warn("Could not set nodes, please change initial parameters!!!!") warnings.warn("Could not set nodes, please change initial parameters!!!!")
break break
else: else:
fraction = (fraction + step) % 1.
x_tmp = x_positions[node_idx % nodes_per_row] x_tmp = x_positions[node_idx % nodes_per_row]
y_tmp = y_positions[node_idx // nodes_per_row] y_tmp = y_positions[node_idx // nodes_per_row]
if len(city_positions) < num_cities and fraction < city_fraction: if node_idx in city_idx:
city_positions.append((x_tmp, y_tmp)) city_positions.append((x_tmp, y_tmp))
else: else:
intersection_positions.append((x_tmp, y_tmp)) intersection_positions.append((x_tmp, y_tmp))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment