From d583e91e3610d191dee7cd4fd1518f0c275316cd Mon Sep 17 00:00:00 2001 From: flaurent <florian.laurent@gmail.com> Date: Mon, 20 Jul 2020 05:18:48 +0200 Subject: [PATCH] Added test checking that flatland can be called from a multiprocessing pool --- tests/test_flatland_multiprocessing.py | 35 ++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 tests/test_flatland_multiprocessing.py diff --git a/tests/test_flatland_multiprocessing.py b/tests/test_flatland_multiprocessing.py new file mode 100644 index 00000000..23cfeeac --- /dev/null +++ b/tests/test_flatland_multiprocessing.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +from multiprocessing.pool import Pool + +from flatland.envs.observations import TreeObsForRailEnv +from flatland.envs.predictions import ShortestPathPredictorForRailEnv +from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import rail_from_grid_transition_map +from flatland.envs.schedule_generators import random_schedule_generator +from flatland.utils.simple_rail import make_simple_rail + +"""Tests for `flatland` package.""" + + +def test_multiprocessing_tree_obs(): + number_of_agents = 5 + rail, rail_map = make_simple_rail() + + obs_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()) + + env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), number_of_agents=number_of_agents, + obs_builder_object=obs_builder) + env.reset(True, True) + + pool = Pool() + pool.map(obs_builder.get, range(number_of_agents)) + + +def main(): + test_multiprocessing_tree_obs() + + +if __name__ == "__main__": + main() -- GitLab