From 87a34e74f1e95c491156fcf19297553dc75d6997 Mon Sep 17 00:00:00 2001
From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch>
Date: Wed, 14 Aug 2019 17:18:49 +0200
Subject: [PATCH] first draft

---
 flatland/envs/generators.py | 193 ++++++++++++++++++++----------------
 1 file changed, 107 insertions(+), 86 deletions(-)

diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py
index 2de26ddc..4cb8b8cc 100644
--- a/flatland/envs/generators.py
+++ b/flatland/envs/generators.py
@@ -1,3 +1,5 @@
+from enum import IntEnum
+
 import msgpack
 import numpy as np
 
@@ -555,6 +557,82 @@ def realistic_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dis
         The matrix with the correct 16-bit bitmaps for each cell.
     """
 
+    def add_rail(grid_map, pt_from, pt_via, pt_to, bAddRemove=True):
+        gRCTrans = np.array([[-1, 0], [0, 1], [1, 0], [0, -1]])  # NESW in RC
+
+        lrcStroke = [[pt_from[0], pt_from[1]], [pt_via[0], pt_via[1]], [pt_to[0], pt_to[1]]]
+
+        rc3Cells = np.array(lrcStroke[:3])  # the 3 cells
+        rcMiddle = rc3Cells[1]  # the middle cell which we will update
+        bDeadend = np.all(lrcStroke[0] == lrcStroke[2])  # deadend means cell 0 == cell 2
+
+        # get the 2 row, col deltas between the 3 cells, eg [[-1,0],[0,1]] = North, East
+        rc2Trans = np.diff(rc3Cells, axis=0)
+
+        # get the direction index for the 2 transitions
+        liTrans = []
+        for rcTrans in rc2Trans:
+            # gRCTrans - rcTrans gives an array of vector differences between our rcTrans
+            # and the 4 directions stored in gRCTrans.
+            # Where the vector difference is zero, we have a match...
+            # np.all detects where the whole row,col vector is zero.
+            # argwhere gives the index of the zero vector, ie the direction index
+            iTrans = np.argwhere(np.all(gRCTrans - rcTrans == 0, axis=1))
+            if len(iTrans) > 0:
+                iTrans = iTrans[0][0]
+                liTrans.append(iTrans)
+
+        # check that we have two transitions
+        if len(liTrans) == 2:
+            # Set the transition
+            # If this transition spans 3 cells, it is not a deadend, so remove any deadends.
+            # The user will need to resolve any conflicts.
+            grid_map.set_transition((*rcMiddle, liTrans[0]),
+                                    liTrans[1],
+                                    bAddRemove,
+                                    remove_deadends=not bDeadend)
+
+            # Also set the reverse transition
+            # use the reversed outbound transition for inbound
+            # and the reversed inbound transition for outbound
+            grid_map.set_transition((*rcMiddle, mirror(liTrans[1])),
+                                    mirror(liTrans[0]), bAddRemove, remove_deadends=not bDeadend)
+
+    def make_switch_w_e(grid_map, center):
+        # e -> w
+        start = (center[0]+1, center[1]-1)
+        via = (center[0], center[1] - 1)
+        goal = (center[0], center[1])
+        add_rail(grid_map, start, via, goal)
+        start = (center[0], center[1]-1)
+        via = (center[0]+1, center[1]-1)
+        goal = (center[0]+1, center[1]-2)
+        add_rail(grid_map, start, via, goal)
+
+    def make_switch_e_w(grid_map, center):
+        # e -> w
+        start = (center[0] + 1, center[1])
+        via = (center[0] + 1, center[1] - 1)
+        goal = (center[0], center[1] - 1)
+        add_rail(grid_map, start, via, goal)
+        start = (center[0] + 1, center[1] - 1)
+        via = (center[0], center[1] - 1)
+        goal = (center[0], center[1] - 2)
+        add_rail(grid_map, start, via, goal)
+
+    class Grid4TransitionsEnum(IntEnum):
+        NORTH = 0
+        EAST = 1
+        SOUTH = 2
+        WEST = 3
+
+        @staticmethod
+        def to_char(int: int):
+            return {0: 'N',
+                    1: 'E',
+                    2: 'S',
+                    3: 'W'}[int]
+
     def generator(width, height, num_agents, num_resets=0):
 
         if num_agents > nr_start_goal:
@@ -567,97 +645,40 @@ def realistic_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dis
 
         np.random.seed(seed + num_resets)
 
-        # generate rail array
-        # step 1:
-        # - generate a start and goal position
-        #   - validate min/max distance allowed
-        #   - validate that start/goals are not placed too close to other start/goals
-        #   - draw a rail from [start,goal]
-        #     - if rail crosses existing rail then validate new connection
-        #     - possibility that this fails to create a path to goal
-        #     - on failure generate new start/goal
-        #
-        # step 2:
-        # - add more rails to map randomly between cells that have rails
-        #   - validate all new rails, on failure don't add new rails
-        #
-        # step 3:
-        # - return transition map + list of [start_pos, start_dir, goal_pos] points
-        #
 
-        start_goal = []
-        start_dir = []
-        nr_created = 0
-        created_sanity = 0
-        sanity_max = 9000
-        while nr_created < nr_start_goal and created_sanity < sanity_max:
-            all_ok = False
-            for _ in range(sanity_max):
-                start = (np.random.randint(0, height), np.random.randint(0, width))
-                goal = (np.random.randint(0, height), np.random.randint(0, width))
-
-                # check to make sure start,goal pos is empty?
-                if rail_array[goal] != 0 or rail_array[start] != 0:
-                    continue
-                # check min/max distance
-                dist_sg = distance_on_rail(start, goal)
-                if dist_sg < min_dist:
-                    continue
-                if dist_sg > max_dist:
-                    continue
-                # check distance to existing points
-                sg_new = [start, goal]
-
-                def check_all_dist(sg_new):
-                    for sg in start_goal:
-                        for i in range(2):
-                            for j in range(2):
-                                dist = distance_on_rail(sg_new[i], sg[j])
-                                if dist < 2:
-                                    return False
-                    return True
-
-                if check_all_dist(sg_new):
-                    all_ok = True
-                    break
-
-            if not all_ok:
-                # we might as well give up at this point
-                break
+        max_n_track_seg=4
+        x_offsets = np.arange(0,height,max_n_track_seg).astype(int)
+        for off_set in x_offsets:
+            # second track
+            data = np.arange(int((width - max_n_track_seg) / max_n_track_seg)) * max_n_track_seg + 2
+            n_track_seg = np.random.choice(max_n_track_seg) + 1
 
+            # track one (full track : left right)
+            start = (off_set, 0)
+            goal = (off_set, width - 1)
             new_path = connect_rail(rail_trans, rail_array, start, goal)
-            if len(new_path) >= 2:
-                nr_created += 1
-                start_goal.append([start, goal])
-                start_dir.append(mirror(get_direction(new_path[0], new_path[1])))
-            else:
-                # after too many failures we will give up
-                created_sanity += 1
 
-        # add extra connections between existing rail
-        created_sanity = 0
-        nr_created = 0
-        while nr_created < nr_extra and created_sanity < sanity_max:
-            all_ok = False
-            for _ in range(sanity_max):
-                start = (np.random.randint(0, height), np.random.randint(0, width))
-                goal = (np.random.randint(0, height), np.random.randint(0, width))
-                # check to make sure start,goal pos are not empty
-                if rail_array[goal] == 0 or rail_array[start] == 0:
-                    continue
-                else:
-                    all_ok = True
-                    break
-            if not all_ok:
-                break
-            new_path = connect_rail(rail_trans, rail_array, start, goal)
-            if len(new_path) >= 2:
-                nr_created += 1
+            agents_position = [new_path[0]]
+            agents_target = [new_path[1]]  # len(new_path) - 1]]
+            agents_direction = [3]
 
-        agents_position = [sg[0] for sg in start_goal[:num_agents]]
-        agents_target = [sg[1] for sg in start_goal[:num_agents]]
-        agents_direction = start_dir[:num_agents]
+            for nbr_track_loop in range(height-1):
+                if len(data) < 2*n_track_seg+1:
+                    break
+                x = np.sort(np.random.choice(data, 2 * n_track_seg, False)).astype(int)
+                data = []
+                for x_loop in range(int(len(x) / 2)):
+                    start = (off_set+nbr_track_loop+1, x[2 * x_loop])
+                    goal = (off_set+nbr_track_loop+1, x[2 * x_loop + 1])
+                    d = np.arange(x[2 * x_loop]+1,x[2 * x_loop+1]-1,2)
+                    data.extend(d)
+                    new_path = connect_rail(rail_trans, rail_array, start, goal)
+                    if len(new_path) >0:
+                        c = (off_set+nbr_track_loop, x[2 * x_loop] + 1)
+                        make_switch_e_w(grid_map, c)
+                        c = (off_set+nbr_track_loop, x[2 * x_loop + 1]+1)
+                        make_switch_w_e(grid_map, c)
 
         return grid_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
 
-    return generator
\ No newline at end of file
+    return generator
-- 
GitLab