diff --git a/examples/training_example.py b/examples/training_example.py index 313920939aabb8bc63b2198ff77d27a24d699468..a05f7c727ac9a56453951de453ea184cab1ea4fc 100644 --- a/examples/training_example.py +++ b/examples/training_example.py @@ -1,9 +1,10 @@ import numpy as np from flatland.envs.generators import complex_rail_generator -from flatland.envs.observations import TreeObsForRailEnv +from flatland.envs.observations import TreeObsForRailEnv, LocalObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv +from flatland.utils.rendertools import RenderTool np.random.seed(1) @@ -12,12 +13,14 @@ np.random.seed(1) # TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()) -env = RailEnv(width=20, - height=20, +LocalGridObs = LocalObsForRailEnv(view_height=10, view_width=2, center=2) +env = RailEnv(width=50, + height=50, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0), - obs_builder_object=TreeObservation, - number_of_agents=2) + obs_builder_object=LocalGridObs, + number_of_agents=5) +env_renderer = RenderTool(env, gl="PILSVG", ) # Import your own Agent or use RLlib to train agents on Flatland # As an example we use a random agent here @@ -66,6 +69,7 @@ for trials in range(1, n_trials + 1): # Reset environment and get initial observations for all agents obs = env.reset() + env_renderer.reset() # Here you can also further enhance the provided observation by means of normalization # See training navigation example in the baseline repository @@ -80,6 +84,8 @@ for trials in range(1, n_trials + 1): # Environment step which returns the observations for all agents, their corresponding # reward and whether their are done next_obs, all_rewards, done, _ = env.step(action_dict) + env_renderer.render_env(show=True, show_observations=True, show_predictions=False) + # Update replay buffer and train agent for a in range(env.get_num_agents()): agent.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a])) diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 8b0c94f12300c9d9e44b478c637cac74f636e98f..e6646cb9b7f4325e73496c04dd8bdf837c34fe30 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -706,71 +706,136 @@ class LocalObsForRailEnv(ObservationBuilder): The observation is composed of the following elements: - transition map array of the local environment around the given agent, - with dimensions (2*view_radius + 1, 2*view_radius + 1, 16), + with dimensions (view_height,2*view_width+1, 16), assuming 16 bits encoding of transitions. - - Two 2D arrays (2*view_radius + 1, 2*view_radius + 1, 2) containing respectively, + - Two 2D arrays (view_height,2*view_width+1, 2) containing respectively, if they are in the agent's vision range, its target position, the positions of the other targets. - - A 3D array (2*view_radius + 1, 2*view_radius + 1, 4) containing the one hot encoding of directions + - A 2D array (view_height,2*view_width+1, 4) containing the one hot encoding of directions of the other agents at their position coordinates, if they are in the agent's vision range. - A 4 elements array with one hot encoding of the direction. + + Use the parameters view_width and view_height to define the rectangular view of the agent. + The center parameters moves the agent along the height axis of this rectangle. If it is 0 the agent only has + observation in front of it. """ - def __init__(self, view_radius): - """ - :param view_radius: - """ + def __init__(self, view_width, view_height, center): + super(LocalObsForRailEnv, self).__init__() - self.view_radius = view_radius + self.view_width = view_width + self.view_height = view_height + self.center = center + self.max_padding = max(self.view_width, self.view_height - self.center) def reset(self): # We build the transition map with a view_radius empty cells expansion on each side. # This helps to collect the local transition map view when the agent is close to a border. - - self.rail_obs = np.zeros((self.env.height + 2 * self.view_radius, - self.env.width + 2 * self.view_radius, 16)) + self.max_padding = max(self.view_width, self.view_height) + self.rail_obs = np.zeros((self.env.height, + self.env.width, 16)) for i in range(self.env.height): for j in range(self.env.width): bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]] bitlist = [0] * (16 - len(bitlist)) + bitlist - self.rail_obs[i + self.view_radius, j + self.view_radius] = np.array(bitlist) + self.rail_obs[i, j] = np.array(bitlist) def get(self, handle): agents = self.env.agents agent = agents[handle] - local_rail_obs = self.rail_obs[agent.position[0]: agent.position[0] + 2 * self.view_radius + 1, - agent.position[1]:agent.position[1] + 2 * self.view_radius + 1] - - obs_map_state = np.zeros((2 * self.view_radius + 1, 2 * self.view_radius + 1, 2)) - - obs_other_agents_state = np.zeros((2 * self.view_radius + 1, 2 * self.view_radius + 1, 4)) - - def relative_pos(pos): - return [agent.position[0] - pos[0], agent.position[1] - pos[1]] - - def is_in(rel_pos): - return (abs(rel_pos[0]) <= self.view_radius) and (abs(rel_pos[1]) <= self.view_radius) - - target_rel_pos = relative_pos(agent.target) - if is_in(target_rel_pos): - obs_map_state[self.view_radius + np.array(target_rel_pos)][0] += 1 - - for i in range(len(agents)): - if i != handle: # TODO: handle used as index...? - agent2 = agents[i] - - agent_2_rel_pos = relative_pos(agent2.position) - if is_in(agent_2_rel_pos): - obs_other_agents_state[self.view_radius + agent_2_rel_pos[0], - self.view_radius + agent_2_rel_pos[1]][agent2.direction] += 1 + # Correct agents position for padding + # agent_rel_pos[0] = agent.position[0] + self.max_padding + # agent_rel_pos[1] = agent.position[1] + self.max_padding + + # Collect visible cells as set to be plotted + visited, rel_coords = self.field_of_view(agent.position, agent.direction, ) + local_rail_obs = None + + # Add the visible cells to the observed cells + self.env.dev_obs_dict[handle] = set(visited) + + # Locate observed agents and their coresponding targets + local_rail_obs = np.zeros((self.view_height, 2 * self.view_width + 1, 16)) + obs_map_state = np.zeros((self.view_height, 2 * self.view_width + 1, 2)) + obs_other_agents_state = np.zeros((self.view_height, 2 * self.view_width + 1, 4)) + _idx = 0 + for pos in visited: + curr_rel_coord = rel_coords[_idx] + local_rail_obs[curr_rel_coord[0], curr_rel_coord[1], :] = self.rail_obs[pos[0], pos[1], :] + if pos == agent.target: + obs_map_state[curr_rel_coord[0], curr_rel_coord[1], 0] = 1 + else: + for tmp_agent in agents: + if pos == tmp_agent.target: + obs_map_state[curr_rel_coord[0], curr_rel_coord[1], 1] = 1 + if pos != agent.position: + for tmp_agent in agents: + if pos == tmp_agent.position: + obs_other_agents_state[curr_rel_coord[0], curr_rel_coord[1], :] = np.identity(4)[ + tmp_agent.direction] + + _idx += 1 + + direction = np.identity(4)[agent.direction] + return local_rail_obs, obs_map_state, obs_other_agents_state, direction - target_rel_pos_2 = relative_pos(agent2.position) - if is_in(target_rel_pos_2): - obs_map_state[self.view_radius + np.array(target_rel_pos_2)][1] += 1 + def get_many(self, handles=None): + """ + Called whenever an observation has to be computed for the `env' environment, for each agent with handle + in the `handles' list. + """ - direction = self._get_one_hot_for_agent_direction(agent) + observations = {} + for h in handles: + observations[h] = self.get(h) + return observations - return local_rail_obs, obs_map_state, obs_other_agents_state, direction + def field_of_view(self, position, direction, state=None): + # Compute the local field of view for an agent in the environment + data_collection = False + if state is not None: + temp_visible_data = np.zeros(shape=(self.view_height, 2 * self.view_width + 1, 16)) + data_collection = True + if direction == 0: + origin = (position[0] + self.center, position[1] - self.view_width) + elif direction == 1: + origin = (position[0] - self.view_width, position[1] - self.center) + elif direction == 2: + origin = (position[0] - self.center, position[1] + self.view_width) + else: + origin = (position[0] + self.view_width, position[1] + self.center) + visible = list() + rel_coords = list() + for h in range(self.view_height): + for w in range(2 * self.view_width + 1): + if direction == 0: + if 0 <= origin[0] - h < self.env.height and 0 <= origin[1] + w < self.env.width: + visible.append((origin[0] - h, origin[1] + w)) + rel_coords.append((h, w)) + # if data_collection: + # temp_visible_data[h, w, :] = state[origin[0] - h, origin[1] + w, :] + elif direction == 1: + if 0 <= origin[0] + w < self.env.height and 0 <= origin[1] + h < self.env.width: + visible.append((origin[0] + w, origin[1] + h)) + rel_coords.append((h, w)) + # if data_collection: + # temp_visible_data[h, w, :] = state[origin[0] + w, origin[1] + h, :] + elif direction == 2: + if 0 <= origin[0] + h < self.env.height and 0 <= origin[1] - w < self.env.width: + visible.append((origin[0] + h, origin[1] - w)) + rel_coords.append((h, w)) + # if data_collection: + # temp_visible_data[h, w, :] = state[origin[0] + h, origin[1] - w, :] + else: + if 0 <= origin[0] - w < self.env.height and 0 <= origin[1] - h < self.env.width: + visible.append((origin[0] - w, origin[1] - h)) + rel_coords.append((h, w)) + # if data_collection: + # temp_visible_data[h, w, :] = state[origin[0] - w, origin[1] - h, :] + if data_collection: + return temp_visible_data + else: + return visible, rel_coords