From 482a73e5379ac6c0506b35d951eab47f2322a299 Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Wed, 1 May 2019 08:19:12 +0200
Subject: [PATCH] updated curve calculation

---
 flatland/envs/rail_env.py | 73 +++++++++++++++++++++------------------
 1 file changed, 39 insertions(+), 34 deletions(-)

diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 552e5f39..55e0e2ed 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -163,9 +163,9 @@ def a_star(rail_trans, rail_array, start, end):
         for new_pos in [(0, -1), (0, 1), (-1, 0), (1, 0)]:
             node_pos = (current_node.pos[0] + new_pos[0], current_node.pos[1] + new_pos[1])
             if node_pos[0] >= rail_shape[0] or \
-               node_pos[0] < 0 or \
-               node_pos[1] >= rail_shape[1] or \
-               node_pos[1] < 0:
+                node_pos[0] < 0 or \
+                node_pos[1] >= rail_shape[1] or \
+                node_pos[1] < 0:
                 continue
 
             # validate positions
@@ -232,7 +232,7 @@ def connect_rail(rail_trans, rail_array, start, end):
     end_pos = path[-1]
     for index in range(len(path) - 1):
         current_pos = path[index]
-        new_pos = path[index+1]
+        new_pos = path[index + 1]
         new_dir = get_direction(current_pos, new_pos)
 
         new_trans = rail_array[current_pos]
@@ -359,6 +359,7 @@ def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0):
                                     # print("too close:", dist, sg_new[i], sg[j])
                                     return False
                     return True
+
                 if check_all_dist(sg_new):
                     break
             start_goal.append([start, goal])
@@ -394,6 +395,7 @@ def rail_from_manual_specifications_generator(rail_spec):
         Generator function that always returns a GridTransitionMap object with
         the matrix of correct 16-bit bitmaps for each cell.
     """
+
     def generator(width, height, num_resets=0):
         t_utils = RailEnvTransitions()
 
@@ -429,6 +431,7 @@ def rail_from_GridTransitionMap_generator(rail_map):
     function
         Generator function that always returns the given `rail_map' object.
     """
+
     def generator(width, height, num_resets=0):
         return rail_map
 
