diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index 4463b1e3a0a457b895cbd8253a7ea87ae2633318..0c1b7d15406e054959b1af35e67e1060c1c3e47b 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -842,6 +842,7 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation node_positions = [] city_positions = [] intersection_positions = [] + for node_idx in range(num_cities + num_intersections): to_close = True tries = 0 diff --git a/flatland/utils/graphics_pil.py b/flatland/utils/graphics_pil.py index f86f301cdda4676855ad7a52623822f3053d6ea1..d8dbbfffe36662f610adb62393666a3a24a37102 100644 --- a/flatland/utils/graphics_pil.py +++ b/flatland/utils/graphics_pil.py @@ -52,10 +52,10 @@ class PILGL(GraphicsLayer): self.background_grid = np.zeros(shape=(self.width, self.height)) if jupyter is False: - # NOTE: Currently removed the dependency on - # screeninfo. We have to find an alternate + # NOTE: Currently removed the dependency on + # screeninfo. We have to find an alternate # way to compute the screen width and height - # In the meantime, we are harcoding the 800x600 + # In the meantime, we are harcoding the 800x600 # assumption self.screen_width = 800 self.screen_height = 600 @@ -114,7 +114,7 @@ class PILGL(GraphicsLayer): for rc in dTargets: r = rc[1] c = rc[0] - d = int(np.floor(np.sqrt((x - r) ** 2 + (y - c) ** 2))) + d = int(np.floor(np.sqrt((x - r) ** 2 + (y - c) ** 2)) / 0.5) distance = min(d, distance) self.background_grid[x][y] = distance @@ -444,7 +444,7 @@ class PILSVG(PILGL): for transition, file in file_directory.items(): - # Translate the ascii transition description in the format "NE WS" to the + # Translate the ascii transition description in the format "NE WS" to the # binary list of transitions as per RailEnv - NESW (in) x NESW (out) transition_16_bit = ["0"] * 16 for sTran in transition.split(" "): diff --git a/tests/test_flatland_env_sparse_rail_generator.py b/tests/test_flatland_env_sparse_rail_generator.py index f49893aeb78cbb09c45993738b0466efe442e0db..1b274bcae21bcd89cc86f5d10b4ae2603ff915b4 100644 --- a/tests/test_flatland_env_sparse_rail_generator.py +++ b/tests/test_flatland_env_sparse_rail_generator.py @@ -25,15 +25,15 @@ def test_realistic_rail_generator(): def test_sparse_rail_generator(): env = RailEnv(width=50, height=50, - rail_generator=sparse_rail_generator(num_cities=2, # Number of cities in map - num_intersections=3, # Number of interesections in map - num_trainstations=5, # Number of possible start/targets on map - min_node_dist=10, # Minimal distance of nodes - node_radius=2, # Proximity of stations to city center - num_neighb=3, # Number of connections to other cities + rail_generator=sparse_rail_generator(num_cities=3, # Number of cities in map + num_intersections=2, # Number of interesections in map + num_trainstations=15, # Number of possible start/targets on map + min_node_dist=6, # Minimal distance of nodes + node_radius=3, # Proximity of stations to city center + num_neighb=2, # Number of connections to other cities seed=5, # Random seed ), - number_of_agents=0, + number_of_agents=10, obs_builder_object=GlobalObsForRailEnv()) # reset to initialize agents_static env_renderer = RenderTool(env, gl="PILSVG", )