diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index a1ad9a85d9a41f57c317a0b4b0bc61796e4e0f4a..da8cacef2a3e9972ef4b90094c59f448ca07bc37 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -13,7 +13,7 @@ np.random.seed(1) # Training on simple small tasks is the best way to get familiar with the environment # Use a the malfunction generator to break agents from time to time -stochastic_data = {'prop_malfunction': 0.1, # Percentage of defective agents +stochastic_data = {'prop_malfunction': 1., # Percentage of defective agents 'malfunction_rate': 30, # Rate of malfunction occurence 'min_duration': 3, # Minimal duration of malfunction 'max_duration': 20 # Max duration of malfunction @@ -31,11 +31,11 @@ speed_ration_map = {1.: 0.25, # Fast passenger train env = RailEnv(width=50, height=50, rail_generator=sparse_rail_generator(num_cities=25, # Number of cities in map (where train stations are) - num_intersections=0, # Number of intersections (no start / target) + num_intersections=10, # Number of intersections (no start / target) num_trainstations=50, # Number of possible start/targets on map min_node_dist=3, # Minimal distance of nodes - node_radius=2, # Proximity of stations to city center - num_neighb=3, # Number of connections to other cities/intersections + node_radius=4, # Proximity of stations to city center + num_neighb=4, # Number of connections to other cities/intersections seed=15, # Random seed grid_mode=True, enhance_intersection=False diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py index bb954998688772a7ce69e5228cff3e16d037f2af..232d6fdab02c57da95bf04c631e4905986c71327 100644 --- a/flatland/core/transition_map.py +++ b/flatland/core/transition_map.py @@ -8,6 +8,7 @@ from numpy import array from flatland.core.grid.grid4 import Grid4Transitions from flatland.core.grid.grid4_utils import get_new_position +from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transitions import Transitions @@ -482,6 +483,14 @@ class GridTransitionMap(TransitionMap): grcPos = array(rcPos) grcMax = self.grid.shape + # Transition elements + transitions = RailEnvTransitions() + cells = transitions.transition_list + simple_switch_east_south = transitions.rotate_transition(cells[10], 90) + simple_switch_west_south = transitions.rotate_transition(cells[2], 270) + symmetrical = cells[6] + double_slip = cells[5] + three_way_transitions = [simple_switch_east_south, simple_switch_west_south, symmetrical] # loop over available outbound directions (indices) for rcPos self.set_transitions(rcPos, 0) @@ -517,25 +526,18 @@ class GridTransitionMap(TransitionMap): self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[0])), connect_directions[1], 1) self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[1])), connect_directions[0], 1) - # Find feasible connection fro three entries + # Find feasible connection for three entries if number_of_incoming == 3: + transition = np.random.choice(three_way_transitions, 1) hole = np.argwhere(incoming_connections < 1)[0][0] - connect_directions = [(hole + 1) % 4, (hole + 2) % 4, (hole + 3) % 4] - self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[0])), connect_directions[1], 1) - self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[0])), connect_directions[2], 1) - self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[1])), connect_directions[0], 1) - self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[2])), connect_directions[0], 1) - # Make a cross + transition = transitions.rotate_transition(transition, int(hole * 90)) + self.set_transitions((rcPos[0], rcPos[1]), transition) + + # Make a double slip switch if number_of_incoming == 4: - connect_directions = np.arange(4) - self.set_transition((rcPos[0], rcPos[1], connect_directions[0]), connect_directions[0], 1) - self.set_transition((rcPos[0], rcPos[1], connect_directions[0]), connect_directions[1], 1) - self.set_transition((rcPos[0], rcPos[1], connect_directions[1]), connect_directions[0], 1) - self.set_transition((rcPos[0], rcPos[1], connect_directions[1]), connect_directions[1], 1) - self.set_transition((rcPos[0], rcPos[1], connect_directions[2]), connect_directions[2], 1) - self.set_transition((rcPos[0], rcPos[1], connect_directions[2]), connect_directions[3], 1) - self.set_transition((rcPos[0], rcPos[1], connect_directions[3]), connect_directions[2], 1) - self.set_transition((rcPos[0], rcPos[1], connect_directions[3]), connect_directions[3], 1) + rotation = np.random.randint(2) + transition = transitions.rotate_transition(double_slip, int(rotation * 90)) + self.set_transitions((rcPos[0], rcPos[1]), transition) return True diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 280fd345d8c1db206c42dc30ba2d7b5fa2e8a69e..d59ca7dc8f15d7f550e59f426a34cdd3275d954a 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -266,7 +266,8 @@ class RailEnv(Environment): def _agent_malfunction(self, agent): # Decrease counter for next event - agent.malfunction_data['next_malfunction'] -= 1 + if agent.malfunction_data['malfunction_rate'] > 0: + agent.malfunction_data['next_malfunction'] -= 1 # Only agents that have a positive rate for malfunctions and are not currently broken are considered if agent.malfunction_data['malfunction_rate'] > 0 >= agent.malfunction_data['malfunction']: @@ -294,7 +295,8 @@ class RailEnv(Environment): alpha = 1.0 beta = 1.0 - + # Epsilon to avoid rounding errors + epsilon = 0.01 invalid_action_penalty = 0 # previously -2; GIACOMO: we decided that invalid actions will carry no penalty step_penalty = -1 * alpha global_reward = 1 * beta @@ -310,7 +312,6 @@ class RailEnv(Environment): self.rewards_dict = {i: r + global_reward for i, r in self.rewards_dict.items()} return self._get_observations(), self.rewards_dict, self.dones, {} - # for i in range(len(self.agents_handles)): for i_agent in range(self.get_num_agents()): agent = self.agents[i_agent] agent.old_direction = agent.direction @@ -328,15 +329,22 @@ class RailEnv(Environment): # The train is broken if agent.malfunction_data['malfunction'] > 0: - agent.malfunction_data['malfunction'] -= 1 + # Last step of malfunction --> Agent starts moving again after getting fixed + if agent.malfunction_data['malfunction'] < 2: + agent.malfunction_data['malfunction'] -= 1 + self.agents[i_agent].moving = True + action_dict[i_agent] = RailEnvActions.DO_NOTHING - # Broken agents are stopped - self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed'] - self.agents[i_agent].moving = False - action_dict[i_agent] = RailEnvActions.DO_NOTHING + else: + agent.malfunction_data['malfunction'] -= 1 - # Nothing left to do with broken agent - continue + # Broken agents are stopped + self.rewards_dict[i_agent] += step_penalty * agent.speed_data['speed'] + self.agents[i_agent].moving = False + action_dict[i_agent] = RailEnvActions.DO_NOTHING + + # Nothing left to do with broken agent + continue if action_dict[i_agent] < 0 or action_dict[i_agent] > len(RailEnvActions): print('ERROR: illegal action=', action_dict[i_agent], @@ -350,7 +358,8 @@ class RailEnv(Environment): # Keep moving action = RailEnvActions.MOVE_FORWARD - if action == RailEnvActions.STOP_MOVING and agent.moving and agent.speed_data['position_fraction'] == 0.: + if action == RailEnvActions.STOP_MOVING and agent.moving and agent.speed_data[ + 'position_fraction'] <= epsilon: # Only allow halting an agent on entering new cells. agent.moving = False self.rewards_dict[i_agent] += stop_penalty @@ -372,7 +381,7 @@ class RailEnv(Environment): # If the agent can make an action action_selected = False - if agent.speed_data['position_fraction'] == 0.: + if agent.speed_data['position_fraction'] <= epsilon: if action != RailEnvActions.DO_NOTHING and action != RailEnvActions.STOP_MOVING: cell_free, new_cell_valid, new_direction, new_position, transition_valid = \ self._check_action_on_agent(action, agent) diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 39796515f73c7702ba4dc162301a63b0186dc1d3..c23593463c2b679cfd09fcbcf390c3d5a05acde4 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -563,7 +563,6 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 node_positions = [] city_positions = [] intersection_positions = [] - # Evenly distribute cities and intersections if grid_mode: tot_num_node = num_intersections + num_cities @@ -572,10 +571,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 nodes_per_col = int(np.ceil(tot_num_node / nodes_per_row)) x_positions = np.linspace(node_radius, height - node_radius, nodes_per_row, dtype=int) y_positions = np.linspace(node_radius, width - node_radius, nodes_per_col, dtype=int) - - fraction = 0 - city_fraction = num_cities / tot_num_node - step = np.gcd(num_intersections, num_cities) / tot_num_node + city_idx = np.random.choice(np.arange(tot_num_node), num_cities) for node_idx in range(num_cities + num_intersections): to_close = True @@ -608,10 +604,9 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 warnings.warn("Could not set nodes, please change initial parameters!!!!") break else: - fraction = (fraction + step) % 1. x_tmp = x_positions[node_idx % nodes_per_row] y_tmp = y_positions[node_idx // nodes_per_row] - if len(city_positions) < num_cities and fraction < city_fraction: + if node_idx in city_idx: city_positions.append((x_tmp, y_tmp)) else: intersection_positions.append((x_tmp, y_tmp)) diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py index 7a116f9e6b5a233112f2e4b5b73556f13b761be6..e81da924d1157fa9d5deef56306643c6a356cc8d 100644 --- a/flatland/envs/schedule_generators.py +++ b/flatland/envs/schedule_generators.py @@ -229,7 +229,6 @@ def schedule_from_file(filename) -> ScheduleGenerator: # agents are always reset as not moving if len(data['agents_static'][0]) > 5: - print(len(data['agents_static'][0])) agents_static = [EnvAgentStatic(d[0], d[1], d[2], d[3], d[4], d[5]) for d in data["agents_static"]] else: agents_static = [EnvAgentStatic(d[0], d[1], d[2], d[3]) for d in data["agents_static"]]