Skip to content
Snippets Groups Projects
Commit 8bfa7470 authored by hagrid67's avatar hagrid67
Browse files

Merge branch '22-performance-tuning' into 'master'

Some performance enhancements

See merge request flatland/flatland!19
parents 1a7d49ff f3c323b6
No related branches found
No related tags found
No related merge requests found
import numpy as np
import random
from collections import namedtuple, deque
import copy
import os
from flatland.baselines.model import QNetwork, QNetwork2
import random
from collections import namedtuple, deque, Iterable
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
import copy
from flatland.baselines.model import QNetwork, QNetwork2
BUFFER_SIZE = int(1e5) # replay buffer size
BATCH_SIZE = 512 # minibatch size
......@@ -175,16 +177,24 @@ class ReplayBuffer:
"""Randomly sample a batch of experiences from memory."""
experiences = random.sample(self.memory, k=self.batch_size)
states = torch.from_numpy(np.vstack([e.state for e in experiences if e is not None])).float().to(device)
actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).long().to(device)
rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(device)
next_states = torch.from_numpy(np.vstack([e.next_state for e in experiences if e is not None])).float().to(
device)
dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(
device)
states = torch.from_numpy(self.__v_stack_impr([e.state for e in experiences if e is not None])) \
.float().to(device)
actions = torch.from_numpy(self.__v_stack_impr([e.action for e in experiences if e is not None])) \
.long().to(device)
rewards = torch.from_numpy(self.__v_stack_impr([e.reward for e in experiences if e is not None])) \
.float().to(device)
next_states = torch.from_numpy(self.__v_stack_impr([e.next_state for e in experiences if e is not None])) \
.float().to(device)
dones = torch.from_numpy(self.__v_stack_impr([e.done for e in experiences if e is not None]).astype(np.uint8)) \
.float().to(device)
return (states, actions, rewards, next_states, dones)
def __len__(self):
"""Return the current size of internal memory."""
return len(self.memory)
def __v_stack_impr(self, states):
sub_dim = len(states[0][0]) if isinstance(states[0], Iterable) else 1
np_states = np.reshape(np.array(states), (len(states), sub_dim))
return np_states
......@@ -556,16 +556,16 @@ class RailEnvTransitions(Grid4Transitions):
self.maskDeadEnds = 0b0010000110000100
# create this to make validation faster
self.transitions_all = []
self.transitions_all = set()
for index, trans in enumerate(self.transitions):
self.transitions_all.append(trans)
self.transitions_all.add(trans)
if index in (2, 4, 6, 7, 8, 9, 10):
for _ in range(3):
trans = self.rotate_transition(trans, rotation=90)
self.transitions_all.append(trans)
self.transitions_all.add(trans)
elif index in (1, 5):
trans = self.rotate_transition(trans, rotation=90)
self.transitions_all.append(trans)
self.transitions_all.add(trans)
def print(self, cell_transition):
print(" NESW")
......@@ -620,10 +620,7 @@ class RailEnvTransitions(Grid4Transitions):
Boolean
True or False
"""
for trans in self.transitions_all:
if cell_transition == trans:
return True
return False
return cell_transition in self.transitions_all
def has_deadend(self, cell_transition):
if cell_transition & self.maskDeadEnds > 0:
......
"""
Definition of the RailEnv environment and related level-generation functions.
Generator functions are functions that take width, height and num_resets as arguments and return
a GridTransitionMap object.
"""
import numpy as np
# from flatland.core.env import Environment
# from flatland.envs.observations import TreeObsForRailEnv
......@@ -53,7 +54,6 @@ def validate_new_transition(rail_trans, rail_array, prev_pos, current_pos, new_p
else:
# check if matches existing layout
new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
# new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
else:
# set the forward path
new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
......@@ -68,7 +68,6 @@ def validate_new_transition(rail_trans, rail_array, prev_pos, current_pos, new_p
else:
# check if matches existing layout
new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1)
# new_trans_e = rail_trans.set_transition(new_trans_e, mirror(new_dir), mirror(new_dir), 1)
if not rail_trans.is_valid(new_trans_e):
return False
......@@ -90,6 +89,9 @@ class AStarNode():
def __eq__(self, other):
return self.pos == other.pos
def __hash__(self):
return hash(self.pos)
def update_if_better(self, other):
if other.g < self.g:
self.parent = other.parent
......@@ -106,30 +108,23 @@ def a_star(rail_trans, rail_array, start, end):
rail_shape = rail_array.shape
start_node = AStarNode(None, start)
end_node = AStarNode(None, end)
open_list = []
closed_list = []
open_nodes = set()
closed_nodes = set()
open_nodes.add(start_node)
open_list.append(start_node)
# this could be optimized
def is_node_in_list(node, the_list):
for o_node in the_list:
if node == o_node:
return o_node
return None
while len(open_list) > 0:
while len(open_nodes) > 0:
# get node with current shortest est. path (lowest f)
current_node = open_list[0]
current_index = 0
for index, item in enumerate(open_list):
current_node = None
for item in open_nodes:
if current_node is None:
current_node = item
continue
if item.f < current_node.f:
current_node = item
current_index = index
# pop current off open list, add to closed list
open_list.pop(current_index)
closed_list.append(current_node)
open_nodes.remove(current_node)
closed_nodes.add(current_node)
# found the goal
if current_node == end_node:
......@@ -149,10 +144,7 @@ def a_star(rail_trans, rail_array, start, end):
prev_pos = None
for new_pos in [(0, -1), (0, 1), (-1, 0), (1, 0)]:
node_pos = (current_node.pos[0] + new_pos[0], current_node.pos[1] + new_pos[1])
if node_pos[0] >= rail_shape[0] or \
node_pos[0] < 0 or \
node_pos[1] >= rail_shape[1] or \
node_pos[1] < 0:
if node_pos[0] >= rail_shape[0] or node_pos[0] < 0 or node_pos[1] >= rail_shape[1] or node_pos[1] < 0:
continue
# validate positions
......@@ -166,8 +158,7 @@ def a_star(rail_trans, rail_array, start, end):
# loop through children
for child in children:
# already in closed list?
closed_node = is_node_in_list(child, closed_list)
if closed_node is not None:
if child in closed_nodes:
continue
# create the f, g, and h values
......@@ -180,16 +171,14 @@ def a_star(rail_trans, rail_array, start, end):
child.f = child.g + child.h
# already in the open list?
open_node = is_node_in_list(child, open_list)
if open_node is not None:
open_node.update_if_better(child)
if child in open_nodes:
continue
# add the child to the open list
open_list.append(child)
open_nodes.add(child)
# no full path found
if len(open_list) == 0:
if len(open_nodes) == 0:
return []
......@@ -323,8 +312,7 @@ def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents):
valid_starting_directions = []
for m in valid_movements:
new_position = get_new_position(agents_position[i], m[1])
if m[0] not in valid_starting_directions and \
_path_exists(rail, new_position, m[0], agents_target[i]):
if m[0] not in valid_starting_directions and _path_exists(rail, new_position, m[0], agents_target[i]):
valid_starting_directions.append(m[0])
if len(valid_starting_directions) == 0:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment