diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..71c9cbb42e95985dc7927602d0017479eb75dfc7 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2019 SBB AG + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/examples/sample_10_10_rail.npy b/examples/sample_10_10_rail.npy new file mode 100644 index 0000000000000000000000000000000000000000..a8dc0d41ecfff0c5c3a8b7446b1dd6246573608e Binary files /dev/null and b/examples/sample_10_10_rail.npy differ diff --git a/examples/temporary_example.py b/examples/temporary_example.py index 2444d3895cf43499ae5108ea16981e55a5989156..02c282cb374914651d063a2b118fb688257e7631 100644 --- a/examples/temporary_example.py +++ b/examples/temporary_example.py @@ -1,41 +1,75 @@ -from flatland.core.env import RailEnv -from flatland.utils.rail_env_generator import * +import random +import numpy as np +import matplotlib.pyplot as plt + +from flatland.envs.rail_env import * +from flatland.core.env_observation_builder import TreeObsForRailEnv from flatland.utils.rendertools import * -random.seed(1) -np.random.seed(1) +random.seed(0) +np.random.seed(0) +transition_probability = [1.0, # empty cell - Case 0 + 1.0, # Case 1 - straight + 1.0, # Case 2 - simple switch + 0.3, # Case 3 - diamond drossing + 0.5, # Case 4 - single slip + 0.5, # Case 5 - double slip + 0.2, # Case 6 - symmetrical + 0.0] # Case 7 - dead end # Example generate a random rail -rail = generate_random_rail(20, 20) +env = RailEnv(width=20, + height=20, + rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), + number_of_agents=10) + +# env = RailEnv(width=20, +# height=20, +# rail_generator=rail_from_list_of_saved_GridTransitionMap_generator(['examples/sample_10_10_rail.npy']), +# number_of_agents=10) -env = RailEnv(rail, number_of_agents=10) env.reset() env_renderer = RenderTool(env) env_renderer.renderEnv(show=True) - +""" # Example generate a rail given a manual specification, # a map of tuples (cell_type, rotation) specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (7, 0), (0, 0)], [(7, 270), (1, 90), (1, 90), (1, 90), (2, 90), (7, 90)]] -rail = generate_rail_from_manual_specifications(specs) -env = RailEnv(rail, number_of_agents=1) +env = RailEnv(width=6, + height=2, + rail_generator=rail_from_manual_specifications_generator(specs), + number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2)) handle = env.get_agent_handles() - -env.reset() - -env.agents_position = [[1, 4]] -env.agents_target = [[1, 1]] -env.agents_direction = [1] +env.agents_position[0] = [1, 4] +env.agents_target[0] = [1, 1] +env.agents_direction[0] = 1 +# TODO: watch out: if these variables are overridden, the obs_builder object has to be reset, too! +env.obs_builder.reset() +""" +env = RailEnv(width=7, + height=7, + rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), + number_of_agents=2) + +# Print the distance map of each cell to the target of the first agent +# for i in range(4): +# print(env.obs_builder.distance_map[0, :, :, i]) + +# Print the observation vector for agent 0 +obs, all_rewards, done, _ = env.step({0:0}) +for i in range(env.number_of_agents): + env.obs_builder.util_print_obs_subtree(tree=obs[i], num_features_per_node=5) env_renderer = RenderTool(env) env_renderer.renderEnv(show=True) - print("Manual control: s=perform step, q=quit, [agent id] [1-2-3 action] \ (turnleft+move, move to front, turnright+move)") for step in range(100): diff --git a/examples/training_navigation.py b/examples/training_navigation.py new file mode 100644 index 0000000000000000000000000000000000000000..975d33fb3139ebc2040b941dbe73d8aeb3b225eb --- /dev/null +++ b/examples/training_navigation.py @@ -0,0 +1,104 @@ +from flatland.envs.rail_env import * +from flatland.core.env_observation_builder import TreeObsForRailEnv +from flatland.utils.rendertools import * +from flatland.baselines.dueling_double_dqn import Agent +from collections import deque +import torch,random + +random.seed(1) +np.random.seed(1) + + +# Example generate a rail given a manual specification, +# a map of tuples (cell_type, rotation) +transition_probability = [1.0, # empty cell - Case 0 + 1.0, # Case 1 - straight + 1.0, # Case 2 - simple switch + 0.3, # Case 3 - diamond drossing + 0.5, # Case 4 - single slip + 0.5, # Case 5 - double slip + 0.2, # Case 6 - symmetrical + 0.0] # Case 7 - dead end + +# Example generate a random rail +env = RailEnv(width=7, + height=7, + rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), + number_of_agents=1) +env_renderer = RenderTool(env) +handle = env.get_agent_handles() + +state_size = 105 +action_size = 4 +n_trials = 5000 +eps = 1. +eps_end = 0.005 +eps_decay = 0.998 +action_dict = dict() +scores_window = deque(maxlen=100) +done_window = deque(maxlen=100) +scores = [] +dones_list = [] + +agent = Agent(state_size, action_size, "FC", 0) + +for trials in range(1, n_trials + 1): + + # Reset environment + obs = env.reset() + # env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5) + + + score = 0 + env_done = 0 + + # Run episode + for step in range(100): + if trials >= 114: + env_renderer.renderEnv(show=True) + + # Action + for a in range(env.number_of_agents): + action = agent.act(np.array(obs[a]), eps=eps) + action_dict.update({a: action}) + + # Environment step + next_obs, all_rewards, done, _ = env.step(action_dict) + + + # Update replay buffer and train agent + for a in range(env.number_of_agents): + agent.step(obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a]) + score += all_rewards[a] + + obs = next_obs.copy() + + if done['__all__']: + env_done = 1 + break + # Epsioln decay + eps = max(eps_end, eps_decay * eps) # decrease epsilon + + done_window.append(env_done) + scores_window.append(score) # save most recent score + scores.append(np.mean(scores_window)) + dones_list.append((np.mean(done_window))) + + print('\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%\tEpsilon: {:.2f}'.format(env.number_of_agents, + trials, + np.mean( + scores_window), + 100 * np.mean( + done_window), + eps), + end=" ") + if trials % 100 == 0: + print( + '\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%\tEpsilon: {:.2f}'.format(env.number_of_agents, + trials, + np.mean( + scores_window), + 100 * np.mean( + done_window), + eps)) + torch.save(agent.qnetwork_local.state_dict(), '../flatland/baselines/Nets/avoid_checkpoint' + str(trials) + '.pth') diff --git a/flatland/baselines/__init__.py b/flatland/baselines/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/flatland/baselines/dueling_double_dqn.py b/flatland/baselines/dueling_double_dqn.py new file mode 100644 index 0000000000000000000000000000000000000000..ee75a615cda4d26a06810d2b7d109fe5691d5ac4 --- /dev/null +++ b/flatland/baselines/dueling_double_dqn.py @@ -0,0 +1,189 @@ +import numpy as np +import random +from collections import namedtuple, deque +import os +from flatland.baselines.model import QNetwork, QNetwork2 +import torch +import torch.nn.functional as F +import torch.optim as optim +import copy + +BUFFER_SIZE = int(1e5) # replay buffer size +BATCH_SIZE = 512 # minibatch size +GAMMA = 0.99 # discount factor 0.99 +TAU = 1e-3 # for soft update of target parameters +LR = 0.5e-4 # learning rate 5 +UPDATE_EVERY = 10 # how often to update the network +double_dqn = True # If using double dqn algorithm +input_channels = 5 # Number of Input channels + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +device = torch.device("cpu") +print(device) + + +class Agent: + """Interacts with and learns from the environment.""" + + def __init__(self, state_size, action_size, net_type, seed, double_dqn=True, input_channels=5): + """Initialize an Agent object. + + Params + ====== + state_size (int): dimension of each state + action_size (int): dimension of each action + seed (int): random seed + """ + self.state_size = state_size + self.action_size = action_size + self.seed = random.seed(seed) + self.version = net_type + self.double_dqn = double_dqn + # Q-Network + if self.version == "Conv": + self.qnetwork_local = QNetwork2(state_size, action_size, seed, input_channels).to(device) + self.qnetwork_target = copy.deepcopy(self.qnetwork_local) + else: + self.qnetwork_local = QNetwork(state_size, action_size, seed).to(device) + self.qnetwork_target = copy.deepcopy(self.qnetwork_local) + + self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=LR) + + # Replay memory + self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed) + # Initialize time step (for updating every UPDATE_EVERY steps) + self.t_step = 0 + + def save(self, filename): + torch.save(self.qnetwork_local.state_dict(), filename + ".local") + torch.save(self.qnetwork_target.state_dict(), filename + ".target") + + def load(self, filename): + if os.path.exists(filename + ".local"): + self.qnetwork_local.load_state_dict(torch.load(filename + ".local")) + if os.path.exists(filename + ".target"): + self.qnetwork_target.load_state_dict(torch.load(filename + ".target")) + + def step(self, state, action, reward, next_state, done): + # Save experience in replay memory + self.memory.add(state, action, reward, next_state, done) + + # Learn every UPDATE_EVERY time steps. + self.t_step = (self.t_step + 1) % UPDATE_EVERY + if self.t_step == 0: + # If enough samples are available in memory, get random subset and learn + if len(self.memory) > BATCH_SIZE: + experiences = self.memory.sample() + self.learn(experiences, GAMMA) + + def act(self, state, eps=0.): + """Returns actions for given state as per current policy. + + Params + ====== + state (array_like): current state + eps (float): epsilon, for epsilon-greedy action selection + """ + state = torch.from_numpy(state).float().unsqueeze(0).to(device) + self.qnetwork_local.eval() + with torch.no_grad(): + action_values = self.qnetwork_local(state) + self.qnetwork_local.train() + + # Epsilon-greedy action selection + if random.random() > eps: + return np.argmax(action_values.cpu().data.numpy()) + else: + return random.choice(np.arange(self.action_size)) + + def learn(self, experiences, gamma): + + """Update value parameters using given batch of experience tuples. + + Params + ====== + experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples + gamma (float): discount factor + """ + states, actions, rewards, next_states, dones = experiences + + # Get expected Q values from local model + Q_expected = self.qnetwork_local(states).gather(1, actions) + + if self.double_dqn: + # Double DQN + q_best_action = self.qnetwork_local(next_states).max(1)[1] + Q_targets_next = self.qnetwork_target(next_states).gather(1, q_best_action.unsqueeze(-1)) + else: + # DQN + Q_targets_next = self.qnetwork_target(next_states).detach().max(1)[0].unsqueeze(-1) + + # Compute Q targets for current states + + Q_targets = rewards + (gamma * Q_targets_next * (1 - dones)) + + # Compute loss + loss = F.mse_loss(Q_expected, Q_targets) + # Minimize the loss + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + # ------------------- update target network ------------------- # + self.soft_update(self.qnetwork_local, self.qnetwork_target, TAU) + + def soft_update(self, local_model, target_model, tau): + """Soft update model parameters. + θ_target = τ*θ_local + (1 - τ)*θ_target + + Params + ====== + local_model (PyTorch model): weights will be copied from + target_model (PyTorch model): weights will be copied to + tau (float): interpolation parameter + """ + for target_param, local_param in zip(target_model.parameters(), local_model.parameters()): + target_param.data.copy_(tau * local_param.data + (1.0 - tau) * target_param.data) + + +class ReplayBuffer: + """Fixed-size buffer to store experience tuples.""" + + def __init__(self, action_size, buffer_size, batch_size, seed): + """Initialize a ReplayBuffer object. + + Params + ====== + action_size (int): dimension of each action + buffer_size (int): maximum size of buffer + batch_size (int): size of each training batch + seed (int): random seed + """ + self.action_size = action_size + self.memory = deque(maxlen=buffer_size) + self.batch_size = batch_size + self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"]) + self.seed = random.seed(seed) + + def add(self, state, action, reward, next_state, done): + """Add a new experience to memory.""" + e = self.experience(np.expand_dims(state, 0), action, reward, np.expand_dims(next_state, 0), done) + self.memory.append(e) + + def sample(self): + """Randomly sample a batch of experiences from memory.""" + experiences = random.sample(self.memory, k=self.batch_size) + + states = torch.from_numpy(np.vstack([e.state for e in experiences if e is not None])).float().to(device) + actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).long().to(device) + rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(device) + next_states = torch.from_numpy(np.vstack([e.next_state for e in experiences if e is not None])).float().to( + device) + dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to( + device) + + return (states, actions, rewards, next_states, dones) + + def __len__(self): + """Return the current size of internal memory.""" + return len(self.memory) diff --git a/flatland/baselines/model.py b/flatland/baselines/model.py new file mode 100644 index 0000000000000000000000000000000000000000..7a5b3d613342a4fba8e2c8f1f45df21381e12684 --- /dev/null +++ b/flatland/baselines/model.py @@ -0,0 +1,61 @@ +import torch.nn as nn +import torch.nn.functional as F + + +class QNetwork(nn.Module): + def __init__(self, state_size, action_size, seed, hidsize1=128, hidsize2=128): + super(QNetwork, self).__init__() + + self.fc1_val = nn.Linear(state_size, hidsize1) + self.fc2_val = nn.Linear(hidsize1, hidsize2) + self.fc3_val = nn.Linear(hidsize2, 1) + + self.fc1_adv = nn.Linear(state_size, hidsize1) + self.fc2_adv = nn.Linear(hidsize1, hidsize2) + self.fc3_adv = nn.Linear(hidsize2, action_size) + + def forward(self, x): + val = F.relu(self.fc1_val(x)) + val = F.relu(self.fc2_val(val)) + val = self.fc3_val(val) + + # advantage calculation + adv = F.relu(self.fc1_adv(x)) + adv = F.relu(self.fc2_adv(adv)) + adv = self.fc3_adv(adv) + return val + adv - adv.mean() + + +class QNetwork2(nn.Module): + def __init__(self, state_size, action_size, seed, input_channels, hidsize1=128, hidsize2=64): + super(QNetwork2, self).__init__() + self.conv1 = nn.Conv2d(input_channels, 16, kernel_size=3, stride=1) + self.bn1 = nn.BatchNorm2d(16) + self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=3) + self.bn2 = nn.BatchNorm2d(32) + self.conv3 = nn.Conv2d(32, 64, kernel_size=5, stride=3) + self.bn3 = nn.BatchNorm2d(64) + + self.fc1_val = nn.Linear(6400, hidsize1) + self.fc2_val = nn.Linear(hidsize1, hidsize2) + self.fc3_val = nn.Linear(hidsize2, 1) + + self.fc1_adv = nn.Linear(6400, hidsize1) + self.fc2_adv = nn.Linear(hidsize1, hidsize2) + self.fc3_adv = nn.Linear(hidsize2, action_size) + + def forward(self, x): + x = F.relu(self.conv1(x)) + x = F.relu(self.conv2(x)) + x = F.relu(self.conv3(x)) + + # value function approximation + val = F.relu(self.fc1_val(x.view(x.size(0), -1))) + val = F.relu(self.fc2_val(val)) + val = self.fc3_val(val) + + # advantage calculation + adv = F.relu(self.fc1_adv(x.view(x.size(0), -1))) + adv = F.relu(self.fc2_adv(adv)) + adv = self.fc3_adv(adv) + return val + adv - adv.mean() diff --git a/flatland/core/env.py b/flatland/core/env.py index a7e63fd4e2697996a4dbe0a684735fd46153fb1b..284afdffb6ce46ac481018af469c7d2e024fc792 100644 --- a/flatland/core/env.py +++ b/flatland/core/env.py @@ -3,9 +3,6 @@ The env module defines the base Environment class. The base Environment class is adapted from rllib.env.MultiAgentEnv (https://github.com/ray-project/ray). """ -import random - -from .env_observation_builder import TreeObsForRailEnv class Environment: @@ -93,303 +90,3 @@ class Environment: function. """ raise NotImplementedError() - - -class RailEnv: - """ - RailEnv environment class. - - RailEnv is an environment inspired by a (simplified version of) a rail - network, in which agents (trains) have to navigate to their target - locations in the shortest time possible, while at the same time cooperating - to avoid bottlenecks. - - The valid actions in the environment are: - 0: do nothing - 1: turn left and move to the next cell - 2: move to the next cell in front of the agent - 3: turn right and move to the next cell - - Moving forward in a dead-end cell makes the agent turn 180 degrees and step - to the cell it came from. - - The actions of the agents are executed in order of their handle to prevent - deadlocks and to allow them to learn relative priorities. - - TODO: WRITE ABOUT THE REWARD FUNCTION, and possibly allow for alpha and - beta to be passed as parameters to __init__(). - """ - - def __init__(self, - rail, - number_of_agents=1, - custom_observation_builder=TreeObsForRailEnv): - """ - Environment init. - - Parameters - ------- - rail : numpy.ndarray of type numpy.uint16 - The transition matrix that defines the environment. - number_of_agents : int - Number of agents to spawn on the map. - custom_observation_builder: ObservationBuilder object - ObservationBuilder-derived object that takes this env object - as input as provides observation vectors for each agent. - """ - - self.rail = rail - self.width = rail.width - self.height = rail.height - - self.number_of_agents = number_of_agents - - self.obs_builder = custom_observation_builder(env=self) - - self.actions = [0]*self.number_of_agents - self.rewards = [0]*self.number_of_agents - self.done = False - - self.dones = {"__all__": False} - self.obs_dict = {} - self.rewards_dict = {} - - self.agents_handles = list(range(self.number_of_agents)) - - def get_agent_handles(self): - return self.agents_handles - - def reset(self): - self.dones = {"__all__": False} - for handle in self.agents_handles: - self.dones[handle] = False - - re_generate = True - while re_generate: - valid_positions = [] - for r in range(self.height): - for c in range(self.width): - if self.rail.get_transitions((r, c)) > 0: - valid_positions.append((r, c)) - - self.agents_position = random.sample(valid_positions, - self.number_of_agents) - self.agents_target = random.sample(valid_positions, - self.number_of_agents) - - # agents_direction must be a direction for which a solution is - # guaranteed. - self.agents_direction = [0]*self.number_of_agents - re_generate = False - for i in range(self.number_of_agents): - valid_movements = [] - for direction in range(4): - position = self.agents_position[i] - moves = self.rail.get_transitions( - (position[0], position[1], direction)) - for move_index in range(4): - if moves[move_index]: - valid_movements.append((direction, move_index)) - - valid_starting_directions = [] - for m in valid_movements: - new_position = self._new_position(self.agents_position[i], - m[1]) - if m[0] not in valid_starting_directions and \ - self._path_exists(new_position, m[0], - self.agents_target[i]): - valid_starting_directions.append(m[0]) - - if len(valid_starting_directions) == 0: - re_generate = True - else: - self.agents_direction[i] = random.sample( - valid_starting_directions, 1)[0] - - # Reset the state of the observation builder with the new environment - self.obs_builder.reset() - - # Return the new observation vectors for each agent - return self._get_observations() - - def step(self, action_dict): - alpha = 1.0 - beta = 1.0 - - invalid_action_penalty = -2 - step_penalty = -1 * alpha - global_reward = 1 * beta - - # Reset the step rewards - self.rewards_dict = {} - for handle in self.agents_handles: - self.rewards_dict[handle] = 0 - - if self.dones["__all__"]: - return self._get_observations(), self.rewards_dict, self.dones, {} - - for i in range(len(self.agents_handles)): - handle = self.agents_handles[i] - - if handle not in action_dict: - continue - - action = action_dict[handle] - - if action < 0 or action > 3: - print('ERROR: illegal action=', action, - 'for agent with handle=', handle) - return - - if action > 0: - pos = self.agents_position[i] - direction = self.agents_direction[i] - - movement = direction - if action == 1: - movement = direction - 1 - elif action == 3: - movement = direction + 1 - - if movement < 0: - movement += 4 - if movement >= 4: - movement -= 4 - - is_deadend = False - if action == 2: - # compute number of possible transitions in the current - # cell - nbits = 0 - tmp = self.rail.get_transitions((pos[0], pos[1])) - while tmp > 0: - nbits += (tmp & 1) - tmp = tmp >> 1 - if nbits == 1: - # dead-end; assuming the rail network is consistent, - # this should match the direction the agent has come - # from. But it's better to check in any case. - reverse_direction = 0 - if direction == 0: - reverse_direction = 2 - elif direction == 1: - reverse_direction = 3 - elif direction == 2: - reverse_direction = 0 - elif direction == 3: - reverse_direction = 1 - - valid_transition = self.rail.get_transition( - (pos[0], pos[1], direction), - reverse_direction) - if valid_transition: - direction = reverse_direction - movement = reverse_direction - is_deadend = True - - new_position = self._new_position(pos, movement) - - # Is it a legal move? 1) transition allows the movement in the - # cell, 2) the new cell is not empty (case 0), 3) the cell is - # free, i.e., no agent is currently in that cell - if new_position[1] >= self.width or\ - new_position[0] >= self.height or\ - new_position[0] < 0 or new_position[1] < 0: - new_cell_isValid = False - - elif self.rail.get_transitions((new_position[0], new_position[1])) > 0: - new_cell_isValid = True - else: - new_cell_isValid = False - - transition_isValid = self.rail.get_transition( - (pos[0], pos[1], direction), - movement) or is_deadend - - cell_isFree = True - for j in range(self.number_of_agents): - if self.agents_position[j] == new_position: - cell_isFree = False - break - - if new_cell_isValid and transition_isValid and cell_isFree: - # move and change direction to face the movement that was - # performed - self.agents_position[i] = new_position - self.agents_direction[i] = movement - else: - # the action was not valid, add penalty - self.rewards_dict[handle] += invalid_action_penalty - - # if agent is not in target position, add step penalty - if self.agents_position[i][0] == self.agents_target[i][0] and \ - self.agents_position[i][1] == self.agents_target[i][1]: - self.dones[handle] = True - else: - self.rewards_dict[handle] += step_penalty - - # Check for end of episode + add global reward to all rewards! - num_agents_in_target_position = 0 - for i in range(self.number_of_agents): - if self.agents_position[i][0] == self.agents_target[i][0] and \ - self.agents_position[i][1] == self.agents_target[i][1]: - num_agents_in_target_position += 1 - - if num_agents_in_target_position == self.number_of_agents: - self.dones["__all__"] = True - self.rewards_dict = [r+global_reward for r in self.rewards_dict] - - # Reset the step actions (in case some agent doesn't 'register_action' - # on the next step) - self.actions = [0]*self.number_of_agents - - return self._get_observations(), self.rewards_dict, self.dones, {} - - def _new_position(self, position, movement): - if movement == 0: # NORTH - return (position[0]-1, position[1]) - elif movement == 1: # EAST - return (position[0], position[1] + 1) - elif movement == 2: # SOUTH - return (position[0]+1, position[1]) - elif movement == 3: # WEST - return (position[0], position[1] - 1) - - def _path_exists(self, start, direction, end): - # BFS - Check if a path exists between the 2 nodes - - visited = set() - stack = [(start, direction)] - while stack: - node = stack.pop() - if node[0][0] == end[0] and node[0][1] == end[1]: - return 1 - if node not in visited: - visited.add(node) - moves = self.rail.get_transitions((node[0][0], node[0][1], node[1])) - for move_index in range(4): - if moves[move_index]: - stack.append((self._new_position(node[0], move_index), - move_index)) - - # If cell is a dead-end, append previous node with reversed - # orientation! - nbits = 0 - tmp = self.rail.get_transitions((node[0][0], node[0][1])) - while tmp > 0: - nbits += (tmp & 1) - tmp = tmp >> 1 - if nbits == 1: - stack.append((node[0], (node[1] + 2) % 4)) - - return 0 - - def _get_observations(self): - self.obs_dict = {} - for handle in self.agents_handles: - self.obs_dict[handle] = self.obs_builder.get(handle) - return self.obs_dict - - def render(self): - # TODO: - pass diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py index 1f97ff2b75674026ae98334a74aea6f9d1b60dbd..38e66423509e622d7cd0685b14c7c1b28f8806ca 100644 --- a/flatland/core/env_observation_builder.py +++ b/flatland/core/env_observation_builder.py @@ -1,3 +1,13 @@ +""" +ObservationBuilder objects are objects that can be passed to environments designed for customizability. +The ObservationBuilder-derived custom classes implement 2 functions, reset() and get() or get(handle). + ++ Reset() is called after each environment reset, to allow for pre-computing relevant data. + ++ Get() is called whenever an observation has to be computed, potentially for each agent independently in +case of multi-agent environments. +""" + import numpy as np from collections import deque @@ -6,46 +16,87 @@ from collections import deque class ObservationBuilder: - def __init__(self, env): + """ + ObservationBuilder base class. + """ + def __init__(self): + pass + + def _set_env(self, env): self.env = env def reset(self): + """ + Called after each environment reset. + """ raise NotImplementedError() - def get(self, handle): + def get(self, handle=0): + """ + Called whenever an observation has to be computed for the `env' environment, possibly + for each agent independently (agent id `handle'). + + Parameters + ------- + handle : int (optional) + Handle of the agent for which to compute the observation vector. + + Returns + ------- + function + An observation structure, specific to the corresponding environment. + """ raise NotImplementedError() class TreeObsForRailEnv(ObservationBuilder): - def __init__(self, env): - self.env = env + """ + TreeObsForRailEnv object. + + This object returns observation vectors for agents in the RailEnv environment. + The information is local to each agent and exploits the tree structure of the rail + network to simplify the representation of the state of the environment for each agent. + """ + def __init__(self, max_depth): + self.max_depth = max_depth def reset(self): self.distance_map = np.inf * np.ones(shape=(self.env.number_of_agents, self.env.height, - self.env.width)) + self.env.width, + 4)) self.max_dist = np.zeros(self.env.number_of_agents) for i in range(self.env.number_of_agents): self.max_dist[i] = self._distance_map_walker(self.env.agents_target[i], i) + # Update local lookup table for all agents' target locations + self.location_has_target = {} + for loc in self.env.agents_target: + self.location_has_target[(loc[0], loc[1])] = 1 + def _distance_map_walker(self, position, target_nr): + """ + Utility function to compute distance maps from each cell in the rail network (and each possible + orientation within it) to each agent's target cell. + """ # Returns max distance to target, from the farthest away node, while filling in distance_map for ori in range(4): - self.distance_map[target_nr, position[0], position[1]] = 0 + self.distance_map[target_nr, position[0], position[1], ori] = 0 # Fill in the (up to) 4 neighboring nodes # nodes_queue = [] # list of tuples (row, col, direction, distance); - # direction is the direction of movement, meaning that at least a possible orientation - # of an agent in cell (row,col) allows a movement in direction `direction' - nodes_queue = deque(self._get_and_update_neighbors(position, - target_nr, 0, enforce_target_direction=-1)) + # direction is the direction of movement, meaning that at least a possible orientation of an agent + # in cell (row,col) allows a movement in direction `direction' + nodes_queue = deque(self._get_and_update_neighbors(position, target_nr, 0, enforce_target_direction=-1)) # BFS from target `position' to all the reachable nodes in the grid # Stop the search if the target position is re-visited, in any direction - visited = set([(position[0], position[1], 0), (position[0], position[1], 1), - (position[0], position[1], 2), (position[0], position[1], 3)]) + visited = set([(position[0], position[1], 0), + (position[0], position[1], 1), + (position[0], position[1], 2), + (position[0], position[1], 3)]) max_distance = 0 @@ -57,11 +108,9 @@ class TreeObsForRailEnv(ObservationBuilder): if node_id not in visited: visited.add(node_id) - # From the list of possible neighbors that have at least a path to the - # current node, only keep those whose new orientation in the current cell - # would allow a transition to direction node[2] - valid_neighbors = self._get_and_update_neighbors( - (node[0], node[1]), target_nr, node[3], node[2]) + # From the list of possible neighbors that have at least a path to the current node, only keep those + # whose new orientation in the current cell would allow a transition to direction node[2] + valid_neighbors = self._get_and_update_neighbors((node[0], node[1]), target_nr, node[3], node[2]) for n in valid_neighbors: nodes_queue.append(n) @@ -72,6 +121,10 @@ class TreeObsForRailEnv(ObservationBuilder): return max_distance def _get_and_update_neighbors(self, position, target_nr, current_distance, enforce_target_direction=-1): + """ + Utility function used by _distance_map_walker to perform a BFS walk over the rail, filling in the + minimum distances from each target cell. + """ neighbors = [] for direction in range(4): @@ -121,9 +174,57 @@ class TreeObsForRailEnv(ObservationBuilder): neighbors.append((new_cell[0], new_cell[1], direction, new_distance)) self.distance_map[target_nr, new_cell[0], new_cell[1]] = new_distance + possible_directions = [0, 1, 2, 3] + if enforce_target_direction >= 0: + # The agent must land into the current cell with orientation `enforce_target_direction'. + # This is only possible if the agent has arrived from the cell in the opposite direction! + possible_directions = [(enforce_target_direction+2) % 4] + + for neigh_direction in possible_directions: + new_cell = self._new_position(position, neigh_direction) + + if new_cell[0] >= 0 and new_cell[0] < self.env.height and \ + new_cell[1] >= 0 and new_cell[1] < self.env.width: + + desired_movement_from_new_cell = (neigh_direction+2) % 4 + + """ + # Is the next cell a dead-end? + isNextCellDeadEnd = False + nbits = 0 + tmp = self.env.rail.get_transitions((new_cell[0], new_cell[1])) + while tmp > 0: + nbits += (tmp & 1) + tmp = tmp >> 1 + if nbits == 1: + # Dead-end! + isNextCellDeadEnd = True + """ + + # Check all possible transitions in new_cell + for agent_orientation in range(4): + # Is a transition along movement `desired_movement_from_new_cell' to the current cell possible? + isValid = self.env.rail.get_transition((new_cell[0], new_cell[1], agent_orientation), + desired_movement_from_new_cell) + + if isValid: + """ + # TODO: check that it works with deadends! -- still bugged! + movement = desired_movement_from_new_cell + if isNextCellDeadEnd: + movement = (desired_movement_from_new_cell+2) % 4 + """ + new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation], + current_distance+1) + neighbors.append((new_cell[0], new_cell[1], agent_orientation, new_distance)) + self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation] = new_distance + return neighbors def _new_position(self, position, movement): + """ + Utility function that converts a compass movement over a 2D grid to new positions (r, c). + """ if movement == 0: # NORTH return (position[0]-1, position[1]) elif movement == 1: # EAST @@ -134,8 +235,233 @@ class TreeObsForRailEnv(ObservationBuilder): return (position[0], position[1] - 1) def get(self, handle): - # TODO: compute the observation for agent `handle' - return [] + """ + Computes the current observation for agent `handle' in env + + The observation vector is composed of 4 sequential parts, corresponding to data from the up to 4 possible + movements in a RailEnv (up to because only a subset of possible transitions are allowed in RailEnv). + The possible movements are sorted relative to the current orientation of the agent, rather than NESW as for + the transitions. The order is: + [data from 'left'] + [data from 'forward'] + [data from 'right'] + [data from 'back'] + + + + + + Each branch data is organized as: + [root node information] + + [recursive branch data from 'left'] + + [... from 'forward'] + + [... from 'right] + + [... from 'back'] + + Finally, each node information is composed of 5 floating point values: + + #1: + + #2: 1 if a target of another agent is detected between the previous node and the current one. + + #3: 1 if another agent is detected between the previous node and the current one. + + #4: + + #5: minimum distance from node to the agent's target (when landing to the node following the corresponding + branch. + + Missing/padding nodes are filled in with -inf (truncated). + Missing values in present node are filled in with +inf (truncated). + + + In case of the root node, the values are [0, 0, 0, 0, distance from agent to target]. + In case the target node is reached, the values are [0, 0, 0, 0, 0]. + """ + + # Update local lookup table for all agents' positions + self.location_has_agent = {} + for loc in self.env.agents_position: + self.location_has_agent[(loc[0], loc[1])] = 1 + + position = self.env.agents_position[handle] + orientation = self.env.agents_direction[handle] + + # Root node - current position + observation = [0, 0, 0, 0, self.distance_map[handle, position[0], position[1], orientation]] + + # Start from the current orientation, and see which transitions are available; + # organize them as [left, forward, right, back], relative to the current orientation + for branch_direction in [(orientation+4+i) % 4 for i in range(-1, 3)]: + if self.env.rail.get_transition((position[0], position[1], orientation), branch_direction): + new_cell = self._new_position(position, branch_direction) + + branch_observation = self._explore_branch(handle, new_cell, branch_direction, 1) + observation = observation + branch_observation + else: + num_cells_to_fill_in = 0 + pow4 = 1 + for i in range(self.max_depth): + num_cells_to_fill_in += pow4 + pow4 *= 4 + observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf]*num_cells_to_fill_in + + return observation + + def _explore_branch(self, handle, position, direction, depth): + """ + Utility function to compute tree-based observations. + """ + # [Recursive branch opened] + if depth >= self.max_depth+1: + return [] + + # Continue along direction until next switch or + # until no transitions are possible along the current direction (i.e., dead-ends) + # We treat dead-ends as nodes, instead of going back, to avoid loops + exploring = True + last_isSwitch = False + last_isDeadEnd = False + # TODO: last_isTerminal = False # wrong cell encountered + last_isTarget = False + + other_agent_encountered = False + other_target_encountered = False + while exploring: + # ############################# + # ############################# + # Modify here to compute any useful data required to build the end node's features. This code is called + # for each cell visited between the previous branching node and the next switch / target / dead-end. + + if position in self.location_has_agent: + other_agent_encountered = True + + if position in self.location_has_target: + other_target_encountered = True + + # ############################# + # ############################# + + # If the target node is encountered, pick that as node. Also, no further branching is possible. + if position[0] == self.env.agents_target[handle][0] and position[1] == self.env.agents_target[handle][1]: + last_isTarget = True + break + + cell_transitions = self.env.rail.get_transitions((position[0], position[1], direction)) + num_transitions = 0 + for i in range(4): + if cell_transitions[i]: + num_transitions += 1 + + exploring = False + if num_transitions == 1: + # Check if dead-end, or if we can go forward along direction + nbits = 0 + tmp = self.env.rail.get_transitions((position[0], position[1])) + while tmp > 0: + nbits += (tmp & 1) + tmp = tmp >> 1 + if nbits == 1: + # Dead-end! + last_isDeadEnd = True + + if not last_isDeadEnd: + # Keep walking through the tree along `direction' + exploring = True + + for i in range(4): + if cell_transitions[i]: + position = self._new_position(position, i) + direction = i + break + + elif num_transitions > 0: + # Switch detected + last_isSwitch = True + break + + elif num_transitions == 0: + # Wrong cell type, but let's cover it and treat it as a dead-end, just in case + # TODO: last_isTerminal = True + break + + # `position' is either a terminal node or a switch + + observation = [] + + # ############################# + # ############################# + # Modify here to append new / different features for each visited cell! + + if last_isTarget: + observation = [0, + 1 if other_target_encountered else 0, + 1 if other_agent_encountered else 0, + 0, + 0] + + else: + observation = [0, + 1 if other_target_encountered else 0, + 1 if other_agent_encountered else 0, + 0, + self.distance_map[handle, position[0], position[1], direction]] + + # ############################# + # ############################# + + # Start from the current orientation, and see which transitions are available; + # organize them as [left, forward, right, back], relative to the current orientation + for branch_direction in [(direction+4+i) % 4 for i in range(-1, 3)]: + if last_isDeadEnd and self.env.rail.get_transition((position[0], position[1], direction), + (branch_direction+2) % 4): + # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes + # it back + new_cell = self._new_position(position, (branch_direction+2) % 4) + + branch_observation = self._explore_branch(handle, new_cell, (branch_direction+2) % 4, depth+1) + observation = observation + branch_observation + + elif last_isSwitch and self.env.rail.get_transition((position[0], position[1], direction), + branch_direction): + new_cell = self._new_position(position, branch_direction) + + branch_observation = self._explore_branch(handle, new_cell, branch_direction, depth+1) + observation = observation + branch_observation + + else: + num_cells_to_fill_in = 0 + pow4 = 1 + for i in range(self.max_depth-depth): + num_cells_to_fill_in += pow4 + pow4 *= 4 + observation = observation + [-np.inf, -np.inf, -np.inf, -np.inf, -np.inf]*num_cells_to_fill_in + + return observation + + def util_print_obs_subtree(self, tree, num_features_per_node=5, prompt='', current_depth=0): + """ + Utility function to pretty-print tree observations returned by this object. + """ + if len(tree) < num_features_per_node: + return + + depth = 0 + tmp = len(tree)/num_features_per_node-1 + pow4 = 4 + while tmp > 0: + tmp -= pow4 + depth += 1 + pow4 *= 4 + + prompt_ = ['L:', 'F:', 'R:', 'B:'] + + print(" "*current_depth + prompt, tree[0:num_features_per_node]) + child_size = (len(tree)-num_features_per_node)//4 + for children in range(4): + child_tree = tree[(num_features_per_node+children*child_size): + (num_features_per_node+(children+1)*child_size)] + self.util_print_obs_subtree(child_tree, + num_features_per_node, + prompt=prompt_[children], + current_depth=current_depth+1) class GlobalObsForRailEnv(ObservationBuilder): @@ -152,8 +478,8 @@ class GlobalObsForRailEnv(ObservationBuilder): - A 4 elements array with one of encoding of the direction. """ - def __init__(self, env): - super(GlobalObsForRailEnv, self).__init__(env) + def __init__(self): + super(GlobalObsForRailEnv, self).__init__() def reset(self): self.rail_obs = np.zeros((self.env.height, self.env.width, 16)) @@ -184,196 +510,3 @@ class GlobalObsForRailEnv(ObservationBuilder): direction[self.env.agents_direction[handle]] = 1 return self.rail_obs, obs_agents_targets_pos, direction - - -class Tree_State: - """ - Keep track of the current state while building the tree - """ - def __init__(self, agent, position, direction, depth, distance): - self.agent = agent - self.position = position - self.direction = direction - self.depth = depth - self.initial_direction = None - self.distance = distance - self.data = [np.inf, np.inf, np.inf, np.inf, np.inf] - - -class Node(): - """ - Define a tree node to get populated during search - """ - def __init__(self, position, data): - self.n_children = 4 - self.children = [None]*self.n_children - self.data = data - self.position = position - - def insert(self, position, data, child_idx): - """ - Insert new node with data - - @param data node data object to insert - """ - new_node = Node(position, data) - self.children[child_idx] = new_node - return new_node - - def print_tree(self, i=0, depth=0): - """ - Print tree content inorder - """ - current_i = i - curr_depth = depth+1 - if i < self.n_children: - if self.children[i] is not None: - self.children[i].print_tree(depth=curr_depth) - current_i += 1 - if self.children[i] is not None: - self.children[i].print_tree(i, depth=curr_depth) - - -""" - - def get_observation(self, agent): - # Get the current observation for an agent - current_position = self.internal_position[agent] - #target_heading = self._compass(agent).tolist() - coordinate = tuple(np.transpose(self._position_to_coordinate([current_position]))) - agent_distance = self.distance_map[agent][coordinate][0] - # Start tree search - if current_position == self.target[agent]: - agent_tree = Node(current_position, [-np.inf, -np.inf, -np.inf, -np.inf, -1]) - else: - agent_tree = Node(current_position, [0, 0, 0, 0, agent_distance]) - - initial_tree_state = Tree_State(agent, current_position, -1, 0, 0) - self._tree_search(initial_tree_state, agent_tree, agent) - observation = [] - distance_data = [] - - self._flatten_tree(agent_tree, observation, distance_data, self.max_depth+1) - # This is probably very slow!!!! - #max_obs = np.max([i for i in observation if i < np.inf]) - #if max_obs != 0: - # observation = np.array(observation)/ max_obs - - #print([i for i in distance_data if i >= 0]) - observation = np.concatenate((observation, distance_data)) - #observation = np.concatenate((observation, np.identity(5)[int(self.last_action[agent])])) - #return np.clip(observation / self.max_dist[agent], -1, 1) - return np.clip(observation / 15., -1, 1) - - - - - def _tree_search(self, in_tree_state, parent_node, agent): - if in_tree_state.depth >= self.max_depth: - return - target_distance = np.inf - other_target = np.inf - other_agent = np.inf - coordinate = tuple(np.transpose(self._position_to_coordinate([in_tree_state.position]))) - curr_target_dist = self.distance_map[agent][coordinate][0] - forbidden_action = (in_tree_state.direction + 2) % 4 - # Update the position - failed_move = 0 - leaf_distance = in_tree_state.distance - for child_idx in range(4): - if child_idx != forbidden_action or in_tree_state.direction == -1: - tree_state = copy.deepcopy(in_tree_state) - tree_state.direction = child_idx - current_position, invalid_move = self._detect_path( - tree_state.position, tree_state.direction) - if tree_state.initial_direction == None: - tree_state.initial_direction = child_idx - if not invalid_move: - coordinate = tuple(np.transpose(self._position_to_coordinate([current_position]))) - curr_target_dist = self.distance_map[agent][coordinate][0] - #if tree_state.initial_direction == None: - # tree_state.initial_direction = child_idx - tree_state.position = current_position - tree_state.distance += 1 - - - # Collect information at the current position - detection_distance = tree_state.distance - if current_position == self.target[tree_state.agent]: - target_distance = detection_distance - - elif current_position in self.target: - other_target = detection_distance - - if current_position in self.internal_position: - other_agent = detection_distance - - tree_state.data[0] = self._min_greater_zero(target_distance, tree_state.data[0]) - tree_state.data[1] = self._min_greater_zero(other_target, tree_state.data[1]) - tree_state.data[2] = self._min_greater_zero(other_agent, tree_state.data[2]) - tree_state.data[3] = tree_state.distance - tree_state.data[4] = self._min_greater_zero(curr_target_dist, tree_state.data[4]) - - if self._switch_detection(tree_state.position): - tree_state.depth += 1 - new_tree_state = copy.deepcopy(tree_state) - new_node = parent_node.insert(tree_state.position, - tree_state.data, tree_state.initial_direction) - new_tree_state.initial_direction = None - new_tree_state.data = [np.inf, np.inf, np.inf, np.inf, np.inf] - self._tree_search(new_tree_state, new_node, agent) - else: - self._tree_search(tree_state, parent_node, agent) - else: - failed_move += 1 - if failed_move == 3 and in_tree_state.direction != -1: - tree_state.data[4] = self._min_greater_zero(curr_target_dist, tree_state.data[4]) - parent_node.insert(tree_state.position, tree_state.data, tree_state.initial_direction) - return - return - - def _flatten_tree(self, node, observation_vector, distance_sensor, depth): - if depth <= 0: - return - if node != None: - observation_vector.extend(node.data[:-1]) - distance_sensor.extend([node.data[-1]]) - else: - observation_vector.extend([-np.inf, -np.inf, -np.inf, -np.inf]) - distance_sensor.extend([-np.inf]) - for child_idx in range(4): - if node != None: - child = node.children[child_idx] - else: - child = None - self._flatten_tree(child, observation_vector, distance_sensor, depth -1) - - - - def _switch_detection(self, position): - # Hack to detect switches - # This can later directly be derived from the transition matrix - paths = 0 - for i in range(4): - _, invalid_move = self._detect_path(position, i) - if not invalid_move: - paths +=1 - if paths >= 3: - return True - return False - - - - - def _min_greater_zero(self, x, y): - if x <= 0 and y <= 0: - return 0 - if x < 0: - return y - if y < 0: - return x - return min(x, y) - - - -""" diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py index d3fcf5c8467586053ca4ab624f9b8536bdfba2de..78dd9110c6ab61de1d5c38aab6e93d0180431cd5 100644 --- a/flatland/core/transition_map.py +++ b/flatland/core/transition_map.py @@ -118,7 +118,7 @@ class GridTransitionMap(TransitionMap): Width of the grid. height : int Height of the grid. - transitions_class : Transitions object + transitions : Transitions object The Transitions object to use to encode/decode transitions over the grid. @@ -243,6 +243,54 @@ class GridTransitionMap(TransitionMap): return self.transitions.set_transition(self.grid[cell_id[0]][cell_id[1]], cell_id[2], transition_index, new_transition) + def save_transition_map(self, filename): + """ + Save the transitions grid as `filename', in npy format. + + Parameters + ---------- + filename : string + Name of the file to which to save the transitions grid. + + """ + np.save(filename, self.grid) + + def load_transition_map(self, filename, override_gridsize=True): + """ + Load the transitions grid from `filename' (npy format). + The load function only updates the transitions grid, and possibly width and height, but the object has to be + initialized with the correct `transitions' object anyway. + + Parameters + ---------- + filename : string + Name of the file from which to load the transitions grid. + override_gridsize : bool + If override_gridsize=True, the width and height of the GridTransitionMap object are replaced with the size + of the map loaded from `filename'. If override_gridsize=False, the transitions grid is either cropped (if + the grid size is larger than (height,width) ) or padded with zeros (if the grid size is smaller than + (height,width) ) + + """ + new_grid = np.load(filename) + + new_height = new_grid.shape[0] + new_width = new_grid.shape[1] + + if override_gridsize: + self.width = new_width + self.height = new_height + self.grid = new_grid + + else: + if new_grid.dtype == np.uint16: + self.grid = np.zeros((self.height, self.width), dtype=np.uint16) + elif new_grid.dtype == np.uint64: + self.grid = np.zeros((self.height, self.width), dtype=np.uint64) + + self.grid[0:min(self.height, new_height), + 0:min(self.width, new_width)] = new_grid[0:min(self.height, new_height), + 0:min(self.width, new_width)] # TODO: GIACOMO: is it better to provide those methods with lists of cell_ids # (most general implementation) or to make Grid-class specific methods for diff --git a/flatland/envs/__init__.py b/flatland/envs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py new file mode 100644 index 0000000000000000000000000000000000000000..36040d5cf66f1adc74faaade7e45acbd711676e0 --- /dev/null +++ b/flatland/envs/rail_env.py @@ -0,0 +1,721 @@ +""" +Definition of the RailEnv environment and related level-generation functions. + +Generator functions are functions that take width, height and num_resets as arguments and return +a GridTransitionMap object. +""" +import numpy as np + +from flatland.core.env import Environment +from flatland.core.env_observation_builder import TreeObsForRailEnv + +from flatland.core.transitions import Grid8Transitions, RailEnvTransitions +from flatland.core.transition_map import GridTransitionMap + + +def rail_from_manual_specifications_generator(rail_spec): + """ + Utility to convert a rail given by manual specification as a map of tuples + (cell_type, rotation), to a transition map with the correct 16-bit + transitions specifications. + + Parameters + ------- + rail_spec : list of list of tuples + List (rows) of lists (columns) of tuples, each specifying a cell for + the RailEnv environment as (cell_type, rotation), with rotation being + clock-wise and in [0, 90, 180, 270]. + + Returns + ------- + function + Generator function that always returns a GridTransitionMap object with + the matrix of correct 16-bit bitmaps for each cell. + """ + def generator(width, height, num_resets=0): + t_utils = RailEnvTransitions() + + height = len(rail_spec) + width = len(rail_spec[0]) + rail = GridTransitionMap(width=width, height=height, transitions=t_utils) + + for r in range(height): + for c in range(width): + cell = rail_spec[r][c] + if cell[0] < 0 or cell[0] >= len(t_utils.transitions): + print("ERROR - invalid cell type=", cell[0]) + return [] + rail.set_transitions((r, c), t_utils.rotate_transition( + t_utils.transitions[cell[0]], cell[1])) + + return rail + + return generator + + +def rail_from_GridTransitionMap_generator(rail_map): + """ + Utility to convert a rail given by a GridTransitionMap map with the correct + 16-bit transitions specifications. + + Parameters + ------- + rail_map : GridTransitionMap object + GridTransitionMap object to return when the generator is called. + + Returns + ------- + function + Generator function that always returns the given `rail_map' object. + """ + def generator(width, height, num_resets=0): + return rail_map + + return generator + + +def rail_from_list_of_saved_GridTransitionMap_generator(list_of_filenames): + """ + Utility to sequentially and cyclically return GridTransitionMap-s from a list of files, on each environment reset. + + Parameters + ------- + list_of_filenames : list + List of filenames with the saved grids to load. + + Returns + ------- + function + Generator function that always returns the given `rail_map' object. + """ + def generator(width, height, num_resets=0): + t_utils = RailEnvTransitions() + rail_map = GridTransitionMap(width=width, height=height, transitions=t_utils) + rail_map.load_transition_map(list_of_filenames[num_resets % len(list_of_filenames)], override_gridsize=False) + + if rail_map.grid.dtype == np.uint64: + rail_map.transitions = Grid8Transitions() + + return rail_map + + return generator + + +""" +def generate_rail_from_list_of_manual_specifications(list_of_specifications) + def generator(width, height, num_resets=0): + return generate_rail_from_manual_specifications(list_of_specifications) + + return generator +""" + + +def random_rail_generator(cell_type_relative_proportion=[1.0]*8): + """ + Dummy random level generator: + - fill in cells at random in [width-2, height-2] + - keep filling cells in among the unfilled ones, such that all transitions + are legit; if no cell can be filled in without violating some + transitions, pick one among those that can satisfy most transitions + (1,2,3 or 4), and delete (+mark to be re-filled) the cells that were + incompatible. + - keep trying for a total number of insertions + (e.g., (W-2)*(H-2)*MAX_REPETITIONS ); if no solution is found, empty the + board and try again from scratch. + - finally pad the border of the map with dead-ends to avoid border issues. + + Dead-ends are not allowed inside the grid, only at the border; however, if + no cell type can be inserted in a given cell (because of the neighboring + transitions), deadends are allowed if they solve the problem. This was + found to turn most un-genereatable levels into valid ones. + + Parameters + ------- + width : int + The width (number of cells) of the grid to generate. + height : int + The height (number of cells) of the grid to generate. + + Returns + ------- + numpy.ndarray of type numpy.uint16 + The matrix with the correct 16-bit bitmaps for each cell. + """ + + def generator(width, height, num_resets=0): + t_utils = RailEnvTransitions() + + transition_probability = cell_type_relative_proportion + + transitions_templates_ = [] + transition_probabilities = [] + for i in range(len(t_utils.transitions)-1): # don't include dead-ends + all_transitions = 0 + for dir_ in range(4): + trans = t_utils.get_transitions(t_utils.transitions[i], dir_) + all_transitions |= (trans[0] << 3) | \ + (trans[1] << 2) | \ + (trans[2] << 1) | \ + (trans[3]) + + template = [int(x) for x in bin(all_transitions)[2:]] + template = [0]*(4-len(template)) + template + + # add all rotations + for rot in [0, 90, 180, 270]: + transitions_templates_.append((template, + t_utils.rotate_transition( + t_utils.transitions[i], + rot))) + transition_probabilities.append(transition_probability[i]) + template = [template[-1]]+template[:-1] + + def get_matching_templates(template): + ret = [] + for i in range(len(transitions_templates_)): + is_match = True + for j in range(4): + if template[j] >= 0 and \ + template[j] != transitions_templates_[i][0][j]: + is_match = False + break + if is_match: + ret.append((transitions_templates_[i][1], transition_probabilities[i])) + return ret + + MAX_INSERTIONS = (width-2) * (height-2) * 10 + MAX_ATTEMPTS_FROM_SCRATCH = 10 + + attempt_number = 0 + while attempt_number < MAX_ATTEMPTS_FROM_SCRATCH: + cells_to_fill = [] + rail = [] + for r in range(height): + rail.append([None]*width) + if r > 0 and r < height-1: + cells_to_fill = cells_to_fill \ + + [(r, c) for c in range(1, width-1)] + + num_insertions = 0 + while num_insertions < MAX_INSERTIONS and len(cells_to_fill) > 0: + # cell = random.sample(cells_to_fill, 1)[0] + cell = cells_to_fill[np.random.choice(len(cells_to_fill), 1)[0]] + cells_to_fill.remove(cell) + row = cell[0] + col = cell[1] + + # look at its neighbors and see what are the possible transitions + # that can be chosen from, if any. + valid_template = [-1, -1, -1, -1] + + for el in [(0, 2, (-1, 0)), + (1, 3, (0, 1)), + (2, 0, (1, 0)), + (3, 1, (0, -1))]: # N, E, S, W + neigh_trans = rail[row+el[2][0]][col+el[2][1]] + if neigh_trans is not None: + # select transition coming from facing direction el[1] and + # moving to direction el[1] + max_bit = 0 + for k in range(4): + max_bit |= \ + t_utils.get_transition(neigh_trans, k, el[1]) + + if max_bit: + valid_template[el[0]] = 1 + else: + valid_template[el[0]] = 0 + + possible_cell_transitions = get_matching_templates(valid_template) + + if len(possible_cell_transitions) == 0: # NO VALID TRANSITIONS + # no cell can be filled in without violating some transitions + # can a dead-end solve the problem? + if valid_template.count(1) == 1: + for k in range(4): + if valid_template[k] == 1: + rot = 0 + if k == 0: + rot = 180 + elif k == 1: + rot = 270 + elif k == 2: + rot = 0 + elif k == 3: + rot = 90 + + rail[row][col] = t_utils.rotate_transition( + int('0010000000000000', 2), rot) + num_insertions += 1 + + break + + else: + # can I get valid transitions by removing a single + # neighboring cell? + bestk = -1 + besttrans = [] + for k in range(4): + tmp_template = valid_template[:] + tmp_template[k] = -1 + possible_cell_transitions = get_matching_templates( + tmp_template) + if len(possible_cell_transitions) > len(besttrans): + besttrans = possible_cell_transitions + bestk = k + + if bestk >= 0: + # Replace the corresponding cell with None, append it + # to cells to fill, fill in a transition in the current + # cell. + replace_row = row - 1 + replace_col = col + if bestk == 1: + replace_row = row + replace_col = col + 1 + elif bestk == 2: + replace_row = row + 1 + replace_col = col + elif bestk == 3: + replace_row = row + replace_col = col - 1 + + cells_to_fill.append((replace_row, replace_col)) + rail[replace_row][replace_col] = None + + possible_transitions, possible_probabilities = zip(*besttrans) + possible_probabilities = [p/sum(possible_probabilities) for p in possible_probabilities] + + rail[row][col] = np.random.choice(possible_transitions, + p=possible_probabilities) + num_insertions += 1 + + else: + print('WARNING: still nothing!') + rail[row][col] = int('0000000000000000', 2) + num_insertions += 1 + pass + + else: + possible_transitions, possible_probabilities = zip(*possible_cell_transitions) + possible_probabilities = [p/sum(possible_probabilities) for p in possible_probabilities] + + rail[row][col] = np.random.choice(possible_transitions, + p=possible_probabilities) + num_insertions += 1 + + if num_insertions == MAX_INSERTIONS: + # Failed to generate a valid level; try again for a number of times + attempt_number += 1 + else: + break + + if attempt_number == MAX_ATTEMPTS_FROM_SCRATCH: + print('ERROR: failed to generate level') + + # Finally pad the border of the map with dead-ends to avoid border issues; + # at most 1 transition in the neigh cell + for r in range(height): + # Check for transitions coming from [r][1] to WEST + max_bit = 0 + neigh_trans = rail[r][1] + if neigh_trans is not None: + for k in range(4): + neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \ + & (2**4-1) + max_bit = max_bit | (neigh_trans_from_direction & 1) + if max_bit: + rail[r][0] = t_utils.rotate_transition( + int('0010000000000000', 2), 270) + else: + rail[r][0] = int('0000000000000000', 2) + + # Check for transitions coming from [r][-2] to EAST + max_bit = 0 + neigh_trans = rail[r][-2] + if neigh_trans is not None: + for k in range(4): + neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \ + & (2**4-1) + max_bit = max_bit | (neigh_trans_from_direction & (1 << 2)) + if max_bit: + rail[r][-1] = t_utils.rotate_transition(int('0010000000000000', 2), + 90) + else: + rail[r][-1] = int('0000000000000000', 2) + + for c in range(width): + # Check for transitions coming from [1][c] to NORTH + max_bit = 0 + neigh_trans = rail[1][c] + if neigh_trans is not None: + for k in range(4): + neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \ + & (2**4-1) + max_bit = max_bit | (neigh_trans_from_direction & (1 << 3)) + if max_bit: + rail[0][c] = int('0010000000000000', 2) + else: + rail[0][c] = int('0000000000000000', 2) + + # Check for transitions coming from [-2][c] to SOUTH + max_bit = 0 + neigh_trans = rail[-2][c] + if neigh_trans is not None: + for k in range(4): + neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \ + & (2**4-1) + max_bit = max_bit | (neigh_trans_from_direction & (1 << 1)) + if max_bit: + rail[-1][c] = t_utils.rotate_transition( + int('0010000000000000', 2), 180) + else: + rail[-1][c] = int('0000000000000000', 2) + + # For display only, wrong levels + for r in range(height): + for c in range(width): + if rail[r][c] is None: + rail[r][c] = int('0000000000000000', 2) + + tmp_rail = np.asarray(rail, dtype=np.uint16) + + return_rail = GridTransitionMap(width=width, height=height, transitions=t_utils) + return_rail.grid = tmp_rail + return return_rail + + return generator + + +class RailEnv(Environment): + """ + RailEnv environment class. + + RailEnv is an environment inspired by a (simplified version of) a rail + network, in which agents (trains) have to navigate to their target + locations in the shortest time possible, while at the same time cooperating + to avoid bottlenecks. + + The valid actions in the environment are: + 0: do nothing + 1: turn left and move to the next cell + 2: move to the next cell in front of the agent + 3: turn right and move to the next cell + + Moving forward in a dead-end cell makes the agent turn 180 degrees and step + to the cell it came from. + + The actions of the agents are executed in order of their handle to prevent + deadlocks and to allow them to learn relative priorities. + + TODO: WRITE ABOUT THE REWARD FUNCTION, and possibly allow for alpha and + beta to be passed as parameters to __init__(). + """ + + def __init__(self, + width, + height, + rail_generator=random_rail_generator(), + number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2)): + """ + Environment init. + + Parameters + ------- + rail_generator : function + The rail_generator function is a function that takes the width and + height of a rail map along with the number of times the env has + been reset, and returns a GridTransitionMap object. + Implemented functions are: + random_rail_generator : generate a random rail of given size + rail_from_GridTransitionMap_generator(rail_map) : generate a rail from + a GridTransitionMap object + rail_from_manual_specifications_generator(rail_spec) : generate a rail from + a rail specifications array + TODO: generate_rail_from_saved_list or from list of ndarray bitmaps --- + width : int + The width of the rail map. Potentially in the future, + a range of widths to sample from. + height : int + The height of the rail map. Potentially in the future, + a range of heights to sample from. + number_of_agents : int + Number of agents to spawn on the map. Potentially in the future, + a range of number of agents to sample from. + obs_builder_object: ObservationBuilder object + ObservationBuilder-derived object that takes builds observation + vectors for each agent. + """ + + self.rail_generator = rail_generator + self.rail = None + self.width = width + self.height = height + + self.number_of_agents = number_of_agents + + self.obs_builder = obs_builder_object + self.obs_builder._set_env(self) + + self.actions = [0]*self.number_of_agents + self.rewards = [0]*self.number_of_agents + self.done = False + + self.dones = {"__all__": False} + self.obs_dict = {} + self.rewards_dict = {} + + self.agents_handles = list(range(self.number_of_agents)) + + # self.agents_position = [] + # self.agents_target = [] + # self.agents_direction = [] + self.num_resets = 0 + self.reset() + self.num_resets = 0 + + def get_agent_handles(self): + return self.agents_handles + + def reset(self): + self.rail = self.rail_generator(self.width, self.height, self.num_resets) + self.num_resets += 1 + + self.dones = {"__all__": False} + for handle in self.agents_handles: + self.dones[handle] = False + + re_generate = True + while re_generate: + valid_positions = [] + for r in range(self.height): + for c in range(self.width): + if self.rail.get_transitions((r, c)) > 0: + valid_positions.append((r, c)) + + # self.agents_position = random.sample(valid_positions, + # self.number_of_agents) + self.agents_position = [ + valid_positions[i] for i in + np.random.choice(len(valid_positions), self.number_of_agents)] + self.agents_target = [ + valid_positions[i] for i in + np.random.choice(len(valid_positions), self.number_of_agents)] + + # agents_direction must be a direction for which a solution is + # guaranteed. + self.agents_direction = [0]*self.number_of_agents + re_generate = False + for i in range(self.number_of_agents): + valid_movements = [] + for direction in range(4): + position = self.agents_position[i] + moves = self.rail.get_transitions( + (position[0], position[1], direction)) + for move_index in range(4): + if moves[move_index]: + valid_movements.append((direction, move_index)) + + valid_starting_directions = [] + for m in valid_movements: + new_position = self._new_position(self.agents_position[i], + m[1]) + if m[0] not in valid_starting_directions and \ + self._path_exists(new_position, m[0], + self.agents_target[i]): + valid_starting_directions.append(m[0]) + + if len(valid_starting_directions) == 0: + re_generate = True + else: + self.agents_direction[i] = valid_starting_directions[ + np.random.choice(len(valid_starting_directions), 1)[0]] + + # Reset the state of the observation builder with the new environment + self.obs_builder.reset() + + # Return the new observation vectors for each agent + return self._get_observations() + + def step(self, action_dict): + alpha = 1.0 + beta = 1.0 + + invalid_action_penalty = -2 + step_penalty = -1 * alpha + global_reward = 1 * beta + + # Reset the step rewards + self.rewards_dict = dict() + for handle in self.agents_handles: + self.rewards_dict[handle] = 0 + + if self.dones["__all__"]: + return self._get_observations(), self.rewards_dict, self.dones, {} + + for i in range(len(self.agents_handles)): + handle = self.agents_handles[i] + + if handle not in action_dict: + continue + + action = action_dict[handle] + + if action < 0 or action > 3: + print('ERROR: illegal action=', action, + 'for agent with handle=', handle) + return + + if action > 0: + pos = self.agents_position[i] + direction = self.agents_direction[i] + + movement = direction + if action == 1: + movement = direction - 1 + elif action == 3: + movement = direction + 1 + + if movement < 0: + movement += 4 + if movement >= 4: + movement -= 4 + + is_deadend = False + if action == 2: + # compute number of possible transitions in the current + # cell + nbits = 0 + tmp = self.rail.get_transitions((pos[0], pos[1])) + while tmp > 0: + nbits += (tmp & 1) + tmp = tmp >> 1 + if nbits == 1: + # dead-end; assuming the rail network is consistent, + # this should match the direction the agent has come + # from. But it's better to check in any case. + reverse_direction = 0 + if direction == 0: + reverse_direction = 2 + elif direction == 1: + reverse_direction = 3 + elif direction == 2: + reverse_direction = 0 + elif direction == 3: + reverse_direction = 1 + + valid_transition = self.rail.get_transition( + (pos[0], pos[1], direction), + reverse_direction) + if valid_transition: + direction = reverse_direction + movement = reverse_direction + is_deadend = True + + new_position = self._new_position(pos, movement) + + # Is it a legal move? 1) transition allows the movement in the + # cell, 2) the new cell is not empty (case 0), 3) the cell is + # free, i.e., no agent is currently in that cell + if new_position[1] >= self.width or\ + new_position[0] >= self.height or\ + new_position[0] < 0 or new_position[1] < 0: + new_cell_isValid = False + + elif self.rail.get_transitions((new_position[0], new_position[1])) > 0: + new_cell_isValid = True + else: + new_cell_isValid = False + + transition_isValid = self.rail.get_transition( + (pos[0], pos[1], direction), + movement) or is_deadend + + cell_isFree = True + for j in range(self.number_of_agents): + if self.agents_position[j] == new_position: + cell_isFree = False + break + + if new_cell_isValid and transition_isValid and cell_isFree: + # move and change direction to face the movement that was + # performed + self.agents_position[i] = new_position + self.agents_direction[i] = movement + else: + # the action was not valid, add penalty + self.rewards_dict[handle] += invalid_action_penalty + + # if agent is not in target position, add step penalty + if self.agents_position[i][0] == self.agents_target[i][0] and \ + self.agents_position[i][1] == self.agents_target[i][1]: + self.dones[handle] = True + else: + self.rewards_dict[handle] += step_penalty + + # Check for end of episode + add global reward to all rewards! + num_agents_in_target_position = 0 + for i in range(self.number_of_agents): + if self.agents_position[i][0] == self.agents_target[i][0] and \ + self.agents_position[i][1] == self.agents_target[i][1]: + num_agents_in_target_position += 1 + + if num_agents_in_target_position == self.number_of_agents: + self.dones["__all__"] = True + self.rewards_dict = [r+global_reward for r in self.rewards_dict] + + # Reset the step actions (in case some agent doesn't 'register_action' + # on the next step) + self.actions = [0]*self.number_of_agents + + return self._get_observations(), self.rewards_dict, self.dones, {} + + def _new_position(self, position, movement): + if movement == 0: # NORTH + return (position[0]-1, position[1]) + elif movement == 1: # EAST + return (position[0], position[1] + 1) + elif movement == 2: # SOUTH + return (position[0]+1, position[1]) + elif movement == 3: # WEST + return (position[0], position[1] - 1) + + def _path_exists(self, start, direction, end): + # BFS - Check if a path exists between the 2 nodes + + visited = set() + stack = [(start, direction)] + while stack: + node = stack.pop() + if node[0][0] == end[0] and node[0][1] == end[1]: + return 1 + if node not in visited: + visited.add(node) + moves = self.rail.get_transitions((node[0][0], node[0][1], node[1])) + for move_index in range(4): + if moves[move_index]: + stack.append((self._new_position(node[0], move_index), + move_index)) + + # If cell is a dead-end, append previous node with reversed + # orientation! + nbits = 0 + tmp = self.rail.get_transitions((node[0][0], node[0][1])) + while tmp > 0: + nbits += (tmp & 1) + tmp = tmp >> 1 + if nbits == 1: + stack.append((node[0], (node[1] + 2) % 4)) + + return 0 + + def _get_observations(self): + self.obs_dict = {} + for handle in self.agents_handles: + self.obs_dict[handle] = self.obs_builder.get(handle) + return self.obs_dict + + def render(self): + # TODO: + pass diff --git a/flatland/utils/rail_env_generator.py b/flatland/utils/rail_env_generator.py deleted file mode 100644 index 69e5b831be67c6a808e22b5413255789acf90f27..0000000000000000000000000000000000000000 --- a/flatland/utils/rail_env_generator.py +++ /dev/null @@ -1,307 +0,0 @@ -""" -The rail_env_generator module defines provides utilities to generate env -bitmaps for the RailEnv environment. -""" -import random -import numpy as np - -from flatland.core.transitions import RailEnvTransitions -from flatland.core.transition_map import GridTransitionMap - - -def generate_rail_from_manual_specifications(rail_spec): - """ - Utility to convert a rail given by manual specification as a map of tuples - (cell_type, rotation), to a transition map with the correct 16-bit - transitions specifications. - - Parameters - ------- - rail_spec : list of list of tuples - List (rows) of lists (columns) of tuples, each specifying a cell for - the RailEnv environment as (cell_type, rotation), with rotation being - clock-wise and in [0, 90, 180, 270]. - - Returns - ------- - numpy.ndarray of type numpy.uint16 - The matrix with the correct 16-bit bitmaps for each cell. - """ - t_utils = RailEnvTransitions() - - height = len(rail_spec) - width = len(rail_spec[0]) - rail = GridTransitionMap(width=width, height=height, transitions=t_utils) - - for r in range(height): - for c in range(width): - cell = rail_spec[r][c] - if cell[0] < 0 or cell[0] >= len(t_utils.transitions): - print("ERROR - invalid cell type=", cell[0]) - return [] - rail.set_transitions((r, c), t_utils.rotate_transition( - t_utils.transitions[cell[0]], cell[1])) - - return rail - - -def generate_random_rail(width, height): - """ - Dummy random level generator: - - fill in cells at random in [width-2, height-2] - - keep filling cells in among the unfilled ones, such that all transitions - are legit; if no cell can be filled in without violating some - transitions, pick one among those that can satisfy most transitions - (1,2,3 or 4), and delete (+mark to be re-filled) the cells that were - incompatible. - - keep trying for a total number of insertions - (e.g., (W-2)*(H-2)*MAX_REPETITIONS ); if no solution is found, empty the - board and try again from scratch. - - finally pad the border of the map with dead-ends to avoid border issues. - - Dead-ends are not allowed inside the grid, only at the border; however, if - no cell type can be inserted in a given cell (because of the neighboring - transitions), deadends are allowed if they solve the problem. This was - found to turn most un-genereatable levels into valid ones. - - Parameters - ------- - width : int - The width (number of cells) of the grid to generate. - height : int - The height (number of cells) of the grid to generate. - - Returns - ------- - numpy.ndarray of type numpy.uint16 - The matrix with the correct 16-bit bitmaps for each cell. - """ - - t_utils = RailEnvTransitions() - - transitions_templates_ = [] - for i in range(len(t_utils.transitions)-1): # don't include dead-ends - all_transitions = 0 - for dir_ in range(4): - trans = t_utils.get_transitions(t_utils.transitions[i], dir_) - all_transitions |= (trans[0] << 3) | \ - (trans[1] << 2) | \ - (trans[2] << 1) | \ - (trans[3]) - - template = [int(x) for x in bin(all_transitions)[2:]] - template = [0]*(4-len(template)) + template - - # add all rotations - for rot in [0, 90, 180, 270]: - transitions_templates_.append((template, - t_utils.rotate_transition( - t_utils.transitions[i], - rot))) - template = [template[-1]]+template[:-1] - - def get_matching_templates(template): - ret = [] - for i in range(len(transitions_templates_)): - is_match = True - for j in range(4): - if template[j] >= 0 and \ - template[j] != transitions_templates_[i][0][j]: - is_match = False - break - if is_match: - ret.append(transitions_templates_[i][1]) - return ret - - MAX_INSERTIONS = (width-2) * (height-2) * 10 - MAX_ATTEMPTS_FROM_SCRATCH = 10 - - attempt_number = 0 - while attempt_number < MAX_ATTEMPTS_FROM_SCRATCH: - cells_to_fill = [] - rail = [] - for r in range(height): - rail.append([None]*width) - if r > 0 and r < height-1: - cells_to_fill = cells_to_fill \ - + [(r, c) for c in range(1, width-1)] - - num_insertions = 0 - while num_insertions < MAX_INSERTIONS and len(cells_to_fill) > 0: - cell = random.sample(cells_to_fill, 1)[0] - cells_to_fill.remove(cell) - row = cell[0] - col = cell[1] - - # look at its neighbors and see what are the possible transitions - # that can be chosen from, if any. - valid_template = [-1, -1, -1, -1] - - for el in [(0, 2, (-1, 0)), - (1, 3, (0, 1)), - (2, 0, (1, 0)), - (3, 1, (0, -1))]: # N, E, S, W - neigh_trans = rail[row+el[2][0]][col+el[2][1]] - if neigh_trans is not None: - # select transition coming from facing direction el[1] and - # moving to direction el[1] - max_bit = 0 - for k in range(4): - max_bit |= \ - t_utils.get_transition(neigh_trans, k, el[1]) - - if max_bit: - valid_template[el[0]] = 1 - else: - valid_template[el[0]] = 0 - - possible_cell_transitions = get_matching_templates(valid_template) - - if len(possible_cell_transitions) == 0: # NO VALID TRANSITIONS - # no cell can be filled in without violating some transitions - # can a dead-end solve the problem? - if valid_template.count(1) == 1: - for k in range(4): - if valid_template[k] == 1: - rot = 0 - if k == 0: - rot = 180 - elif k == 1: - rot = 270 - elif k == 2: - rot = 0 - elif k == 3: - rot = 90 - - rail[row][col] = t_utils.rotate_transition( - int('0000000000100000', 2), rot) - num_insertions += 1 - - break - - else: - # can I get valid transitions by removing a single - # neighboring cell? - bestk = -1 - besttrans = [] - for k in range(4): - tmp_template = valid_template[:] - tmp_template[k] = -1 - possible_cell_transitions = get_matching_templates( - tmp_template) - if len(possible_cell_transitions) > len(besttrans): - besttrans = possible_cell_transitions - bestk = k - - if bestk >= 0: - # Replace the corresponding cell with None, append it - # to cells to fill, fill in a transition in the current - # cell. - replace_row = row - 1 - replace_col = col - if bestk == 1: - replace_row = row - replace_col = col + 1 - elif bestk == 2: - replace_row = row + 1 - replace_col = col - elif bestk == 3: - replace_row = row - replace_col = col - 1 - - cells_to_fill.append((replace_row, replace_col)) - rail[replace_row][replace_col] = None - - rail[row][col] = random.sample( - besttrans, 1)[0] - num_insertions += 1 - - else: - print('WARNING: still nothing!') - rail[row][col] = int('0000000000000000', 2) - num_insertions += 1 - pass - - else: - rail[row][col] = random.sample( - possible_cell_transitions, 1)[0] - num_insertions += 1 - - if num_insertions == MAX_INSERTIONS: - # Failed to generate a valid level; try again for a number of times - attempt_number += 1 - else: - break - - if attempt_number == MAX_ATTEMPTS_FROM_SCRATCH: - print('ERROR: failed to generate level') - - # Finally pad the border of the map with dead-ends to avoid border issues; - # at most 1 transition in the neigh cell - for r in range(height): - # Check for transitions coming from [r][1] to WEST - max_bit = 0 - neigh_trans = rail[r][1] - if neigh_trans is not None: - for k in range(4): - neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \ - & (2**4-1) - max_bit = max_bit | (neigh_trans_from_direction & 1) - if max_bit: - rail[r][0] = t_utils.rotate_transition( - int('0000000000100000', 2), 270) - else: - rail[r][0] = int('0000000000000000', 2) - - # Check for transitions coming from [r][-2] to EAST - max_bit = 0 - neigh_trans = rail[r][-2] - if neigh_trans is not None: - for k in range(4): - neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \ - & (2**4-1) - max_bit = max_bit | (neigh_trans_from_direction & (1 << 2)) - if max_bit: - rail[r][-1] = t_utils.rotate_transition(int('0000000000100000', 2), - 90) - else: - rail[r][-1] = int('0000000000000000', 2) - - for c in range(width): - # Check for transitions coming from [1][c] to NORTH - max_bit = 0 - neigh_trans = rail[1][c] - if neigh_trans is not None: - for k in range(4): - neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \ - & (2**4-1) - max_bit = max_bit | (neigh_trans_from_direction & (1 << 3)) - if max_bit: - rail[0][c] = int('0000000000100000', 2) - else: - rail[0][c] = int('0000000000000000', 2) - - # Check for transitions coming from [-2][c] to SOUTH - max_bit = 0 - neigh_trans = rail[-2][c] - if neigh_trans is not None: - for k in range(4): - neigh_trans_from_direction = (neigh_trans >> ((3-k) * 4)) \ - & (2**4-1) - max_bit = max_bit | (neigh_trans_from_direction & (1 << 1)) - if max_bit: - rail[-1][c] = t_utils.rotate_transition( - int('0000000000100000', 2), 180) - else: - rail[-1][c] = int('0000000000000000', 2) - - # For display only, wrong levels - for r in range(height): - for c in range(width): - if rail[r][c] is None: - rail[r][c] = int('0000000000000000', 2) - - tmp_rail = np.asarray(rail, dtype=np.uint16) - return_rail = GridTransitionMap(width=width, height=height, transitions=t_utils) - return_rail.grid = tmp_rail - return return_rail diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index d0a78917dcee833eac2d5ac7567436546635d044..9f008f00c261f76e71771732aa02c3c3071f9542 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -6,6 +6,8 @@ import xarray as xr import matplotlib.pyplot as plt +# TODO: suggested renaming to RailEnvRenderTool, as it will only work with RailEnv! + class RenderTool(object): Visit = recordtype("Visit", ["rc", "iDir", "iDepth", "prev"]) @@ -401,7 +403,7 @@ class RenderTool(object): # cell_size is a bit pointless with matplotlib - it does not relate to pixels, # so for now I've changed it to 1 (from 10) cell_size = 1 - + plt.clf() # if oFigure is None: # oFigure = plt.figure() @@ -549,7 +551,9 @@ class RenderTool(object): plt.xlim([0, env.width * cell_size]) plt.ylim([-env.height * cell_size, 0]) if show: - plt.show() + plt.show(block=False) + plt.pause(0.00001) + return def _draw_square(self, center, size, color): x0 = center[0]-size/2 diff --git a/images/basic-env.png b/images/basic-env.png index 0b21b26887236b4e7d7e82b34bfa074ec9d05c38..850d6ecad2d1adb6d3d4f829116acee67b9441db 100644 Binary files a/images/basic-env.png and b/images/basic-env.png differ diff --git a/images/env-path.png b/images/env-path.png index 5f49e744754237889dd7331213d3084cb19b1555..95b9faa9fbe78e36c49216058274dbd18495cc12 100644 Binary files a/images/env-path.png and b/images/env-path.png differ diff --git a/images/env-tree-graph.png b/images/env-tree-graph.png index 9b2a2d6a9cc4792c962e0fdbdd87f233aaac7d7d..f33b5f4c8d69ab2c028e6cc1689f4ccbb7600ce3 100644 Binary files a/images/env-tree-graph.png and b/images/env-tree-graph.png differ diff --git a/images/env-tree-spatial.png b/images/env-tree-spatial.png index 06f2054027c5da517c8eff0b2c142dd183a7e3fb..54ac9bfc0c2cb0853a319368add4d6ac5514fd28 100644 Binary files a/images/env-tree-spatial.png and b/images/env-tree-spatial.png differ diff --git a/tests/test_env_observation_builder.py b/tests/test_env_observation_builder.py index 1a797de858beb064d1d5338b746bc0047adf6ba8..44dbfc6d8f14a7293e148f125b47eed66c9ca08d 100644 --- a/tests/test_env_observation_builder.py +++ b/tests/test_env_observation_builder.py @@ -3,7 +3,7 @@ from flatland.core.env_observation_builder import GlobalObsForRailEnv from flatland.core.transition_map import GridTransitionMap, Grid4Transitions -from flatland.core.env import RailEnv +from flatland.envs.rail_env import RailEnv, rail_from_GridTransitionMap_generator from flatland.utils.rendertools import * """Tests for `flatland` package.""" @@ -57,29 +57,30 @@ def test_global_obs(): rail = GridTransitionMap(width=rail_map.shape[1], height=rail_map.shape[0], transitions=transitions) rail.grid = rail_map - env = RailEnv(rail, number_of_agents=1) + env = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_GridTransitionMap_generator(rail), + number_of_agents=1, + obs_builder_object=GlobalObsForRailEnv()) - env.reset() + global_obs = env.reset() # env_renderer = RenderTool(env) # env_renderer.renderEnv(show=True) - global_obs = GlobalObsForRailEnv(env) - global_obs.reset() - assert(global_obs.rail_obs.shape == rail_map.shape + (16,)) + # global_obs.reset() + assert(global_obs[0][0].shape == rail_map.shape + (16,)) rail_map_recons = np.zeros_like(rail_map) - for i in range(global_obs.rail_obs.shape[0]): - for j in range(global_obs.rail_obs.shape[1]): - rail_map_recons[i,j] = int( - ''.join(global_obs.rail_obs[i, j].astype(int).astype(str)), 2) + for i in range(global_obs[0][0].shape[0]): + for j in range(global_obs[0][0].shape[1]): + rail_map_recons[i, j] = int( + ''.join(global_obs[0][0][i, j].astype(int).astype(str)), 2) assert(rail_map_recons.all() == rail_map.all()) - obs = global_obs.get(0) - # If this assertion is wrong, it means that the observation returned # places the agent on an empty cell - assert(np.sum(rail_map * obs[1][0]) > 0) + assert(np.sum(rail_map * global_obs[0][1][0]) > 0) diff --git a/tests/test_environments.py b/tests/test_environments.py index a89133d25378fec20bb44e43b7c43ce0432f789a..ea8748b8aa4b50a1371a013be98f3b42d0d01228 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -1,10 +1,11 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- +import numpy as np -from flatland.core.env import RailEnv +from flatland.envs.rail_env import RailEnv, rail_from_GridTransitionMap_generator from flatland.core.transitions import Grid4Transitions from flatland.core.transition_map import GridTransitionMap -import numpy as np +from flatland.core.env_observation_builder import GlobalObsForRailEnv """Tests for `flatland` package.""" @@ -46,7 +47,12 @@ def test_rail_environment_single_agent(): rail = GridTransitionMap(width=3, height=3, transitions=transitions) rail.grid = rail_map - rail_env = RailEnv(rail, number_of_agents=1) + rail_env = RailEnv(width=3, + height=3, + rail_generator=rail_from_GridTransitionMap_generator(rail), + number_of_agents=1, + obs_builder_object=GlobalObsForRailEnv()) + for _ in range(200): _ = rail_env.reset() @@ -118,7 +124,11 @@ def test_dead_end(): transitions=transitions) rail.grid = rail_map - rail_env = RailEnv(rail, number_of_agents=1) + rail_env = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_GridTransitionMap_generator(rail), + number_of_agents=1, + obs_builder_object=GlobalObsForRailEnv()) def check_consistency(rail_env): # We run step to check that trains do not move anymore @@ -141,14 +151,14 @@ def test_dead_end(): # We try the configuration in the 4 directions: rail_env.reset() - rail_env.agents_target[0] = [0, 0] - rail_env.agents_position[0] = [0, 2] + rail_env.agents_target[0] = (0, 0) + rail_env.agents_position[0] = (0, 2) rail_env.agents_direction[0] = 1 check_consistency(rail_env) rail_env.reset() - rail_env.agents_target[0] = [0, 4] - rail_env.agents_position[0] = [0, 2] + rail_env.agents_target[0] = (0, 4) + rail_env.agents_position[0] = (0, 2) rail_env.agents_direction[0] = 3 check_consistency(rail_env) @@ -164,16 +174,20 @@ def test_dead_end(): transitions=transitions) rail.grid = rail_map - rail_env = RailEnv(rail, number_of_agents=1) + rail_env = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_GridTransitionMap_generator(rail), + number_of_agents=1, + obs_builder_object=GlobalObsForRailEnv()) rail_env.reset() - rail_env.agents_target[0] = [0, 0] - rail_env.agents_position[0] = [2, 0] + rail_env.agents_target[0] = (0, 0) + rail_env.agents_position[0] = (2, 0) rail_env.agents_direction[0] = 2 check_consistency(rail_env) rail_env.reset() - rail_env.agents_target[0] = [4, 0] - rail_env.agents_position[0] = [2, 0] + rail_env.agents_target[0] = (4, 0) + rail_env.agents_position[0] = (2, 0) rail_env.agents_direction[0] = 0 check_consistency(rail_env) diff --git a/tests/test_rendertools.py b/tests/test_rendertools.py index ae9a9e1867a9e9d65877e81e73370550a233d206..e45b7d1815365afda98f699d628a7e6f51c92395 100644 --- a/tests/test_rendertools.py +++ b/tests/test_rendertools.py @@ -4,15 +4,14 @@ Tests for `flatland` package. """ -from flatland.core.env import RailEnv +from flatland.envs.rail_env import RailEnv, random_rail_generator import numpy as np -import random import os import matplotlib.pyplot as plt -from flatland.utils import rail_env_generator import flatland.utils.rendertools as rt +from flatland.core.env_observation_builder import GlobalObsForRailEnv def checkFrozenImage(sFileImage): @@ -36,9 +35,12 @@ def checkFrozenImage(sFileImage): def test_render_env(): - random.seed(100) - oRail = rail_env_generator.generate_random_rail(10, 10) - oEnv = RailEnv(oRail, number_of_agents=2) + # random.seed(100) + np.random.seed(100) + oEnv = RailEnv(width=10, height=10, + rail_generator=random_rail_generator(), + number_of_agents=2, + obs_builder_object=GlobalObsForRailEnv()) oEnv.reset() oRT = rt.RenderTool(oEnv) plt.figure(figsize=(10, 10))