From f5748de299099235ee57966504c02abb472d2b53 Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Tue, 1 Oct 2019 09:27:47 -0400
Subject: [PATCH] improved city orientation with grid_mode=True

---
 examples/flatland_2_0_example.py | 4 ++--
 flatland/envs/rail_generators.py | 5 ++++-
 2 files changed, 6 insertions(+), 3 deletions(-)

diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py
index 621f2c79..92542c99 100644
--- a/examples/flatland_2_0_example.py
+++ b/examples/flatland_2_0_example.py
@@ -30,9 +30,9 @@ speed_ration_map = {1.: 0.25,  # Fast passenger train
 
 env = RailEnv(width=50,
               height=50,
-              rail_generator=sparse_rail_generator(num_cities=8,  # Number of cities in map (where train stations are)
+              rail_generator=sparse_rail_generator(num_cities=12,  # Number of cities in map (where train stations are)
                                                    seed=1,  # Random seed
-                                                   grid_mode=False,
+                                                   grid_mode=True,
                                                    max_inter_city_rails=2,
                                                    max_tracks_in_city=4,
                                                    ),
diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index 337b301a..94e6ddb6 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -692,7 +692,10 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
             # Store the directions to these neighbours and orient city to face closest neighbour
             connection_sides_idx = []
             idx = 1
-            current_closest_direction = direction_to_point(node_position, node_positions[closest_neighb_idx[idx]])
+            if grid_mode:
+                current_closest_direction = np.random.randint(4)
+            else:
+                current_closest_direction = direction_to_point(node_position, node_positions[closest_neighb_idx[idx]])
             connection_sides_idx.append(current_closest_direction)
             connection_sides_idx.append((current_closest_direction + 2) % 4)
             city_orientations.append(current_closest_direction)
-- 
GitLab