diff --git a/examples/training_example.py b/examples/training_example.py index 313920939aabb8bc63b2198ff77d27a24d699468..67e0c740c2b15e25c65dde8036831f3ef062a831 100644 --- a/examples/training_example.py +++ b/examples/training_example.py @@ -1,9 +1,10 @@ import numpy as np from flatland.envs.generators import complex_rail_generator -from flatland.envs.observations import TreeObsForRailEnv +from flatland.envs.observations import TreeObsForRailEnv, LocalObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv +from flatland.utils.rendertools import RenderTool np.random.seed(1) @@ -12,12 +13,14 @@ np.random.seed(1) # TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()) +LocalGridObs = LocalObsForRailEnv(view_height=10, view_width=2, center=2) env = RailEnv(width=20, height=20, rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0), - obs_builder_object=TreeObservation, + obs_builder_object=LocalGridObs, number_of_agents=2) +env_renderer = RenderTool(env, gl="PILSVG", ) # Import your own Agent or use RLlib to train agents on Flatland # As an example we use a random agent here @@ -66,6 +69,7 @@ for trials in range(1, n_trials + 1): # Reset environment and get initial observations for all agents obs = env.reset() + env_renderer.reset() # Here you can also further enhance the provided observation by means of normalization # See training navigation example in the baseline repository @@ -80,6 +84,8 @@ for trials in range(1, n_trials + 1): # Environment step which returns the observations for all agents, their corresponding # reward and whether their are done next_obs, all_rewards, done, _ = env.step(action_dict) + env_renderer.render_env(show=True, show_observations=True, show_predictions=False) + # Update replay buffer and train agent for a in range(env.get_num_agents()): agent.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a])) diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py index 060785f537251484ab4fd1c520fc92bc8a564cbc..4acdf16f292a1b3ef5b78620e588dea8c3ff27e3 100644 --- a/flatland/core/env_observation_builder.py +++ b/flatland/core/env_observation_builder.py @@ -74,6 +74,7 @@ class ObservationBuilder: direction[agent.direction] = 1 return direction + class DummyObservationBuilder(ObservationBuilder): """ DummyObservationBuilder class which returns dummy observations diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 498d98eef653faa97b1f4f9d7a17048f0a8b9b70..fb8e107dead2c0bb765c01b223e2317d120dc0ee 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -698,71 +698,115 @@ class LocalObsForRailEnv(ObservationBuilder): The observation is composed of the following elements: - transition map array of the local environment around the given agent, - with dimensions (2*view_radius + 1, 2*view_radius + 1, 16), + with dimensions (view_height,2*view_width+1, 16), assuming 16 bits encoding of transitions. - - Two 2D arrays (2*view_radius + 1, 2*view_radius + 1, 2) containing respectively, + - Two 3D arrays (view_height,2*view_width+1, 2) containing respectively, if they are in the agent's vision range, its target position, the positions of the other targets. - - A 3D array (2*view_radius + 1, 2*view_radius + 1, 4) containing the one hot encoding of directions + - A 3D array (view_height,2*view_width+1, 4) containing the one hot encoding of directions of the other agents at their position coordinates, if they are in the agent's vision range. - A 4 elements array with one hot encoding of the direction. """ - def __init__(self, view_radius): + def __init__(self, view_width, view_height, center): """ :param view_radius: """ super(LocalObsForRailEnv, self).__init__() - self.view_radius = view_radius + self.view_width = view_width + self.view_height = view_height + self.center = center + self.max_padding = max(self.view_width, self.view_height - self.center) def reset(self): # We build the transition map with a view_radius empty cells expansion on each side. # This helps to collect the local transition map view when the agent is close to a border. - - self.rail_obs = np.zeros((self.env.height + 2 * self.view_radius, - self.env.width + 2 * self.view_radius, 16)) + self.max_padding = max(self.view_width, self.view_height) + self.rail_obs = np.zeros((self.env.height + 2 * self.max_padding, + self.env.width + 2 * self.max_padding, 16)) for i in range(self.env.height): for j in range(self.env.width): bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]] bitlist = [0] * (16 - len(bitlist)) + bitlist - self.rail_obs[i + self.view_radius, j + self.view_radius] = np.array(bitlist) + self.rail_obs[i + self.view_height, j + self.view_width] = np.array(bitlist) def get(self, handle): agents = self.env.agents agent = agents[handle] + agent_rel_pos = [0, 0] - local_rail_obs = self.rail_obs[agent.position[0]: agent.position[0] + 2 * self.view_radius + 1, - agent.position[1]:agent.position[1] + 2 * self.view_radius + 1] - - obs_map_state = np.zeros((2 * self.view_radius + 1, 2 * self.view_radius + 1, 2)) + # Correct agents position for padding + agent_rel_pos[0] = agent.position[0] + self.max_padding + agent_rel_pos[1] = agent.position[1] + self.max_padding - obs_other_agents_state = np.zeros((2 * self.view_radius + 1, 2 * self.view_radius + 1, 4)) + # Collect the rail information in the local field of view + local_rail_obs = self.field_of_view(agent_rel_pos, agent.direction, state=self.rail_obs) - def relative_pos(pos): - return [agent.position[0] - pos[0], agent.position[1] - pos[1]] + # Locate observed agents and their coresponding targets + obs_map_state = np.zeros((self.view_height + 1, 2 * self.view_width + 1, 2)) + obs_other_agents_state = np.zeros((self.view_height + 1, 2 * self.view_width + 1, 4)) - def is_in(rel_pos): - return (abs(rel_pos[0]) <= self.view_radius) and (abs(rel_pos[1]) <= self.view_radius) + # Collect visible cells as set to be plotted + visited = self.field_of_view(agent.position, agent.direction) - target_rel_pos = relative_pos(agent.target) - if is_in(target_rel_pos): - obs_map_state[self.view_radius + np.array(target_rel_pos)][0] += 1 + # Add the visible cells to the observed cells + self.env.dev_obs_dict[handle] = visited - for i in range(len(agents)): - if i != handle: # TODO: handle used as index...? - agent2 = agents[i] + direction = self._get_one_hot_for_agent_direction(agent) - agent_2_rel_pos = relative_pos(agent2.position) - if is_in(agent_2_rel_pos): - obs_other_agents_state[self.view_radius + agent_2_rel_pos[0], - self.view_radius + agent_2_rel_pos[1]][agent2.direction] += 1 + return local_rail_obs, obs_map_state, obs_other_agents_state, direction - target_rel_pos_2 = relative_pos(agent2.position) - if is_in(target_rel_pos_2): - obs_map_state[self.view_radius + np.array(target_rel_pos_2)][1] += 1 + def get_many(self, handles=None): + """ + Called whenever an observation has to be computed for the `env' environment, for each agent with handle + in the `handles' list. + """ - direction = self._get_one_hot_for_agent_direction(agent) + observations = {} + for h in handles: + observations[h] = self.get(h) + return observations - return local_rail_obs, obs_map_state, obs_other_agents_state, direction + def field_of_view(self, position, direction, state=None): + # Compute the local field of view for an agent in the environment + data_collection = False + if state is not None: + temp_visible_data = np.zeros(shape=(self.view_height, 2 * self.view_width + 1, 16)) + data_collection = True + if direction == 0: + origin = (position[0] + self.center, position[1] - self.view_width) + elif direction == 1: + origin = (position[0] - self.view_width, position[1] - self.center) + elif direction == 2: + origin = (position[0] - self.center, position[1] + self.view_width) + else: + origin = (position[0] + self.view_width, position[1] + self.center) + visible = set() + for h in range(self.view_height): + for w in range(2 * self.view_width + 1): + if direction == 0: + if 0 <= origin[0] - h < self.env.height and 0 <= origin[1] + w < self.env.width: + visible.add((origin[0] - h, origin[1] + w)) + if data_collection: + temp_visible_data[h, w, :] = state[origin[0] - h, origin[1] + w, :] + elif direction == 1: + if 0 <= origin[0] + w < self.env.height and 0 <= origin[1] + h < self.env.width: + visible.add((origin[0] + w, origin[1] + h)) + if data_collection: + temp_visible_data[h, w, :] = state[origin[0] + w, origin[1] + h, :] + elif direction == 2: + if 0 <= origin[0] - h < self.env.height and 0 <= origin[1] + w < self.env.width: + visible.add((origin[0] + h, origin[1] - w)) + if data_collection: + temp_visible_data[h, w, :] = state[origin[0] + h, origin[1] - w, :] + else: + if 0 <= origin[0] - h < self.env.height and 0 <= origin[1] + w < self.env.width: + visible.add((origin[0] - w, origin[1] - h)) + if data_collection: + temp_visible_data[h, w, :] = state[origin[0] - w, origin[1] - h, :] + if data_collection: + return temp_visible_data + else: + return visible