Skip to content
Snippets Groups Projects
Commit 1e4c1044 authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

Merge branch '153_bug_fixes_multi_speed' into 'master'

153 bug fixes multi speed

Closes #153

See merge request flatland/flatland!168
parents c6dc451e 0fddb4ed
No related branches found
No related tags found
No related merge requests found
...@@ -13,7 +13,7 @@ np.random.seed(1) ...@@ -13,7 +13,7 @@ np.random.seed(1)
# Training on simple small tasks is the best way to get familiar with the environment # Training on simple small tasks is the best way to get familiar with the environment
# Use a the malfunction generator to break agents from time to time # Use a the malfunction generator to break agents from time to time
stochastic_data = {'prop_malfunction': 0.1, # Percentage of defective agents stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents
'malfunction_rate': 30, # Rate of malfunction occurence 'malfunction_rate': 30, # Rate of malfunction occurence
'min_duration': 3, # Minimal duration of malfunction 'min_duration': 3, # Minimal duration of malfunction
'max_duration': 20 # Max duration of malfunction 'max_duration': 20 # Max duration of malfunction
...@@ -31,11 +31,11 @@ speed_ration_map = {1.: 0.25, # Fast passenger train ...@@ -31,11 +31,11 @@ speed_ration_map = {1.: 0.25, # Fast passenger train
env = RailEnv(width=50, env = RailEnv(width=50,
height=50, height=50,
rail_generator=sparse_rail_generator(num_cities=25, # Number of cities in map (where train stations are) rail_generator=sparse_rail_generator(num_cities=25, # Number of cities in map (where train stations are)
num_intersections=0, # Number of intersections (no start / target) num_intersections=10, # Number of intersections (no start / target)
num_trainstations=50, # Number of possible start/targets on map num_trainstations=50, # Number of possible start/targets on map
min_node_dist=3, # Minimal distance of nodes min_node_dist=3, # Minimal distance of nodes
node_radius=2, # Proximity of stations to city center node_radius=4, # Proximity of stations to city center
num_neighb=3, # Number of connections to other cities/intersections num_neighb=4, # Number of connections to other cities/intersections
seed=15, # Random seed seed=15, # Random seed
grid_mode=True, grid_mode=True,
enhance_intersection=False enhance_intersection=False
......
...@@ -8,6 +8,7 @@ from numpy import array ...@@ -8,6 +8,7 @@ from numpy import array
from flatland.core.grid.grid4 import Grid4Transitions from flatland.core.grid.grid4 import Grid4Transitions
from flatland.core.grid.grid4_utils import get_new_position from flatland.core.grid.grid4_utils import get_new_position
from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transitions import Transitions from flatland.core.transitions import Transitions
...@@ -482,6 +483,14 @@ class GridTransitionMap(TransitionMap): ...@@ -482,6 +483,14 @@ class GridTransitionMap(TransitionMap):
grcPos = array(rcPos) grcPos = array(rcPos)
grcMax = self.grid.shape grcMax = self.grid.shape
# Transition elements
transitions = RailEnvTransitions()
cells = transitions.transition_list
simple_switch_east_south = transitions.rotate_transition(cells[10], 90)
simple_switch_west_south = transitions.rotate_transition(cells[2], 270)
symmetrical = cells[6]
double_slip = cells[5]
three_way_transitions = [simple_switch_east_south, simple_switch_west_south, symmetrical]
# loop over available outbound directions (indices) for rcPos # loop over available outbound directions (indices) for rcPos
self.set_transitions(rcPos, 0) self.set_transitions(rcPos, 0)
...@@ -517,25 +526,18 @@ class GridTransitionMap(TransitionMap): ...@@ -517,25 +526,18 @@ class GridTransitionMap(TransitionMap):
self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[0])), connect_directions[1], 1) self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[0])), connect_directions[1], 1)
self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[1])), connect_directions[0], 1) self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[1])), connect_directions[0], 1)
# Find feasible connection fro three entries # Find feasible connection for three entries
if number_of_incoming == 3: if number_of_incoming == 3:
transition = np.random.choice(three_way_transitions, 1)
hole = np.argwhere(incoming_connections < 1)[0][0] hole = np.argwhere(incoming_connections < 1)[0][0]
connect_directions = [(hole + 1) % 4, (hole + 2) % 4, (hole + 3) % 4] transition = transitions.rotate_transition(transition, int(hole * 90))
self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[0])), connect_directions[1], 1) self.set_transitions((rcPos[0], rcPos[1]), transition)
self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[0])), connect_directions[2], 1)
self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[1])), connect_directions[0], 1) # Make a double slip switch
self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[2])), connect_directions[0], 1)
# Make a cross
if number_of_incoming == 4: if number_of_incoming == 4:
connect_directions = np.arange(4) rotation = np.random.randint(2)
self.set_transition((rcPos[0], rcPos[1], connect_directions[0]), connect_directions[0], 1) transition = transitions.rotate_transition(double_slip, int(rotation * 90))
self.set_transition((rcPos[0], rcPos[1], connect_directions[0]), connect_directions[1], 1) self.set_transitions((rcPos[0], rcPos[1]), transition)
self.set_transition((rcPos[0], rcPos[1], connect_directions[1]), connect_directions[0], 1)
self.set_transition((rcPos[0], rcPos[1], connect_directions[1]), connect_directions[1], 1)
self.set_transition((rcPos[0], rcPos[1], connect_directions[2]), connect_directions[2], 1)
self.set_transition((rcPos[0], rcPos[1], connect_directions[2]), connect_directions[3], 1)
self.set_transition((rcPos[0], rcPos[1], connect_directions[3]), connect_directions[2], 1)
self.set_transition((rcPos[0], rcPos[1], connect_directions[3]), connect_directions[3], 1)
return True return True
......
...@@ -266,7 +266,8 @@ class RailEnv(Environment): ...@@ -266,7 +266,8 @@ class RailEnv(Environment):
def _agent_malfunction(self, agent): def _agent_malfunction(self, agent):
# Decrease counter for next event # Decrease counter for next event
agent.malfunction_data['next_malfunction'] -= 1 if agent.malfunction_data['malfunction_rate'] > 0:
agent.malfunction_data['next_malfunction'] -= 1
# Only agents that have a positive rate for malfunctions and are not currently broken are considered # Only agents that have a positive rate for malfunctions and are not currently broken are considered
if agent.malfunction_data['malfunction_rate'] > 0 >= agent.malfunction_data['malfunction']: if agent.malfunction_data['malfunction_rate'] > 0 >= agent.malfunction_data['malfunction']:
...@@ -294,7 +295,8 @@ class RailEnv(Environment): ...@@ -294,7 +295,8 @@ class RailEnv(Environment):
alpha = 1.0 alpha = 1.0
beta = 1.0 beta = 1.0
# Epsilon to avoid rounding errors
epsilon = 0.01
invalid_action_penalty = 0 # previously -2; GIACOMO: we decided that invalid actions will carry no penalty invalid_action_penalty = 0 # previously -2; GIACOMO: we decided that invalid actions will carry no penalty
step_penalty = -1 * alpha step_penalty = -1 * alpha
global_reward = 1 * beta global_reward = 1 * beta
...@@ -310,7 +312,6 @@ class RailEnv(Environment): ...@@ -310,7 +312,6 @@ class RailEnv(Environment):
self.rewards_dict = {i: r + global_reward for i, r in self.rewards_dict.items()} self.rewards_dict = {i: r + global_reward for i, r in self.rewards_dict.items()}
return self._get_observations(), self.rewards_dict, self.dones, {} return self._get_observations(), self.rewards_dict, self.dones, {}
# for i in range(len(self.agents_handles)):
for i_agent in range(self.get_num_agents()): for i_agent in range(self.get_num_agents()):
agent = self.agents[i_agent] agent = self.agents[i_agent]
agent.old_direction = agent.direction agent.old_direction = agent.direction
...@@ -328,15 +329,22 @@ class RailEnv(Environment): ...@@ -328,15 +329,22 @@ class RailEnv(Environment):
# The train is broken # The train is broken
if agent.malfunction_data['malfunction'] > 0: if agent.malfunction_data['malfunction'] > 0:
agent.malfunction_data['malfunction'] -= 1 # Last step of malfunction --> Agent starts moving again after getting fixed
if agent.malfunction_data['malfunction'] < 2:
agent.malfunction_data['malfunction'] -= 1
self.agents[i_agent].moving = True
action_dict[i_agent] = RailEnvActions.DO_NOTHING
# Broken agents are stopped else:
self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed'] agent.malfunction_data['malfunction'] -= 1
self.agents[i_agent].moving = False
action_dict[i_agent] = RailEnvActions.DO_NOTHING
# Nothing left to do with broken agent # Broken agents are stopped
continue self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed']
self.agents[i_agent].moving = False
action_dict[i_agent] = RailEnvActions.DO_NOTHING
# Nothing left to do with broken agent
continue
if action_dict[i_agent] < 0 or action_dict[i_agent] > len(RailEnvActions): if action_dict[i_agent] < 0 or action_dict[i_agent] > len(RailEnvActions):
print('ERROR: illegal action=', action_dict[i_agent], print('ERROR: illegal action=', action_dict[i_agent],
...@@ -350,7 +358,8 @@ class RailEnv(Environment): ...@@ -350,7 +358,8 @@ class RailEnv(Environment):
# Keep moving # Keep moving
action = RailEnvActions.MOVE_FORWARD action = RailEnvActions.MOVE_FORWARD
if action == RailEnvActions.STOP_MOVING and agent.moving and agent.speed_data['position_fraction'] == 0.: if action == RailEnvActions.STOP_MOVING and agent.moving and agent.speed_data[
'position_fraction'] <= epsilon:
# Only allow halting an agent on entering new cells. # Only allow halting an agent on entering new cells.
agent.moving = False agent.moving = False
self.rewards_dict[i_agent] += stop_penalty self.rewards_dict[i_agent] += stop_penalty
...@@ -372,7 +381,7 @@ class RailEnv(Environment): ...@@ -372,7 +381,7 @@ class RailEnv(Environment):
# If the agent can make an action # If the agent can make an action
action_selected = False action_selected = False
if agent.speed_data['position_fraction'] == 0.: if agent.speed_data['position_fraction'] <= epsilon:
if action != RailEnvActions.DO_NOTHING and action != RailEnvActions.STOP_MOVING: if action != RailEnvActions.DO_NOTHING and action != RailEnvActions.STOP_MOVING:
cell_free, new_cell_valid, new_direction, new_position, transition_valid = \ cell_free, new_cell_valid, new_direction, new_position, transition_valid = \
self._check_action_on_agent(action, agent) self._check_action_on_agent(action, agent)
......
...@@ -563,7 +563,6 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 ...@@ -563,7 +563,6 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
node_positions = [] node_positions = []
city_positions = [] city_positions = []
intersection_positions = [] intersection_positions = []
# Evenly distribute cities and intersections # Evenly distribute cities and intersections
if grid_mode: if grid_mode:
tot_num_node = num_intersections + num_cities tot_num_node = num_intersections + num_cities
...@@ -572,10 +571,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 ...@@ -572,10 +571,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
nodes_per_col = int(np.ceil(tot_num_node / nodes_per_row)) nodes_per_col = int(np.ceil(tot_num_node / nodes_per_row))
x_positions = np.linspace(node_radius, height - node_radius, nodes_per_row, dtype=int) x_positions = np.linspace(node_radius, height - node_radius, nodes_per_row, dtype=int)
y_positions = np.linspace(node_radius, width - node_radius, nodes_per_col, dtype=int) y_positions = np.linspace(node_radius, width - node_radius, nodes_per_col, dtype=int)
city_idx = np.random.choice(np.arange(tot_num_node), num_cities)
fraction = 0
city_fraction = num_cities / tot_num_node
step = np.gcd(num_intersections, num_cities) / tot_num_node
for node_idx in range(num_cities + num_intersections): for node_idx in range(num_cities + num_intersections):
to_close = True to_close = True
...@@ -608,10 +604,9 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 ...@@ -608,10 +604,9 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
warnings.warn("Could not set nodes, please change initial parameters!!!!") warnings.warn("Could not set nodes, please change initial parameters!!!!")
break break
else: else:
fraction = (fraction + step) % 1.
x_tmp = x_positions[node_idx % nodes_per_row] x_tmp = x_positions[node_idx % nodes_per_row]
y_tmp = y_positions[node_idx // nodes_per_row] y_tmp = y_positions[node_idx // nodes_per_row]
if len(city_positions) < num_cities and fraction < city_fraction: if node_idx in city_idx:
city_positions.append((x_tmp, y_tmp)) city_positions.append((x_tmp, y_tmp))
else: else:
intersection_positions.append((x_tmp, y_tmp)) intersection_positions.append((x_tmp, y_tmp))
......
...@@ -229,7 +229,6 @@ def schedule_from_file(filename) -> ScheduleGenerator: ...@@ -229,7 +229,6 @@ def schedule_from_file(filename) -> ScheduleGenerator:
# agents are always reset as not moving # agents are always reset as not moving
if len(data['agents_static'][0]) > 5: if len(data['agents_static'][0]) > 5:
print(len(data['agents_static'][0]))
agents_static = [EnvAgentStatic(d[0], d[1], d[2], d[3], d[4], d[5]) for d in data["agents_static"]] agents_static = [EnvAgentStatic(d[0], d[1], d[2], d[3], d[4], d[5]) for d in data["agents_static"]]
else: else:
agents_static = [EnvAgentStatic(d[0], d[1], d[2], d[3]) for d in data["agents_static"]] agents_static = [EnvAgentStatic(d[0], d[1], d[2], d[3]) for d in data["agents_static"]]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment