From 7ca1ed3b05e9d0ef189e4f114a84ce676b6e7e99 Mon Sep 17 00:00:00 2001 From: MLErik <baerenjesus@gmail.com> Date: Sat, 17 Aug 2019 11:12:05 -0400 Subject: [PATCH] Fixed invalid transitions: now all nodes are valid All starting positions are valid. Check for feasibility of map is not yet done. --- examples/training_example.py | 2 +- flatland/core/transition_map.py | 56 +++++++++++++++++++ flatland/envs/generators.py | 39 +++++++++---- ...test_flatland_env_sparse_rail_generator.py | 11 ++-- 4 files changed, 89 insertions(+), 19 deletions(-) diff --git a/examples/training_example.py b/examples/training_example.py index c038e7b4..60dc455f 100644 --- a/examples/training_example.py +++ b/examples/training_example.py @@ -77,7 +77,7 @@ for trials in range(1, n_trials + 1): score = 0 # Run episode - for step in range(100): + for step in range(500): # Chose an action for each agent in the environment for a in range(env.get_num_agents()): action = agent.act(obs[a]) diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py index 5e0f6cd7..018c8cd5 100644 --- a/flatland/core/transition_map.py +++ b/flatland/core/transition_map.py @@ -350,4 +350,60 @@ class GridTransitionMap(TransitionMap): return True + def fix_neighbours(self, rcPos, check_this_cell=False): + """ + Check validity of cell at rcPos = tuple(row, column) + Checks that: + - surrounding cells have inbound transitions for all the + outbound transitions of this cell. + + These are NOT checked - see transition.is_valid: + - all transitions have the mirror transitions (N->E <=> W->S) + - Reverse transitions (N -> S) only exist for a dead-end + - a cell contains either no dead-ends or exactly one + + Returns: True (valid) or False (invalid) + """ + cell_transition = self.grid[tuple(rcPos)] + + if check_this_cell: + if not self.transitions.is_valid(cell_transition): + return False + + gDir2dRC = self.transitions.gDir2dRC # [[-1,0] = N, [0,1]=E, etc] + grcPos = array(rcPos) + grcMax = self.grid.shape + + binTrans = self.get_full_transitions(*rcPos) # 16bit integer - all trans in/out + lnBinTrans = array([binTrans >> 8, binTrans & 0xff], dtype=np.uint8) # 2 x uint8 + g2binTrans = np.unpackbits(lnBinTrans).reshape(4, 4) # 4x4 x uint8 binary(0,1) + gDirOut = g2binTrans.any(axis=0) # outbound directions as boolean array (4) + giDirOut = np.argwhere(gDirOut)[:, 0] # valid outbound directions as array of int + + # loop over available outbound directions (indices) for rcPos + for iDirOut in giDirOut: + gdRC = gDir2dRC[iDirOut] # row,col increment + gPos2 = grcPos + gdRC # next cell in that direction + + # Check the adjacent cell is within bounds + # if not, then this transition is invalid! + if np.any(gPos2 < 0): + return False + if np.any(gPos2 >= grcMax): + return False + + # Get the transitions out of gPos2, using iDirOut as the inbound direction + # if there are no available transitions, ie (0,0,0,0), then rcPos is invalid + t4Trans2 = self.get_transitions(*gPos2, iDirOut) + if any(t4Trans2): + continue + else: + self.set_transition((gPos2[0], gPos2[1], iDirOut), mirror(iDirOut), 1) + return False + + return True + + +def mirror(dir): + return (dir + 2) % 4 # TODO: improvement override __getitem__ and __setitem__ (cell contents, not transitions?) diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index 2441e3e7..a0eb7ca7 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -841,8 +841,8 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation to_close = True tries = 0 while to_close: - x_tmp = 1 + np.random.randint(height - 1) - y_tmp = 1 + np.random.randint(width - 1) + x_tmp = 1 + np.random.randint(height - 2) + y_tmp = 1 + np.random.randint(width - 2) to_close = False for node_pos in node_positions: if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist: @@ -871,8 +871,6 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation available_nodes = available_nodes[np.argsort(node_dist)] # Set number of neighboring nodes - # np.random.randint(1, max_neigbours) - if len(available_nodes) >= num_neighb: connected_neighb_idx = available_nodes[ 0:2] # np.random.choice(available_nodes, num_neighb, replace=False) @@ -887,6 +885,8 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation node_stack.pop(0) # Place train stations close to the node + # We currently place them uniformly distirbuted among all cities + train_stations = [[] for i in range(num_cities)] for station in range(num_trainstations): @@ -911,6 +911,11 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation new_path = connect_from_nodes(rail_trans, rail_array, node_positions[trainstation_node], (station_x, station_y)) + # Fix all nodes with illegal transition maps + for current_node in node_positions: + if not grid_map.cell_neighbours_valid(current_node): + grid_map.fix_neighbours(current_node) + # Generate start and target node directory for all agents. # Assure that start and target are not in the same node agent_start_targets_nodes = [] @@ -924,20 +929,24 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation # Assign agents to slots for agent_idx in range(num_agents): - av_start_nodes = [idx for idx, val in enumerate(node_available_start) if val > 0] - av_target_nodes = [idx for idx, val in enumerate(node_available_target) if val > 0] - start_node = np.random.choice(av_start_nodes) - target_node = np.random.choice(av_target_nodes) + avail_start_nodes = [idx for idx, val in enumerate(node_available_start) if val > 0] + avail_target_nodes = [idx for idx, val in enumerate(node_available_target) if val > 0] + start_node = np.random.choice(avail_start_nodes) + target_node = np.random.choice(avail_target_nodes) + tries = 0 while target_node == start_node: - target_node = np.random.choice(av_target_nodes) + target_node = np.random.choice(avail_target_nodes) + tries += 1 + # Test again with new start node if no pair is found (This code needs to be improved) + if tries > 10: + start_node = np.random.choice(avail_start_nodes) + node_available_start[start_node] -= 1 node_available_target[target_node] -= 1 - print(node_available_target, node_available_start) agent_start_targets_nodes.append((start_node, target_node)) # Place agents and targets within available train stations - agents_position = [] agents_target = [] agents_direction = [] @@ -956,7 +965,13 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation start = train_stations[current_start_node][start_station_idx] agents_position.append((start[0], start[1])) agents_target.append((target[0], target[1])) - agents_direction.append(0) + + # Orient the agent correctly + for orientation in range(4): + transitions = grid_map.get_transitions(start[0], start[1], orientation) + if any(transitions) > 0: + agents_direction.append(orientation) + continue agent_idx += 1 return grid_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) diff --git a/tests/test_flatland_env_sparse_rail_generator.py b/tests/test_flatland_env_sparse_rail_generator.py index abb03528..34e6c2b4 100644 --- a/tests/test_flatland_env_sparse_rail_generator.py +++ b/tests/test_flatland_env_sparse_rail_generator.py @@ -23,23 +23,22 @@ def test_realistic_rail_generator(): env_renderer.close_window() def test_sparse_rail_generator(): - - env = RailEnv(width=50, + env = RailEnv(width=20, height=50, rail_generator=sparse_rail_generator(num_cities=10, # Number of cities in map num_intersections=3, # Number of interesections in map - num_trainstations=30, # Number of possible start/targets on map + num_trainstations=10, # Number of possible start/targets on map min_node_dist=10, # Minimal distance of nodes node_radius=2, # Proximity of stations to city center num_neighb=4, # Number of connections to other cities seed=15, # Random seed ), - number_of_agents=20, + number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) # reset to initialize agents_static env_renderer = RenderTool(env, gl="PILSVG", ) env_renderer.render_env(show=True, show_observations=True, show_predictions=False) - time.sleep(2) + time.sleep(20) env_renderer.gl.save_image("flatalnd_2_0.png") - time.sleep(100) + time.sleep(1) -- GitLab