@@ -449,6 +452,7 @@ def rail_from_list_of_saved_GridTransitionMap_generator(list_of_filenames):
     function
         Generator function that always returns the given `rail_map' object.
     """
+
     def generator(width, height, num_resets=0):
         t_utils = RailEnvTransitions()
         rail_map = GridTransitionMap(width=width, height=height, transitions=t_utils)
@@ -525,9 +529,9 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
             # add all rotations
             for rot in [0, 90, 180, 270]:
                 transitions_templates_.append((template,
-                                              t_utils.rotate_transition(
-                                               t_utils.transitions[i],
-                                               rot)))
+                                               t_utils.rotate_transition(
+                                                   t_utils.transitions[i],
+                                                   rot)))
                 transition_probabilities.append(transition_probability[i])
                 template = [template[-1]] + template[:-1]
 
@@ -537,7 +541,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
                 is_match = True
                 for j in range(4):
                     if template[j] >= 0 and \
-                       template[j] != transitions_templates_[i][0][j]:
+                        template[j] != transitions_templates_[i][0][j]:
                         is_match = False
                         break
                 if is_match:
@@ -678,7 +682,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
             neigh_trans = rail[r][1]
             if neigh_trans is not None:
                 for k in range(4):
-                    neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2**4 - 1)
+                    neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1)
                     max_bit = max_bit | (neigh_trans_from_direction & 1)
             if max_bit:
                 rail[r][0] = t_utils.rotate_transition(int('0010000000000000', 2), 270)
@@ -690,7 +694,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
             neigh_trans = rail[r][-2]
             if neigh_trans is not None:
                 for k in range(4):
-                    neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2**4 - 1)
+                    neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1)
                     max_bit = max_bit | (neigh_trans_from_direction & (1 << 2))
             if max_bit:
                 rail[r][-1] = t_utils.rotate_transition(int('0010000000000000', 2),
@@ -704,7 +708,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
             neigh_trans = rail[1][c]
             if neigh_trans is not None:
                 for k in range(4):
-                    neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2**4 - 1)
+                    neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1)
                     max_bit = max_bit | (neigh_trans_from_direction & (1 << 3))
             if max_bit:
                 rail[0][c] = int('0010000000000000', 2)
@@ -716,7 +720,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
             neigh_trans = rail[-2][c]
             if neigh_trans is not None:
                 for k in range(4):
-                    neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2**4 - 1)
+                    neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1)
                     max_bit = max_bit | (neigh_trans_from_direction & (1 << 1))
             if max_bit:
                 rail[-1][c] = t_utils.rotate_transition(int('0010000000000000', 2), 180)
@@ -840,9 +844,9 @@ class RailEnv(Environment):
 
     def check_agent_lists(self):
         for lAgents, name in zip(
-                [self.agents_handles, self.agents_position, self.agents_direction],
-                ["handles", "positions", "directions"]):
-            assert self.number_of_agents == len(lAgents), "Inconsistent agent list:"+name
+            [self.agents_handles, self.agents_position, self.agents_direction],
+            ["handles", "positions", "directions"]):
+            assert self.number_of_agents == len(lAgents), "Inconsistent agent list:" + name
 
     def check_agent_locdirpath(self, iAgent):
         valid_movements = []
@@ -857,7 +861,7 @@ class RailEnv(Environment):
         for m in valid_movements:
             new_position = self._new_position(self.agents_position[iAgent], m[1])
             if m[0] not in valid_starting_directions and \
-                    self._path_exists(new_position, m[0], self.agents_target[iAgent]):
+                self._path_exists(new_position, m[0], self.agents_target[iAgent]):
                 valid_starting_directions.append(m[0])
 
         if len(valid_starting_directions) == 0:
@@ -876,7 +880,7 @@ class RailEnv(Environment):
         for m in valid_movements:
             new_position = self._new_position(rcPos, m[1])
             if m[0] not in valid_starting_directions and \
-                    self._path_exists(new_position, m[0], rcTarget):
+                self._path_exists(new_position, m[0], rcTarget):
                 valid_starting_directions.append(m[0])
 
         if len(valid_starting_directions) == 0:
@@ -891,7 +895,7 @@ class RailEnv(Environment):
             rcPos = np.random.choice(len(self.valid_positions))
 
         iAgent = self.number_of_agents
-        
+
         self.agents_position.append(tuple(rcPos))  # ensure it's a tuple not a list
         self.agents_handles.append(max(self.agents_handles + [-1]) + 1)  # max(handles) + 1, starting at 0
 
@@ -902,7 +906,7 @@ class RailEnv(Environment):
         self.number_of_agents += 1
         self.check_agent_lists()
         return iAgent
-    
+
     def reset(self, regen_rail=True, replace_agents=True):
         if regen_rail or self.rail is None:
             # TODO: Import not only rail information but also start and goal positions
@@ -961,7 +965,7 @@ class RailEnv(Environment):
                         for m in valid_movements:
                             new_position = self._new_position(self.agents_position[i], m[1])
                             if m[0] not in valid_starting_directions and \
-                                    self._path_exists(new_position, m[0], self.agents_target[i]):
+                                self._path_exists(new_position, m[0], self.agents_target[i]):
                                 valid_starting_directions.append(m[0])
 
                         if len(valid_starting_directions) == 0:
@@ -1011,6 +1015,15 @@ class RailEnv(Environment):
                 pos = self.agents_position[i]
                 direction = self.agents_direction[i]
 
+                # compute number of possible transitions in the current
+                # cell used to check for invalid actions
+
+                nbits = 0
+                tmp = self.rail.get_transitions((pos[0], pos[1]))
+                while tmp > 0:
+                    nbits += (tmp & 1)
+                    tmp = tmp >> 1
+
                 movement = direction
                 if action == 1:
                     movement = direction - 1
@@ -1024,14 +1037,6 @@ class RailEnv(Environment):
 
                 is_deadend = False
                 if action == 2:
-
-                    # compute number of possible transitions in the current
-                    # cell
-                    nbits = 0
-                    tmp = self.rail.get_transitions((pos[0], pos[1]))
-                    while tmp > 0:
-                        nbits += (tmp & 1)
-                        tmp = tmp >> 1
                     if nbits == 1:
                         # dead-end;  assuming the rail network is consistent,
                         # this should match the direction the agent has come
@@ -1074,9 +1079,9 @@ class RailEnv(Environment):
                 # Is it a legal move?  1) transition allows the movement in the
                 # cell,  2) the new cell is not empty (case 0),  3) the cell is
                 # free, i.e., no agent is currently in that cell
-                if new_position[1] >= self.width or\
-                   new_position[0] >= self.height or\
-                   new_position[0] < 0 or new_position[1] < 0:
+                if new_position[1] >= self.width or \
+                    new_position[0] >= self.height or \
+                    new_position[0] < 0 or new_position[1] < 0:
                     new_cell_isValid = False
 
                 elif self.rail.get_transitions((new_position[0], new_position[1])) > 0:
@@ -1105,7 +1110,7 @@ class RailEnv(Environment):
 
             # if agent is not in target position, add step penalty
             if self.agents_position[i][0] == self.agents_target[i][0] and \
-               self.agents_position[i][1] == self.agents_target[i][1]:
+                self.agents_position[i][1] == self.agents_target[i][1]:
                 self.dones[handle] = True
             else:
                 self.rewards_dict[handle] += step_penalty
@@ -1114,7 +1119,7 @@ class RailEnv(Environment):
         num_agents_in_target_position = 0
         for i in range(self.number_of_agents):
             if self.agents_position[i][0] == self.agents_target[i][0] and \
-               self.agents_position[i][1] == self.agents_target[i][1]:
+                self.agents_position[i][1] == self.agents_target[i][1]:
                 num_agents_in_target_position += 1
 
         if num_agents_in_target_position == self.number_of_agents:
@@ -1127,7 +1132,7 @@ class RailEnv(Environment):
         return self._get_observations(), self.rewards_dict, self.dones, {}
 
     def _new_position(self, position, movement):
-        if movement == 0:    # NORTH
+        if movement == 0:  # NORTH
             return (position[0] - 1, position[1])
         elif movement == 1:  # EAST
             return (position[0], position[1] + 1)
-- 
GitLab