diff --git a/flatland/baselines/dueling_double_dqn.py b/flatland/baselines/dueling_double_dqn.py index 66fe3a3effec0dfa9dc35d07fec887eaa05be6fc..41a27bf8431df7812f1b4f63e797aa426c17edf1 100644 --- a/flatland/baselines/dueling_double_dqn.py +++ b/flatland/baselines/dueling_double_dqn.py @@ -1,12 +1,14 @@ -import numpy as np -import random -from collections import namedtuple, deque +import copy import os -from flatland.baselines.model import QNetwork, QNetwork2 +import random +from collections import namedtuple, deque, Iterable + +import numpy as np import torch import torch.nn.functional as F import torch.optim as optim -import copy + +from flatland.baselines.model import QNetwork, QNetwork2 BUFFER_SIZE = int(1e5) # replay buffer size BATCH_SIZE = 512 # minibatch size @@ -175,16 +177,24 @@ class ReplayBuffer: """Randomly sample a batch of experiences from memory.""" experiences = random.sample(self.memory, k=self.batch_size) - states = torch.from_numpy(np.vstack([e.state for e in experiences if e is not None])).float().to(device) - actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).long().to(device) - rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(device) - next_states = torch.from_numpy(np.vstack([e.next_state for e in experiences if e is not None])).float().to( - device) - dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to( - device) + states = torch.from_numpy(self.__v_stack_impr([e.state for e in experiences if e is not None])) \ + .float().to(device) + actions = torch.from_numpy(self.__v_stack_impr([e.action for e in experiences if e is not None])) \ + .long().to(device) + rewards = torch.from_numpy(self.__v_stack_impr([e.reward for e in experiences if e is not None])) \ + .float().to(device) + next_states = torch.from_numpy(self.__v_stack_impr([e.next_state for e in experiences if e is not None])) \ + .float().to(device) + dones = torch.from_numpy(self.__v_stack_impr([e.done for e in experiences if e is not None]).astype(np.uint8)) \ + .float().to(device) return (states, actions, rewards, next_states, dones) def __len__(self): """Return the current size of internal memory.""" return len(self.memory) + + def __v_stack_impr(self, states): + sub_dim = len(states[0][0]) if isinstance(states[0], Iterable) else 1 + np_states = np.reshape(np.array(states), (len(states), sub_dim)) + return np_states diff --git a/flatland/core/transitions.py b/flatland/core/transitions.py index 622d900598bba6bbc48750d6bb48923975af9b5e..add047b6c7895e391211258bb10561110e0f1a19 100644 --- a/flatland/core/transitions.py +++ b/flatland/core/transitions.py @@ -556,16 +556,16 @@ class RailEnvTransitions(Grid4Transitions): self.maskDeadEnds = 0b0010000110000100 # create this to make validation faster - self.transitions_all = [] + self.transitions_all = set() for index, trans in enumerate(self.transitions): - self.transitions_all.append(trans) + self.transitions_all.add(trans) if index in (2, 4, 6, 7, 8, 9, 10): for _ in range(3): trans = self.rotate_transition(trans, rotation=90) - self.transitions_all.append(trans) + self.transitions_all.add(trans) elif index in (1, 5): trans = self.rotate_transition(trans, rotation=90) - self.transitions_all.append(trans) + self.transitions_all.add(trans) def print(self, cell_transition): print(" NESW") @@ -620,10 +620,7 @@ class RailEnvTransitions(Grid4Transitions): Boolean True or False """ - for trans in self.transitions_all: - if cell_transition == trans: - return True - return False + return cell_transition in self.transitions_all def has_deadend(self, cell_transition): if cell_transition & self.maskDeadEnds > 0: diff --git a/flatland/envs/env_utils.py b/flatland/envs/env_utils.py index b58604c6d7ededa28a33d30e87e13777a3cd54ec..1482b4438bebd82638b873f3232198172a05e6d0 100644 --- a/flatland/envs/env_utils.py +++ b/flatland/envs/env_utils.py @@ -1,12 +1,13 @@ - """ Definition of the RailEnv environment and related level-generation functions. Generator functions are functions that take width, height and num_resets as arguments and return a GridTransitionMap object. """ + import numpy as np + # from flatland.core.env import Environment # from flatland.envs.observations import TreeObsForRailEnv @@ -53,7 +54,6 @@ def validate_new_transition(rail_trans, rail_array, prev_pos, current_pos, new_p else: # check if matches existing layout new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) - # new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1) else: # set the forward path new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) @@ -68,7 +68,6 @@ def validate_new_transition(rail_trans, rail_array, prev_pos, current_pos, new_p else: # check if matches existing layout new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1) - # new_trans_e = rail_trans.set_transition(new_trans_e, mirror(new_dir), mirror(new_dir), 1) if not rail_trans.is_valid(new_trans_e): return False @@ -90,6 +89,9 @@ class AStarNode(): def __eq__(self, other): return self.pos == other.pos + def __hash__(self): + return hash(self.pos) + def update_if_better(self, other): if other.g < self.g: self.parent = other.parent @@ -106,30 +108,23 @@ def a_star(rail_trans, rail_array, start, end): rail_shape = rail_array.shape start_node = AStarNode(None, start) end_node = AStarNode(None, end) - open_list = [] - closed_list = [] + open_nodes = set() + closed_nodes = set() + open_nodes.add(start_node) - open_list.append(start_node) - - # this could be optimized - def is_node_in_list(node, the_list): - for o_node in the_list: - if node == o_node: - return o_node - return None - - while len(open_list) > 0: + while len(open_nodes) > 0: # get node with current shortest est. path (lowest f) - current_node = open_list[0] - current_index = 0 - for index, item in enumerate(open_list): + current_node = None + for item in open_nodes: + if current_node is None: + current_node = item + continue if item.f < current_node.f: current_node = item - current_index = index # pop current off open list, add to closed list - open_list.pop(current_index) - closed_list.append(current_node) + open_nodes.remove(current_node) + closed_nodes.add(current_node) # found the goal if current_node == end_node: @@ -149,10 +144,7 @@ def a_star(rail_trans, rail_array, start, end): prev_pos = None for new_pos in [(0, -1), (0, 1), (-1, 0), (1, 0)]: node_pos = (current_node.pos[0] + new_pos[0], current_node.pos[1] + new_pos[1]) - if node_pos[0] >= rail_shape[0] or \ - node_pos[0] < 0 or \ - node_pos[1] >= rail_shape[1] or \ - node_pos[1] < 0: + if node_pos[0] >= rail_shape[0] or node_pos[0] < 0 or node_pos[1] >= rail_shape[1] or node_pos[1] < 0: continue # validate positions @@ -166,8 +158,7 @@ def a_star(rail_trans, rail_array, start, end): # loop through children for child in children: # already in closed list? - closed_node = is_node_in_list(child, closed_list) - if closed_node is not None: + if child in closed_nodes: continue # create the f, g, and h values @@ -180,16 +171,14 @@ def a_star(rail_trans, rail_array, start, end): child.f = child.g + child.h # already in the open list? - open_node = is_node_in_list(child, open_list) - if open_node is not None: - open_node.update_if_better(child) + if child in open_nodes: continue # add the child to the open list - open_list.append(child) + open_nodes.add(child) # no full path found - if len(open_list) == 0: + if len(open_nodes) == 0: return [] @@ -323,8 +312,7 @@ def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents): valid_starting_directions = [] for m in valid_movements: new_position = get_new_position(agents_position[i], m[1]) - if m[0] not in valid_starting_directions and \ - _path_exists(rail, new_position, m[0], agents_target[i]): + if m[0] not in valid_starting_directions and _path_exists(rail, new_position, m[0], agents_target[i]): valid_starting_directions.append(m[0]) if len(valid_starting_directions) == 0: