diff --git a/env_data/tests/Level_distance_map_shortest_path.pkl b/env_data/tests/Level_distance_map_shortest_path.mpk similarity index 100% rename from env_data/tests/Level_distance_map_shortest_path.pkl rename to env_data/tests/Level_distance_map_shortest_path.mpk diff --git a/env_data/tests/test_002.pkl b/env_data/tests/test_002.mpk similarity index 100% rename from env_data/tests/test_002.pkl rename to env_data/tests/test_002.mpk diff --git a/flatland/envs/malfunction_generators.py b/flatland/envs/malfunction_generators.py index 028c17a7be65a21a759767a45ad0eabde4f938d4..f8d1bc66b4f9a66c9657902aaa67ae42c9fd8c71 100644 --- a/flatland/envs/malfunction_generators.py +++ b/flatland/envs/malfunction_generators.py @@ -2,11 +2,11 @@ from typing import Callable, NamedTuple, Optional, Tuple -import msgpack import numpy as np from numpy.random.mtrand import RandomState from flatland.envs.agent_utils import EnvAgent, RailAgentStatus +from flatland.envs import persistence Malfunction = NamedTuple('Malfunction', [('num_broken_steps', int)]) MalfunctionParameters = NamedTuple('MalfunctionParameters', @@ -28,7 +28,7 @@ def _malfunction_prob(rate: float) -> float: return 1 - np.exp(- (1 / rate)) -def malfunction_from_file(filename: str) -> Tuple[MalfunctionGenerator, MalfunctionProcessData]: +def malfunction_from_file(filename: str, load_from_package=None) -> Tuple[MalfunctionGenerator, MalfunctionProcessData]: """ Utility to load pickle file @@ -40,18 +40,26 @@ def malfunction_from_file(filename: str) -> Tuple[MalfunctionGenerator, Malfunct ------- generator, Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken """ - with open(filename, "rb") as file_in: - load_data = file_in.read() - data = msgpack.unpackb(load_data, use_list=False, encoding='utf-8') + # with open(filename, "rb") as file_in: + # load_data = file_in.read() + + # if filename.endswith("mpk"): + # data = msgpack.unpackb(load_data, use_list=False, encoding='utf-8') + # elif filename.endswith("pkl"): + # data = pickle.loads(load_data) + env_dict = persistence.RailEnvPersister.load_env_dict(filename, load_from_package=load_from_package) # TODO: make this better by using namedtuple in the pickle file. See issue 282 - data['malfunction'] = MalfunctionProcessData._make(data['malfunction']) - if "malfunction" in data: + if "malfunction" in env_dict: + env_dict['malfunction'] = oMPD = MalfunctionProcessData._make(env_dict['malfunction']) + else: + oMPD=None + if oMPD is not None: # Mean malfunction in number of time steps - mean_malfunction_rate = data["malfunction"].malfunction_rate + mean_malfunction_rate = oMPD.malfunction_rate # Uniform distribution parameters for malfunction duration - min_number_of_steps_broken = data["malfunction"].min_duration - max_number_of_steps_broken = data["malfunction"].max_duration + min_number_of_steps_broken = oMPD.min_duration + max_number_of_steps_broken = oMPD.max_duration else: # Mean malfunction in number of time steps mean_malfunction_rate = 0. diff --git a/flatland/envs/persistence.py b/flatland/envs/persistence.py new file mode 100644 index 0000000000000000000000000000000000000000..e77e97aa10405b9b9eb0119a24d463a3402103f5 --- /dev/null +++ b/flatland/envs/persistence.py @@ -0,0 +1,306 @@ + + +import pickle +import msgpack +import numpy as np + +from flatland.envs import rail_env + +#from flatland.core.env import Environment +from flatland.core.env_observation_builder import DummyObservationBuilder +#from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions +#from flatland.core.grid.grid4_utils import get_new_position +#from flatland.core.grid.grid_utils import IntVector2D +from flatland.core.transition_map import GridTransitionMap +from flatland.envs.agent_utils import Agent, EnvAgent, RailAgentStatus +from flatland.envs.distance_map import DistanceMap + +#from flatland.envs.observations import GlobalObsForRailEnv + +# cannot import objects / classes directly because of circular import +from flatland.envs import malfunction_generators as mal_gen +from flatland.envs import rail_generators as rail_gen +from flatland.envs import schedule_generators as sched_gen + + +class RailEnvPersister(object): + + @classmethod + def save(cls, env, filename, save_distance_maps=False): + """ + Saves environment and distance map information in a file + + Parameters: + --------- + filename: string + save_distance_maps: bool + """ + + env_dict = cls.get_full_state(env) + + #print(f"env save - agents: {env_dict['agents'][0]}") + #a0 = env_dict["agents"][0] + #print("agent type:", type(a0)) + + + + if save_distance_maps is True: + oDistMap = env.distance_map.get() + if oDistMap is not None: + if len(oDistMap) > 0: + env_dict["distance_map"] = oDistMap + else: + print("[WARNING] Unable to save the distance map for this environment, as none was found !") + else: + print("[WARNING] Unable to save the distance map for this environment, as none was found !") + + with open(filename, "wb") as file_out: + + if filename.endswith("mpk"): + data = msgpack.packb(env_dict) + + + elif filename.endswith("pkl"): + data = pickle.dumps(env_dict) + #pickle.dump(env_dict, file_out) + + file_out.write(data) + + #with open(filename, "rb") as file_in: + if filename.endswith("mpk"): + #bytes_in = file_in.read() + dIn = msgpack.unpackb(data, encoding="utf-8") + #print(f"msgpack check - {dIn.keys()}") + #print(f"msgpack check - {dIn['agents'][0]}") + + + + @classmethod + def save_episode(cls, env, filename): + dict_env = cls.get_full_state(env) + + lAgents = dict_env["agents"] + #print("Saving agents:", len(lAgents)) + #print("Agent 0:", type(lAgents[0]), lAgents[0]) + + dict_env["episode"] = env.cur_episode + dict_env["actions"] = env.list_actions + dict_env["shape"] = (env.width, env.height) + dict_env["max_episode_steps"] = env._max_episode_steps + + with open(filename, "wb") as file_out: + if filename.endswith(".mpk"): + file_out.write(msgpack.packb(dict_env)) + elif filename.endswith(".pkl"): + pickle.dump(dict_env, file_out) + + + @classmethod + def load(cls, env, filename, load_from_package=None): + """ + Load environment with distance map from a file + + Parameters: + ------- + filename: string + """ + env_dict = cls.load_env_dict(filename, load_from_package=load_from_package) + cls.set_full_state(env, env_dict) + + @classmethod + def load_new(cls, filename, load_from_package=None): + + env_dict = cls.load_env_dict(filename, load_from_package=load_from_package) + + + # TODO: inefficient - each one of these generators loads the complete env file. + env = rail_env.RailEnv(width=1, height=1, + rail_generator=rail_gen.rail_from_file(filename), + schedule_generator=sched_gen.schedule_from_file(filename), + malfunction_generator_and_process_data=mal_gen.malfunction_from_file(filename), + obs_builder_object=DummyObservationBuilder(), + record_steps=True) + + env.rail = GridTransitionMap(1,1) # dummy + + cls.set_full_state(env, env_dict) + + return env, env_dict + + @classmethod + def load_env_dict(cls, filename, load_from_package=None): + + if load_from_package is not None: + from importlib_resources import read_binary + load_data = read_binary(load_from_package, filename) + else: + with open(filename, "rb") as file_in: + load_data = file_in.read() + + if filename.endswith("mpk"): + env_dict = msgpack.unpackb(load_data, use_list=False, encoding="utf-8") + elif filename.endswith("pkl"): + env_dict = pickle.loads(load_data) + else: + print(f"filename {filename} must end with either pkl or mpk") + env_dict = {} + + if "agents" in env_dict: + env_dict["agents"] = [EnvAgent(*d[0:12]) for d in env_dict["agents"]] + #print(f"env_dict agents: {env_dict['agents']}") + + return env_dict + + @classmethod + def load_resource(cls, package, resource): + """ + Load environment (with distance map?) from a binary + """ + #from importlib_resources import read_binary + #load_data = read_binary(package, resource) + + #if resource.endswith("pkl"): + # env_dict = pickle.loads(load_data) + #elif resource.endswith("mpk"): + # env_dict = msgpack.unpackb(load_data, encoding="utf-8") + + #cls.set_full_state(env, env_dict) + + return cls.load_new(resource, load_from_package=package) + + @classmethod + def set_full_state(cls, env, env_dict): + """ + Sets environment state from env_dict + + Parameters + ------- + env_dict: dict + """ + env.rail.grid = np.array(env_dict["grid"]) + + # agents are always reset as not moving + if "agents_static" in env_dict: + # no idea if this still works + env.agents = EnvAgent.load_legacy_static_agent(env_dict["agents_static"]) + else: + agents_data = env_dict["agents"] + if len(agents_data)>0: + if type(agents_data[0]) is EnvAgent: + env.agents = agents_data + else: + env.agents = [EnvAgent(*d[0:12]) for d in env_dict["agents"]] + + #print(f"env agents: {env.agents}") + + # setup with loaded data + env.height, env.width = env.rail.grid.shape + env.rail.height = env.height + env.rail.width = env.width + env.dones = dict.fromkeys(list(range(env.get_num_agents())) + ["__all__"], False) + + @classmethod + def get_full_state(cls, env): + """ + Returns state of environment in dict object, ready for serialization + + """ + grid_data = env.rail.grid.tolist() + + # msgpack cannot persist EnvAgent so use the Agent namedtuple. + agent_data = [agent.to_agent() for agent in env.agents] + #print("get_full_state - agent_data:", agent_data) + malfunction_data: MalfunctionProcessData = env.malfunction_process_data + + msg_data_dict = { + "grid": grid_data, + "agents": agent_data, + "malfunction": malfunction_data, + "max_episode_steps": env._max_episode_steps, + } + return msg_data_dict + + +################################################################################################ +# deprecated methods moved from RailEnv. Most likely broken. + + def deprecated_get_full_state_msg(self) -> msgpack.Packer: + """ + Returns state of environment in msgpack object + """ + msg_data_dict = self.get_full_state_dict() + return msgpack.packb(msg_data_dict, use_bin_type=True) + + def deprecated_get_agent_state_msg(self) -> msgpack.Packer: + """ + Returns agents information in msgpack object + """ + agent_data = [agent.to_agent() for agent in self.agents] + msg_data = { + "agents": agent_data} + return msgpack.packb(msg_data, use_bin_type=True) + + def deprecated_get_full_state_dist_msg(self) -> msgpack.Packer: + """ + Returns environment information with distance map information as msgpack object + """ + grid_data = self.rail.grid.tolist() + agent_data = [agent.to_agent() for agent in self.agents] + + # I think these calls do nothing - they create packed data and it is discarded + #msgpack.packb(grid_data, use_bin_type=True) + #msgpack.packb(agent_data, use_bin_type=True) + + distance_map_data = self.distance_map.get() + malfunction_data: MalfunctionProcessData = self.malfunction_process_data + #msgpack.packb(distance_map_data, use_bin_type=True) # does nothing + msg_data = { + "grid": grid_data, + "agents": agent_data, + "distance_map": distance_map_data, + "malfunction": malfunction_data} + return msgpack.packb(msg_data, use_bin_type=True) + + def deprecated_set_full_state_msg(self, msg_data): + """ + Sets environment state with msgdata object passed as argument + + Parameters + ------- + msg_data: msgpack object + """ + data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8') + self.rail.grid = np.array(data["grid"]) + # agents are always reset as not moving + if "agents_static" in data: + self.agents = EnvAgent.load_legacy_static_agent(data["agents_static"]) + else: + self.agents = [EnvAgent(*d[0:12]) for d in data["agents"]] + # setup with loaded data + self.height, self.width = self.rail.grid.shape + self.rail.height = self.height + self.rail.width = self.width + self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False) + + def deprecated_set_full_state_dist_msg(self, msg_data): + """ + Sets environment grid state and distance map with msgdata object passed as argument + + Parameters + ------- + msg_data: msgpack object + """ + data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8') + self.rail.grid = np.array(data["grid"]) + # agents are always reset as not moving + if "agents_static" in data: + self.agents = EnvAgent.load_legacy_static_agent(data["agents_static"]) + else: + self.agents = [EnvAgent(*d[0:12]) for d in data["agents"]] + if "distance_map" in data.keys(): + self.distance_map.set(data["distance_map"]) + # setup with loaded data + self.height, self.width = self.rail.grid.shape + self.rail.height = self.height + self.rail.width = self.width + self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False) \ No newline at end of file diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 6778fe8d4f081efa5baa518ecfb73a37a9878837..2a00f2f07c2e0ea9679f8dc51d17d696f32f7ce6 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -20,10 +20,24 @@ from flatland.core.grid.grid_utils import IntVector2D from flatland.core.transition_map import GridTransitionMap from flatland.envs.agent_utils import EnvAgent, RailAgentStatus from flatland.envs.distance_map import DistanceMap -from flatland.envs.malfunction_generators import no_malfunction_generator, Malfunction, MalfunctionProcessData + +# Need to use circular imports for persistence. +from flatland.envs import malfunction_generators as mal_gen +from flatland.envs import rail_generators as rail_gen +from flatland.envs import schedule_generators as sched_gen +from flatland.envs import persistence + +# Direct import of objects / classes does not work with circular imports. +# from flatland.envs.malfunction_generators import no_malfunction_generator, Malfunction, MalfunctionProcessData +# from flatland.envs.observations import GlobalObsForRailEnv +# from flatland.envs.rail_generators import random_rail_generator, RailGenerator +# from flatland.envs.schedule_generators import random_schedule_generator, ScheduleGenerator + from flatland.envs.observations import GlobalObsForRailEnv -from flatland.envs.rail_generators import random_rail_generator, RailGenerator -from flatland.envs.schedule_generators import random_schedule_generator, ScheduleGenerator + + + +import pickle m.patch() @@ -116,11 +130,11 @@ class RailEnv(Environment): def __init__(self, width, height, - rail_generator: RailGenerator = random_rail_generator(), - schedule_generator: ScheduleGenerator = random_schedule_generator(), + rail_generator: rail_gen.RailGenerator = rail_gen.random_rail_generator(), + schedule_generator: sched_gen.ScheduleGenerator = sched_gen.random_schedule_generator(), number_of_agents=1, obs_builder_object: ObservationBuilder = GlobalObsForRailEnv(), - malfunction_generator_and_process_data=no_malfunction_generator(), + malfunction_generator_and_process_data=None, #mal_gen.no_malfunction_generator(), remove_agents_at_target=True, random_seed=1, record_steps=False @@ -162,6 +176,8 @@ class RailEnv(Environment): """ super().__init__() + if malfunction_generator_and_process_data is None: + malfunction_generator_and_process_data = mal_gen.no_malfunction_generator() self.malfunction_generator, self.malfunction_process_data = malfunction_generator_and_process_data self.rail_generator: RailGenerator = rail_generator self.schedule_generator: ScheduleGenerator = schedule_generator @@ -206,7 +222,9 @@ class RailEnv(Environment): # save episode timesteps ie agent positions, orientations. (not yet actions / observations) self.record_steps = record_steps # whether to save timesteps - self.cur_episode = [] # save timesteps in here + # save timesteps in here: [[[row, col, dir, malfunction],...nAgents], ...nSteps] + self.cur_episode = [] + self.list_actions = [] # save actions in here def _seed(self, seed=None): self.np_random, seed = seeding.np_random(seed) @@ -239,6 +257,8 @@ class RailEnv(Environment): agent.reset() self.active_agents = [i for i in range(len(self.agents))] + + def action_required(self, agent): """ Check if an agent needs to provide an action @@ -302,6 +322,8 @@ class RailEnv(Environment): if optionals and 'distance_map' in optionals: self.distance_map.set(optionals['distance_map']) + + if regenerate_schedule or regenerate_rail or self.get_num_agents() == 0: agents_hints = None if optionals and 'agents_hints' in optionals: @@ -346,6 +368,9 @@ class RailEnv(Environment): # Reset the malfunction generator self.malfunction_generator(reset=True) + # Empty the episode store of agent positions + self.cur_episode = [] + info_dict: Dict = { 'action_required': {i: self.action_required(agent) for i, agent in enumerate(self.agents)}, 'malfunction': { @@ -470,7 +495,7 @@ class RailEnv(Environment): for i_agent in range(self.get_num_agents()): self.dones[i_agent] = True if self.record_steps: - self.record_timestep() + self.record_timestep(action_dict_) return self._get_observations(), self.rewards_dict, self.dones, info_dict @@ -694,7 +719,7 @@ class RailEnv(Environment): cell_free = False return cell_free, new_cell_valid, new_direction, new_position, transition_valid - def record_timestep(self): + def record_timestep(self, dActions): ''' Record the positions and orientations of all agents in memory, in the cur_episode ''' list_agents_state = [] @@ -707,8 +732,11 @@ class RailEnv(Environment): else: pos = (int(agent.position[0]), int(agent.position[1])) # print("pos:", pos, type(pos[0])) - list_agents_state.append([*pos, int(agent.direction)]) + list_agents_state.append( + [*pos, int(agent.direction), agent.malfunction_data["malfunction"] ]) + self.cur_episode.append(list_agents_state) + self.list_actions.append(dActions) def cell_free(self, position: IntVector2D) -> bool: """ @@ -774,6 +802,7 @@ class RailEnv(Environment): ------ Dict object """ + #print(f"_get_obs - num agents: {self.get_num_agents()} {list(range(self.get_num_agents()))}") self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents()))) return self.obs_dict @@ -792,155 +821,7 @@ class RailEnv(Environment): """ return Grid4Transitions.get_entry_directions(self.rail.get_full_transitions(row, col)) - def get_full_state_msg(self) -> Packer: - """ - Returns state of environment in msgpack object - """ - grid_data = self.rail.grid.tolist() - agent_data = [agent.to_agent() for agent in self.agents] - malfunction_data: MalfunctionProcessData = self.malfunction_process_data - msgpack.packb(grid_data, use_bin_type=True) - msgpack.packb(agent_data, use_bin_type=True) - msg_data = { - "grid": grid_data, - "agents": agent_data, - "malfunction": malfunction_data, - "max_episode_steps": self._max_episode_steps} - return msgpack.packb(msg_data, use_bin_type=True) - - def get_agent_state_msg(self) -> Packer: - """ - Returns agents information in msgpack object - """ - agent_data = [agent.to_agent() for agent in self.agents] - msg_data = { - "agents": agent_data} - return msgpack.packb(msg_data, use_bin_type=True) - - def get_full_state_dist_msg(self) -> Packer: - """ - Returns environment information with distance map information as msgpack object - """ - grid_data = self.rail.grid.tolist() - agent_data = [agent.to_agent() for agent in self.agents] - msgpack.packb(grid_data, use_bin_type=True) - msgpack.packb(agent_data, use_bin_type=True) - distance_map_data = self.distance_map.get() - malfunction_data: MalfunctionProcessData = self.malfunction_process_data - msgpack.packb(distance_map_data, use_bin_type=True) - msg_data = { - "grid": grid_data, - "agents": agent_data, - "distance_map": distance_map_data, - "malfunction": malfunction_data, - "max_episode_steps": self._max_episode_steps} - return msgpack.packb(msg_data, use_bin_type=True) - - def set_full_state_msg(self, msg_data): - """ - Sets environment state with msgdata object passed as argument - - Parameters - ------- - msg_data: msgpack object - """ - data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8') - self.rail.grid = np.array(data["grid"]) - # agents are always reset as not moving - if "agents_static" in data: - self.agents = EnvAgent.load_legacy_static_agent(data["agents_static"]) - else: - self.agents = [EnvAgent(*d[0:12]) for d in data["agents"]] - # setup with loaded data - self.height, self.width = self.rail.grid.shape - self.rail.height = self.height - self.rail.width = self.width - self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False) - - def set_full_state_dist_msg(self, msg_data): - """ - Sets environment grid state and distance map with msgdata object passed as argument - - Parameters - ------- - msg_data: msgpack object - """ - data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8') - self.rail.grid = np.array(data["grid"]) - # agents are always reset as not moving - if "agents_static" in data: - self.agents = EnvAgent.load_legacy_static_agent(data["agents_static"]) - else: - self.agents = [EnvAgent(*d[0:12]) for d in data["agents"]] - if "distance_map" in data.keys(): - self.distance_map.set(data["distance_map"]) - # setup with loaded data - self.height, self.width = self.rail.grid.shape - self.rail.height = self.height - self.rail.width = self.width - self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False) - - def save(self, filename, save_distance_maps=False): - """ - Saves environment and distance map information in a file - - Parameters: - --------- - filename: string - save_distance_maps: bool - """ - if save_distance_maps is True: - if self.distance_map.get() is not None: - if len(self.distance_map.get()) > 0: - with open(filename, "wb") as file_out: - file_out.write(self.get_full_state_dist_msg()) - else: - print("[WARNING] Unable to save the distance map for this environment, as none was found !") - - else: - print("[WARNING] Unable to save the distance map for this environment, as none was found !") - - else: - with open(filename, "wb") as file_out: - file_out.write(self.get_full_state_msg()) - - def save_episode(self, filename): - episode_data = self.cur_episode - msgpack.packb(episode_data, use_bin_type=True) - dict_data = {"episode": episode_data} - # msgpack.packb(msg_data, use_bin_type=True) - with open(filename, "wb") as file_out: - file_out.write(msgpack.packb(dict_data)) - - def load(self, filename): - """ - Load environment with distance map from a file - - Parameters: - ------- - filename: string - """ - with open(filename, "rb") as file_in: - load_data = file_in.read() - self.set_full_state_dist_msg(load_data) - - def load_pkl(self, pkl_data): - """ - Load environment with distance map from a pickle file - - Parameters: - ------- - pkl_data: pickle file - """ - self.set_full_state_msg(pkl_data) - def load_resource(self, package, resource): - """ - Load environment with distance map from a binary - """ - from importlib_resources import read_binary - load_data = read_binary(package, resource) - self.set_full_state_msg(load_data) def _exp_distirbution_synced(self, rate: float) -> float: """ @@ -966,3 +847,9 @@ class RailEnv(Environment): """ return agent.malfunction_data['malfunction'] < 1 + + def save(self, filename): + print("deprecated call to env.save() - pls call RailEnvPersister.save()") + persistence.RailEnvPersister.save(self, filename) + + diff --git a/flatland/envs/rail_env_utils.py b/flatland/envs/rail_env_utils.py index 823fe4f11555298c6ab854c40320dfc9d37d9f4d..9143a86334401889fbd9d2494b2f97a8e6ef5435 100644 --- a/flatland/envs/rail_env_utils.py +++ b/flatland/envs/rail_env_utils.py @@ -8,7 +8,9 @@ from flatland.envs.schedule_generators import schedule_from_file def load_flatland_environment_from_file(file_name: str, load_from_package: str = None, - obs_builder_object: ObservationBuilder = None) -> RailEnv: + obs_builder_object: ObservationBuilder = None, + record_steps = False, + ) -> RailEnv: """ Parameters ---------- @@ -31,6 +33,9 @@ def load_flatland_environment_from_file(file_name: str, max_depth=2, predictor=ShortestPathPredictorForRailEnv(max_depth=10)) environment = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name, load_from_package), - schedule_generator=schedule_from_file(file_name, load_from_package), number_of_agents=1, - obs_builder_object=obs_builder_object) + schedule_generator=schedule_from_file(file_name, load_from_package), + number_of_agents=1, + obs_builder_object=obs_builder_object, + record_steps=record_steps, + ) return environment diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index cb8569643698f68612720a175192af885a222d6d..1a73acb93908fda477be25637631a465df8b9183 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -3,7 +3,6 @@ import sys import warnings from typing import Callable, Tuple, Optional, Dict, List -import msgpack import numpy as np from numpy.random.mtrand import RandomState @@ -16,6 +15,7 @@ from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transition_map import GridTransitionMap from flatland.envs.grid4_generators_utils import connect_rail_in_grid_map, connect_straight_line_in_grid_map, \ fix_inner_nodes, align_cell_to_city +from flatland.envs import persistence RailGeneratorProduct = Tuple[GridTransitionMap, Optional[Dict]] RailGenerator = Callable[[int, int, int, int], RailGeneratorProduct] @@ -240,21 +240,15 @@ def rail_from_file(filename, load_from_package=None) -> RailGenerator: """ def generator(width: int, height: int, num_agents: int, num_resets: int = 0, - np_random: RandomState = None) -> RailGenerator: + np_random: RandomState = None) -> List: + env_dict = persistence.RailEnvPersister.load_env_dict(filename, load_from_package=load_from_package) rail_env_transitions = RailEnvTransitions() - if load_from_package is not None: - from importlib_resources import read_binary - load_data = read_binary(load_from_package, filename) - else: - with open(filename, "rb") as file_in: - load_data = file_in.read() - data = msgpack.unpackb(load_data, use_list=False) - grid = np.array(data[b"grid"]) + grid = np.array(env_dict["grid"]) rail = GridTransitionMap(width=np.shape(grid)[1], height=np.shape(grid)[0], transitions=rail_env_transitions) rail.grid = grid - if b"distance_map" in data.keys(): - distance_map = data[b"distance_map"] + if "distance_map" in env_dict: + distance_map = env_dict["distance_map"] if len(distance_map) > 0: return rail, {'distance_map': distance_map} return [rail, None] @@ -634,7 +628,8 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_ max_feasible_cities = min(max_num_cities, ((height - 2) // (2 * (city_radius + 1))) * ((width - 2) // (2 * (city_radius + 1)))) if max_feasible_cities < 2: - sys.exit("[ABORT] Cannot fit more than one city in this map, no feasible environment possible! Aborting.") + # sys.exit("[ABORT] Cannot fit more than one city in this map, no feasible environment possible! Aborting.") + raise ValueError("ERROR: Cannot fit more than one city in this map, no feasible environment possible!") # Evenly distribute cities if grid_mode: diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py index a84264e04fa2039ea09f9a53df52ab692800ff30..d04047495b8ebf79a82faa472a5608f6da2a0ac2 100644 --- a/flatland/envs/schedule_generators.py +++ b/flatland/envs/schedule_generators.py @@ -2,7 +2,6 @@ import warnings from typing import Tuple, List, Callable, Mapping, Optional, Any -import msgpack import numpy as np from numpy.random.mtrand import RandomState @@ -10,6 +9,7 @@ from flatland.core.grid.grid4_utils import get_new_position from flatland.core.transition_map import GridTransitionMap from flatland.envs.agent_utils import EnvAgent from flatland.envs.schedule_utils import Schedule +from flatland.envs import persistence AgentPosition = Tuple[int, int] ScheduleGenerator = Callable[[GridTransitionMap, int, Optional[Any], Optional[int]], Schedule] @@ -299,22 +299,24 @@ def schedule_from_file(filename, load_from_package=None) -> ScheduleGenerator: def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None, num_resets: int = 0, np_random: RandomState = None) -> Schedule: - if load_from_package is not None: - from importlib_resources import read_binary - load_data = read_binary(load_from_package, filename) - else: - with open(filename, "rb") as file_in: - load_data = file_in.read() - data = msgpack.unpackb(load_data, use_list=False, encoding='utf-8') - if "agents_static" in data: - agents = EnvAgent.load_legacy_static_agent(data["agents_static"]) - else: - agents = [EnvAgent(*d[0:12]) for d in data["agents"]] - if "max_episode_steps" in data: - max_episode_steps = data["max_episode_steps"] - else: - # If no max time was found return 0. - max_episode_steps = 0 + # if load_from_package is not None: + # from importlib_resources import read_binary + # load_data = read_binary(load_from_package, filename) + # else: + # with open(filename, "rb") as file_in: + # load_data = file_in.read() + # data = msgpack.unpackb(load_data, use_list=False, encoding='utf-8') + # if "agents_static" in data: + # agents = EnvAgent.load_legacy_static_agent(data["agents_static"]) + # else: + # agents = [EnvAgent(*d[0:12]) for d in data["agents"]] + + env_dict = persistence.RailEnvPersister.load_env_dict(filename, load_from_package=load_from_package) + + max_episode_steps = env_dict.get("max_episode_steps", 0) + + agents = env_dict["agents"] + # setup with loaded data agents_position = [a.initial_position for a in agents] agents_direction = [a.direction for a in agents] diff --git a/flatland/evaluators/service.py b/flatland/evaluators/service.py index 5aba8d832d3e48615b20c54a486c60c6239e53f3..d9b303c79e71037d224cb108235805ae1f235f56 100644 --- a/flatland/evaluators/service.py +++ b/flatland/evaluators/service.py @@ -435,7 +435,7 @@ class FlatlandRemoteEvaluationService: """ self.simulation_rewards_normalized[-1] += \ cumulative_reward / ( - self.env._max_episode_steps + + self.env._max_episode_steps * self.env.get_num_agents() ) diff --git a/svg/Background_#91D1DD.svg b/flatland/svg/Background_#91D1DD.svg similarity index 100% rename from svg/Background_#91D1DD.svg rename to flatland/svg/Background_#91D1DD.svg diff --git a/svg/Background_#9CCB89.svg b/flatland/svg/Background_#9CCB89.svg similarity index 100% rename from svg/Background_#9CCB89.svg rename to flatland/svg/Background_#9CCB89.svg diff --git a/svg/Background_#AA7B55.svg b/flatland/svg/Background_#AA7B55.svg similarity index 100% rename from svg/Background_#AA7B55.svg rename to flatland/svg/Background_#AA7B55.svg diff --git a/svg/Background_#DEBDA0.svg b/flatland/svg/Background_#DEBDA0.svg similarity index 100% rename from svg/Background_#DEBDA0.svg rename to flatland/svg/Background_#DEBDA0.svg diff --git a/svg/Background_Light_green.svg b/flatland/svg/Background_Light_green.svg similarity index 100% rename from svg/Background_Light_green.svg rename to flatland/svg/Background_Light_green.svg diff --git a/svg/Background_city.svg b/flatland/svg/Background_city.svg similarity index 100% rename from svg/Background_city.svg rename to flatland/svg/Background_city.svg diff --git a/svg/Background_rail.svg b/flatland/svg/Background_rail.svg similarity index 100% rename from svg/Background_rail.svg rename to flatland/svg/Background_rail.svg diff --git a/svg/Background_white.svg b/flatland/svg/Background_white.svg similarity index 100% rename from svg/Background_white.svg rename to flatland/svg/Background_white.svg diff --git a/svg/Background_white_filter.svg b/flatland/svg/Background_white_filter.svg similarity index 100% rename from svg/Background_white_filter.svg rename to flatland/svg/Background_white_filter.svg diff --git a/svg/Bahnhof_#d50000.svg b/flatland/svg/Bahnhof_#d50000.svg similarity index 100% rename from svg/Bahnhof_#d50000.svg rename to flatland/svg/Bahnhof_#d50000.svg diff --git a/svg/Bahnhof_#d50000_Deadend_links.svg b/flatland/svg/Bahnhof_#d50000_Deadend_links.svg similarity index 100% rename from svg/Bahnhof_#d50000_Deadend_links.svg rename to flatland/svg/Bahnhof_#d50000_Deadend_links.svg diff --git a/svg/Bahnhof_#d50000_Deadend_oben.svg b/flatland/svg/Bahnhof_#d50000_Deadend_oben.svg similarity index 100% rename from svg/Bahnhof_#d50000_Deadend_oben.svg rename to flatland/svg/Bahnhof_#d50000_Deadend_oben.svg diff --git a/svg/Bahnhof_#d50000_Deadend_rechts.svg b/flatland/svg/Bahnhof_#d50000_Deadend_rechts.svg similarity index 100% rename from svg/Bahnhof_#d50000_Deadend_rechts.svg rename to flatland/svg/Bahnhof_#d50000_Deadend_rechts.svg diff --git a/svg/Bahnhof_#d50000_Deadend_unten.svg b/flatland/svg/Bahnhof_#d50000_Deadend_unten.svg similarity index 100% rename from svg/Bahnhof_#d50000_Deadend_unten.svg rename to flatland/svg/Bahnhof_#d50000_Deadend_unten.svg diff --git a/svg/Bahnhof_#d50000_Gleis_horizontal.svg b/flatland/svg/Bahnhof_#d50000_Gleis_horizontal.svg similarity index 100% rename from svg/Bahnhof_#d50000_Gleis_horizontal.svg rename to flatland/svg/Bahnhof_#d50000_Gleis_horizontal.svg diff --git a/svg/Bahnhof_#d50000_Gleis_vertikal.svg b/flatland/svg/Bahnhof_#d50000_Gleis_vertikal.svg similarity index 100% rename from svg/Bahnhof_#d50000_Gleis_vertikal.svg rename to flatland/svg/Bahnhof_#d50000_Gleis_vertikal.svg diff --git a/svg/Bahnhof_#d50000_target.svg b/flatland/svg/Bahnhof_#d50000_target.svg similarity index 100% rename from svg/Bahnhof_#d50000_target.svg rename to flatland/svg/Bahnhof_#d50000_target.svg diff --git a/svg/Buildings-Bank.svg b/flatland/svg/Buildings-Bank.svg similarity index 100% rename from svg/Buildings-Bank.svg rename to flatland/svg/Buildings-Bank.svg diff --git a/svg/Buildings-Bar.svg b/flatland/svg/Buildings-Bar.svg similarity index 100% rename from svg/Buildings-Bar.svg rename to flatland/svg/Buildings-Bar.svg diff --git a/svg/Buildings-Fabrik_A.svg b/flatland/svg/Buildings-Fabrik_A.svg similarity index 100% rename from svg/Buildings-Fabrik_A.svg rename to flatland/svg/Buildings-Fabrik_A.svg diff --git a/svg/Buildings-Fabrik_B.svg b/flatland/svg/Buildings-Fabrik_B.svg similarity index 100% rename from svg/Buildings-Fabrik_B.svg rename to flatland/svg/Buildings-Fabrik_B.svg diff --git a/svg/Buildings-Fabrik_C.svg b/flatland/svg/Buildings-Fabrik_C.svg similarity index 100% rename from svg/Buildings-Fabrik_C.svg rename to flatland/svg/Buildings-Fabrik_C.svg diff --git a/svg/Buildings-Fabrik_D.svg b/flatland/svg/Buildings-Fabrik_D.svg similarity index 100% rename from svg/Buildings-Fabrik_D.svg rename to flatland/svg/Buildings-Fabrik_D.svg diff --git a/svg/Buildings-Fabrik_E.svg b/flatland/svg/Buildings-Fabrik_E.svg similarity index 100% rename from svg/Buildings-Fabrik_E.svg rename to flatland/svg/Buildings-Fabrik_E.svg diff --git a/svg/Buildings-Fabrik_F.svg b/flatland/svg/Buildings-Fabrik_F.svg similarity index 100% rename from svg/Buildings-Fabrik_F.svg rename to flatland/svg/Buildings-Fabrik_F.svg diff --git a/svg/Buildings-Fabrik_G.svg b/flatland/svg/Buildings-Fabrik_G.svg similarity index 100% rename from svg/Buildings-Fabrik_G.svg rename to flatland/svg/Buildings-Fabrik_G.svg diff --git a/svg/Buildings-Fabrik_H.svg b/flatland/svg/Buildings-Fabrik_H.svg similarity index 100% rename from svg/Buildings-Fabrik_H.svg rename to flatland/svg/Buildings-Fabrik_H.svg diff --git a/svg/Buildings-Fabrik_I.svg b/flatland/svg/Buildings-Fabrik_I.svg similarity index 100% rename from svg/Buildings-Fabrik_I.svg rename to flatland/svg/Buildings-Fabrik_I.svg diff --git a/svg/Buildings-Hochhaus.svg b/flatland/svg/Buildings-Hochhaus.svg similarity index 100% rename from svg/Buildings-Hochhaus.svg rename to flatland/svg/Buildings-Hochhaus.svg diff --git a/svg/Buildings-Hotel.svg b/flatland/svg/Buildings-Hotel.svg similarity index 100% rename from svg/Buildings-Hotel.svg rename to flatland/svg/Buildings-Hotel.svg diff --git a/svg/Buildings-Office.svg b/flatland/svg/Buildings-Office.svg similarity index 100% rename from svg/Buildings-Office.svg rename to flatland/svg/Buildings-Office.svg diff --git a/svg/Buildings-Polizei.svg b/flatland/svg/Buildings-Polizei.svg similarity index 100% rename from svg/Buildings-Polizei.svg rename to flatland/svg/Buildings-Polizei.svg diff --git a/svg/Buildings-Post.svg b/flatland/svg/Buildings-Post.svg similarity index 100% rename from svg/Buildings-Post.svg rename to flatland/svg/Buildings-Post.svg diff --git a/svg/Buildings-Supermarkt.svg b/flatland/svg/Buildings-Supermarkt.svg similarity index 100% rename from svg/Buildings-Supermarkt.svg rename to flatland/svg/Buildings-Supermarkt.svg diff --git a/svg/Buildings-Tankstelle.svg b/flatland/svg/Buildings-Tankstelle.svg similarity index 100% rename from svg/Buildings-Tankstelle.svg rename to flatland/svg/Buildings-Tankstelle.svg diff --git a/svg/Buildings-Wohnhaus.svg b/flatland/svg/Buildings-Wohnhaus.svg similarity index 100% rename from svg/Buildings-Wohnhaus.svg rename to flatland/svg/Buildings-Wohnhaus.svg diff --git a/svg/Cell_occupied.svg b/flatland/svg/Cell_occupied.svg similarity index 100% rename from svg/Cell_occupied.svg rename to flatland/svg/Cell_occupied.svg diff --git a/svg/Gleis_Deadend.svg b/flatland/svg/Gleis_Deadend.svg similarity index 100% rename from svg/Gleis_Deadend.svg rename to flatland/svg/Gleis_Deadend.svg diff --git a/svg/Gleis_Diamond_Crossing.svg b/flatland/svg/Gleis_Diamond_Crossing.svg similarity index 100% rename from svg/Gleis_Diamond_Crossing.svg rename to flatland/svg/Gleis_Diamond_Crossing.svg diff --git a/svg/Gleis_Kurve_oben_links.svg b/flatland/svg/Gleis_Kurve_oben_links.svg similarity index 100% rename from svg/Gleis_Kurve_oben_links.svg rename to flatland/svg/Gleis_Kurve_oben_links.svg diff --git a/svg/Gleis_Kurve_oben_links_unten_rechts.svg b/flatland/svg/Gleis_Kurve_oben_links_unten_rechts.svg similarity index 100% rename from svg/Gleis_Kurve_oben_links_unten_rechts.svg rename to flatland/svg/Gleis_Kurve_oben_links_unten_rechts.svg diff --git a/svg/Gleis_Kurve_oben_rechts.svg b/flatland/svg/Gleis_Kurve_oben_rechts.svg similarity index 100% rename from svg/Gleis_Kurve_oben_rechts.svg rename to flatland/svg/Gleis_Kurve_oben_rechts.svg diff --git a/svg/Gleis_Kurve_unten_links.svg b/flatland/svg/Gleis_Kurve_unten_links.svg similarity index 100% rename from svg/Gleis_Kurve_unten_links.svg rename to flatland/svg/Gleis_Kurve_unten_links.svg diff --git a/svg/Gleis_Kurve_unten_rechts.svg b/flatland/svg/Gleis_Kurve_unten_rechts.svg similarity index 100% rename from svg/Gleis_Kurve_unten_rechts.svg rename to flatland/svg/Gleis_Kurve_unten_rechts.svg diff --git a/svg/Gleis_horizontal.svg b/flatland/svg/Gleis_horizontal.svg similarity index 100% rename from svg/Gleis_horizontal.svg rename to flatland/svg/Gleis_horizontal.svg diff --git a/svg/Gleis_horizontal_Perron.svg b/flatland/svg/Gleis_horizontal_Perron.svg similarity index 100% rename from svg/Gleis_horizontal_Perron.svg rename to flatland/svg/Gleis_horizontal_Perron.svg diff --git a/svg/Gleis_vertikal.svg b/flatland/svg/Gleis_vertikal.svg similarity index 100% rename from svg/Gleis_vertikal.svg rename to flatland/svg/Gleis_vertikal.svg diff --git a/svg/Gleis_vertikal_Perron.svg b/flatland/svg/Gleis_vertikal_Perron.svg similarity index 100% rename from svg/Gleis_vertikal_Perron.svg rename to flatland/svg/Gleis_vertikal_Perron.svg diff --git a/svg/Scenery-Bergwelt_A_Teil_1_links.svg b/flatland/svg/Scenery-Bergwelt_A_Teil_1_links.svg similarity index 100% rename from svg/Scenery-Bergwelt_A_Teil_1_links.svg rename to flatland/svg/Scenery-Bergwelt_A_Teil_1_links.svg diff --git a/svg/Scenery-Bergwelt_A_Teil_2_mitte.svg b/flatland/svg/Scenery-Bergwelt_A_Teil_2_mitte.svg similarity index 100% rename from svg/Scenery-Bergwelt_A_Teil_2_mitte.svg rename to flatland/svg/Scenery-Bergwelt_A_Teil_2_mitte.svg diff --git a/svg/Scenery-Bergwelt_A_Teil_3_rechts.svg b/flatland/svg/Scenery-Bergwelt_A_Teil_3_rechts.svg similarity index 100% rename from svg/Scenery-Bergwelt_A_Teil_3_rechts.svg rename to flatland/svg/Scenery-Bergwelt_A_Teil_3_rechts.svg diff --git a/svg/Scenery-Bergwelt_B.svg b/flatland/svg/Scenery-Bergwelt_B.svg similarity index 100% rename from svg/Scenery-Bergwelt_B.svg rename to flatland/svg/Scenery-Bergwelt_B.svg diff --git a/svg/Scenery-Bergwelt_C_Teil_1_links.svg b/flatland/svg/Scenery-Bergwelt_C_Teil_1_links.svg similarity index 100% rename from svg/Scenery-Bergwelt_C_Teil_1_links.svg rename to flatland/svg/Scenery-Bergwelt_C_Teil_1_links.svg diff --git a/svg/Scenery-Bergwelt_C_Teil_2_rechts.svg b/flatland/svg/Scenery-Bergwelt_C_Teil_2_rechts.svg similarity index 100% rename from svg/Scenery-Bergwelt_C_Teil_2_rechts.svg rename to flatland/svg/Scenery-Bergwelt_C_Teil_2_rechts.svg diff --git a/svg/Scenery-Laubbaume_A.svg b/flatland/svg/Scenery-Laubbaume_A.svg similarity index 100% rename from svg/Scenery-Laubbaume_A.svg rename to flatland/svg/Scenery-Laubbaume_A.svg diff --git a/svg/Scenery-Laubbaume_B.svg b/flatland/svg/Scenery-Laubbaume_B.svg similarity index 100% rename from svg/Scenery-Laubbaume_B.svg rename to flatland/svg/Scenery-Laubbaume_B.svg diff --git a/svg/Scenery-Laubbaume_C.svg b/flatland/svg/Scenery-Laubbaume_C.svg similarity index 100% rename from svg/Scenery-Laubbaume_C.svg rename to flatland/svg/Scenery-Laubbaume_C.svg diff --git a/svg/Scenery-Nadelbaume_A.svg b/flatland/svg/Scenery-Nadelbaume_A.svg similarity index 100% rename from svg/Scenery-Nadelbaume_A.svg rename to flatland/svg/Scenery-Nadelbaume_A.svg diff --git a/svg/Scenery-Nadelbaume_B.svg b/flatland/svg/Scenery-Nadelbaume_B.svg similarity index 100% rename from svg/Scenery-Nadelbaume_B.svg rename to flatland/svg/Scenery-Nadelbaume_B.svg diff --git a/svg/Scenery_Water.svg b/flatland/svg/Scenery_Water.svg similarity index 100% rename from svg/Scenery_Water.svg rename to flatland/svg/Scenery_Water.svg diff --git a/svg/Selected_Agent.svg b/flatland/svg/Selected_Agent.svg similarity index 100% rename from svg/Selected_Agent.svg rename to flatland/svg/Selected_Agent.svg diff --git a/svg/Selected_Target.svg b/flatland/svg/Selected_Target.svg similarity index 100% rename from svg/Selected_Target.svg rename to flatland/svg/Selected_Target.svg diff --git a/svg/Weiche_Double_Slip.svg b/flatland/svg/Weiche_Double_Slip.svg similarity index 100% rename from svg/Weiche_Double_Slip.svg rename to flatland/svg/Weiche_Double_Slip.svg diff --git a/svg/Weiche_Single_Slip.svg b/flatland/svg/Weiche_Single_Slip.svg similarity index 100% rename from svg/Weiche_Single_Slip.svg rename to flatland/svg/Weiche_Single_Slip.svg diff --git a/svg/Weiche_Symetrical.svg b/flatland/svg/Weiche_Symetrical.svg similarity index 100% rename from svg/Weiche_Symetrical.svg rename to flatland/svg/Weiche_Symetrical.svg diff --git a/svg/Weiche_Symetrical_gerade.svg b/flatland/svg/Weiche_Symetrical_gerade.svg similarity index 100% rename from svg/Weiche_Symetrical_gerade.svg rename to flatland/svg/Weiche_Symetrical_gerade.svg diff --git a/svg/Weiche_horizontal_oben_links.svg b/flatland/svg/Weiche_horizontal_oben_links.svg similarity index 100% rename from svg/Weiche_horizontal_oben_links.svg rename to flatland/svg/Weiche_horizontal_oben_links.svg diff --git a/svg/Weiche_horizontal_oben_rechts.svg b/flatland/svg/Weiche_horizontal_oben_rechts.svg similarity index 100% rename from svg/Weiche_horizontal_oben_rechts.svg rename to flatland/svg/Weiche_horizontal_oben_rechts.svg diff --git a/svg/Weiche_horizontal_unten_links.svg b/flatland/svg/Weiche_horizontal_unten_links.svg similarity index 100% rename from svg/Weiche_horizontal_unten_links.svg rename to flatland/svg/Weiche_horizontal_unten_links.svg diff --git a/svg/Weiche_horizontal_unten_rechts.svg b/flatland/svg/Weiche_horizontal_unten_rechts.svg similarity index 100% rename from svg/Weiche_horizontal_unten_rechts.svg rename to flatland/svg/Weiche_horizontal_unten_rechts.svg diff --git a/svg/Weiche_vertikal_oben_links.svg b/flatland/svg/Weiche_vertikal_oben_links.svg similarity index 100% rename from svg/Weiche_vertikal_oben_links.svg rename to flatland/svg/Weiche_vertikal_oben_links.svg diff --git a/svg/Weiche_vertikal_oben_rechts.svg b/flatland/svg/Weiche_vertikal_oben_rechts.svg similarity index 100% rename from svg/Weiche_vertikal_oben_rechts.svg rename to flatland/svg/Weiche_vertikal_oben_rechts.svg diff --git a/svg/Weiche_vertikal_unten_links.svg b/flatland/svg/Weiche_vertikal_unten_links.svg similarity index 100% rename from svg/Weiche_vertikal_unten_links.svg rename to flatland/svg/Weiche_vertikal_unten_links.svg diff --git a/svg/Weiche_vertikal_unten_rechts.svg b/flatland/svg/Weiche_vertikal_unten_rechts.svg similarity index 100% rename from svg/Weiche_vertikal_unten_rechts.svg rename to flatland/svg/Weiche_vertikal_unten_rechts.svg diff --git a/svg/Zug_1_Weiche_#0091ea.svg b/flatland/svg/Zug_1_Weiche_#0091ea.svg similarity index 100% rename from svg/Zug_1_Weiche_#0091ea.svg rename to flatland/svg/Zug_1_Weiche_#0091ea.svg diff --git a/svg/Zug_1_Weiche_#0091ea_old.svg b/flatland/svg/Zug_1_Weiche_#0091ea_old.svg similarity index 100% rename from svg/Zug_1_Weiche_#0091ea_old.svg rename to flatland/svg/Zug_1_Weiche_#0091ea_old.svg diff --git a/svg/Zug_1_Weiche_#00c853_old.svg b/flatland/svg/Zug_1_Weiche_#00c853_old.svg similarity index 100% rename from svg/Zug_1_Weiche_#00c853_old.svg rename to flatland/svg/Zug_1_Weiche_#00c853_old.svg diff --git a/svg/Zug_1_Weiche_#d50000.svg b/flatland/svg/Zug_1_Weiche_#d50000.svg similarity index 100% rename from svg/Zug_1_Weiche_#d50000.svg rename to flatland/svg/Zug_1_Weiche_#d50000.svg diff --git a/svg/Zug_2_Weiche_#0091ea.svg b/flatland/svg/Zug_2_Weiche_#0091ea.svg similarity index 100% rename from svg/Zug_2_Weiche_#0091ea.svg rename to flatland/svg/Zug_2_Weiche_#0091ea.svg diff --git a/svg/Zug_2_Weiche_#0091ea_old.svg b/flatland/svg/Zug_2_Weiche_#0091ea_old.svg similarity index 100% rename from svg/Zug_2_Weiche_#0091ea_old.svg rename to flatland/svg/Zug_2_Weiche_#0091ea_old.svg diff --git a/svg/Zug_2_Weiche_#00c853_old.svg b/flatland/svg/Zug_2_Weiche_#00c853_old.svg similarity index 100% rename from svg/Zug_2_Weiche_#00c853_old.svg rename to flatland/svg/Zug_2_Weiche_#00c853_old.svg diff --git a/svg/Zug_Gleis_#0091ea.svg b/flatland/svg/Zug_Gleis_#0091ea.svg similarity index 100% rename from svg/Zug_Gleis_#0091ea.svg rename to flatland/svg/Zug_Gleis_#0091ea.svg diff --git a/svg/Zug_Gleis_#0091ea_old.svg b/flatland/svg/Zug_Gleis_#0091ea_old.svg similarity index 100% rename from svg/Zug_Gleis_#0091ea_old.svg rename to flatland/svg/Zug_Gleis_#0091ea_old.svg diff --git a/svg/Zug_Gleis_#00c853_old.svg b/flatland/svg/Zug_Gleis_#00c853_old.svg similarity index 100% rename from svg/Zug_Gleis_#00c853_old.svg rename to flatland/svg/Zug_Gleis_#00c853_old.svg diff --git a/svg/Zug_Gleis_#d50000.svg b/flatland/svg/Zug_Gleis_#d50000.svg similarity index 100% rename from svg/Zug_Gleis_#d50000.svg rename to flatland/svg/Zug_Gleis_#d50000.svg diff --git a/svg/__init__.py b/flatland/svg/__init__.py similarity index 100% rename from svg/__init__.py rename to flatland/svg/__init__.py diff --git a/flatland/utils/editor.py b/flatland/utils/editor.py index cb779b6fbb0fec0a0bed3b36d5cb5b7358d18925..1438a3ebea1738cf1feedb5a5ff38bcd308dee0e 100644 --- a/flatland/utils/editor.py +++ b/flatland/utils/editor.py @@ -20,7 +20,7 @@ class EditorMVC(object): """ EditorMVC - a class to encompass and assemble the Jupyter Editor Model-View-Controller. """ - def __init__(self, env=None, sGL="PIL"): + def __init__(self, env=None, sGL="PIL", env_filename="temp.mpk"): """ Create an Editor MVC assembly around a railenv, or create one if None. """ if env is None: @@ -29,7 +29,7 @@ class EditorMVC(object): env.reset() - self.editor = EditorModel(env) + self.editor = EditorModel(env, env_filename=env_filename) self.editor.view = self.view = View(self.editor, sGL=sGL) self.view.controller = self.editor.controller = self.controller = Controller(self.editor, self.view) self.view.init_canvas() @@ -40,9 +40,10 @@ class View(object): """ The Jupyter Editor View - creates and holds the widgets comprising the Editor. """ - def __init__(self, editor, sGL="MPL"): + def __init__(self, editor, sGL="MPL", screen_width=1200, screen_height=1200): self.editor = self.model = editor self.sGL = sGL + self.xyScreen = (screen_width, screen_height) def display(self): self.output_generator.clear_output() @@ -139,7 +140,8 @@ class View(object): def new_env(self): """ Tell the view to update its graphics when a new env is created. """ - self.oRT = rt.RenderTool(self.editor.env, gl=self.sGL) + self.oRT = rt.RenderTool(self.editor.env, gl=self.sGL, show_debug=True, + screen_height=self.xyScreen[1], screen_width=self.xyScreen[0]) def redraw(self): with self.output_generator: @@ -151,10 +153,12 @@ class View(object): if hasattr(a, 'old_direction') is False: a.old_direction = a.direction - self.oRT.render_env(agents=True, + self.oRT.render_env(show_agents=True, + show_inactive_agents=True, show=False, selected_agent=self.model.selected_agent, - show_observations=False) + show_observations=False, + ) img = self.oRT.get_image() self.wImage.data = img @@ -180,7 +184,9 @@ class View(object): nY = np.floor((self.yxSize[1] - self.yxBase[1]) / self.model.env.width) rc_cell[0] = max(0, min(np.floor(rc_cell[0] / nY), self.model.env.height - 1)) rc_cell[1] = max(0, min(np.floor(rc_cell[1] / nX), self.model.env.width - 1)) - return rc_cell + + # Using numpy arrays for coords not currently supported downstream in the env, observations, etc + return tuple(rc_cell) def log(self, *args, **kwargs): if self.output_generator: @@ -282,11 +288,14 @@ class Controller(object): else: self.lrcStroke = [] - if self.model.selected_agent is not None: - self.lrcStroke = [] - while len(q_events) > 0: - t, x, y = q_events.popleft() - return + # JW: I think this clause causes all editing to fail once an agent is selected. + # I also can't see why it's necessary. So I've if-falsed it out. + if False: + if self.model.selected_agent is not None: + self.lrcStroke = [] + while len(q_events) > 0: + t, x, y = q_events.popleft() + return # Process the events in our queue: # Draw a black square to indicate a trail @@ -330,7 +339,8 @@ class Controller(object): if agent is None: continue if agent_idx == self.model.selected_agent: - agent.direction = (agent.direction + 1) % 4 + agent.initial_direction = (agent.initial_direction + 1) % 4 + agent.direction = agent.initial_direction agent.old_direction = agent.direction self.model.redraw() @@ -373,7 +383,7 @@ class Controller(object): class EditorModel(object): - def __init__(self, env): + def __init__(self, env, env_filename="temp.mpk"): self.view = None self.env = env self.regen_size_width = 10 @@ -387,7 +397,7 @@ class EditorModel(object): self.debug_move_bool = False self.wid_output = None self.draw_mode = "Draw" - self.env_filename = "temp.pkl" + self.env_filename = env_filename self.set_env(env) self.selected_agent = None self.thread = None @@ -658,6 +668,7 @@ class EditorModel(object): self.env = env self.env.reset(regenerate_rail=True) self.fix_env() + self.selected_agent = None # clear the selected agent. self.set_env(self.env) self.view.new_env() self.redraw() @@ -670,7 +681,11 @@ class EditorModel(object): def find_agent_at(self, cell_row_col): for agent_idx, agent in enumerate(self.env.agents): - if tuple(agent.position) == tuple(cell_row_col): + if agent.position is None: + rc_pos = agent.initial_position + else: + rc_pos = agent.position + if tuple(rc_pos) == tuple(cell_row_col): return agent_idx return None @@ -685,18 +700,33 @@ class EditorModel(object): # Has the user clicked on an existing agent? agent_idx = self.find_agent_at(cell_row_col) + # This is in case we still have a selected agent even though the env has been recreated + # with no agents. + if (self.selected_agent is not None) and (self.selected_agent > len(self.env.agents)): + self.selected_agent = None + + # Defensive coding below - for cell_row_col to be a tuple, not a numpy array: + # numpy array breaks various things when loading the env. + if agent_idx is None: # No if self.selected_agent is None: # Create a new agent and select it. - agent = EnvAgent(position=cell_row_col, direction=0, target=cell_row_col, moving=False) + agent = EnvAgent(initial_position=tuple(cell_row_col), + initial_direction=0, + direction=0, + target=tuple(cell_row_col), + moving=False, + ) self.selected_agent = self.env.add_agent(agent) + # self.env.set_agent_active(agent) self.view.oRT.update_background() else: # Move the selected agent to this cell agent = self.env.agents[self.selected_agent] - agent.position = cell_row_col - agent.old_position = cell_row_col + agent.initial_position = tuple(cell_row_col) + agent.position = tuple(cell_row_col) + agent.old_position = tuple(cell_row_col) else: # Yes # Have they clicked on the agent already selected? @@ -711,7 +741,7 @@ class EditorModel(object): def add_target(self, rc_cell): if self.selected_agent is not None: - self.env.agents[self.selected_agent].target = rc_cell + self.env.agents[self.selected_agent].target = tuple(rc_cell) self.view.oRT.update_background() self.redraw() diff --git a/flatland/utils/flask_util.py b/flatland/utils/flask_util.py new file mode 100644 index 0000000000000000000000000000000000000000..e30fd72dfa30073f1404217257b20f8ee582c617 --- /dev/null +++ b/flatland/utils/flask_util.py @@ -0,0 +1,270 @@ + + +from flask import Flask, request, redirect, Response +from flask_socketio import SocketIO, emit +from flask_cors import CORS, cross_origin +import threading +import os +import time +import webbrowser +import numpy as np +import typing +import socket + +from flatland.envs.rail_env import RailEnv, RailEnvActions + + +#async_mode = None + + +class simple_flask_server(object): + """ I wanted to wrap the flask server in a class but this seems to be quite hard; + eg see: https://stackoverflow.com/questions/40460846/using-flask-inside-class + I have made a messy sort of singleton pattern. + It might be easier to revert to the "standard" flask global functions + decorators. + """ + + static_folder = os.path.join(os.getcwd(), "static") + print("Flask static folder: ", static_folder) + app = Flask(__name__, + static_url_path='', + static_folder=static_folder) + + socketio = SocketIO(app, cors_allowed_origins='*') + + # This is the original format for the I/O. + # It comes from the format used in the msgpack saved episode. + # The lists here are truncated from the original - see CK's original main.py, in flatland-render. + gridmap = [ + # list of rows (?). Each cell is a 16-char binary string. Yes this is inefficient! + ["0000000000000000", "0010000000000000", "0000000000000000", "0000000000000000", "0010000000000000", "0000000000000000", "0000000000000000", "0000000000000000", "0010000000000000", "0000000000000000"], + ["0000000000000000", "1000000000100000", "0000000000000000", "0000000000000000", "0000000001001000", "0001001000000000", "0010000000000000", "0000000000000000", "1000000000100000", "0000000000000000"], # ... + ] + agents_static = [ + # [initial position], initial direction, [target], 0 (?) + [[7, 9], 2, [3, 5], 0, + # Speed and malfunction params + {"position_fraction": 0, "speed": 1, "transition_action_on_cellexit": 3}, + {"malfunction": 0, "malfunction_rate": 0, "next_malfunction": 0, "nr_malfunctions": 0}], + [[8, 8], 1, [1, 6], 0, + {"position_fraction": 0, "speed": 1, "transition_action_on_cellexit": 2}, + {"malfunction": 0, "malfunction_rate": 0, "next_malfunction": 0, "nr_malfunctions": 0}], + [[3, 7], 2, [0, 1], 0, + {"position_fraction": 0, "speed": 1, "transition_action_on_cellexit": 2}, + {"malfunction": 0, "malfunction_rate": 0, "next_malfunction": 0, "nr_malfunctions": 0}] + ] + + # "actions" are not really actions, but [row, col, direction] for each agent, at each time step + # This format does not yet handle agents which are in states inactive or done_removed + actions= [ + [[7, 9, 2], [8, 8, 1], [3, 7, 2]], [[7, 9, 2], [8, 7, 3], [2, 7, 0]], # ... + ] + + def __init__(self, env): + # Some ugly stuff with cls and self here + cls = self.__class__ + cls.instance = self # intended as singleton + + cls.app.config['CORS_HEADERS'] = 'Content-Type' + cls.app.config['SECRET_KEY'] = 'secret!' + + self.app = cls.app + self.socketio = cls.socketio + self.env = env + self.renderer_ready = False # to indicate env background not yet drawn + self.port = None # we only assign a port once we start the background server... + self.host = None + + def run_flask_server(self, host='127.0.0.1', port=None): + self.host = host + + if port is None: + self.port = self._find_available_port(host) + else: + self.port = port + + self.socketio.run(simple_flask_server.app, host=host, port=self.port) + + def run_flask_server_in_thread(self, host="127.0.0.1", port=None): + # daemon=True so that this thread exits when the main / foreground thread exits, + # usually when the episode finishes. + self.thread = threading.Thread( + target=self.run_flask_server, + kwargs={"host": host, "port": port}, + daemon=True) + self.thread.start() + # short sleep to allow thread to start (may be unnnecessary) + time.sleep(1) + + def open_browser(self): + webbrowser.open("http://localhost:{}".format(self.port)) + # short sleep to allow browser to request the page etc (may be unnecessary) + time.sleep(1) + + def _test_listen_port(self, host: str, port: int): + oSock = socket.socket() + try: + oSock.bind((host, port)) + except OSError: + return False # The port is not available + + del oSock # This should release the port + return True # The port is available + + def _find_available_port(self, host: str, port_start: int = 8080): + for nPort in range(port_start, port_start+100): + if self._test_listen_port(host, nPort): + return nPort + print("Could not find an available port for Flask to listen on!") + return None + + def get_endpoint_url(self): + return "http://{}:{}".format(self.host, self.port) + + @app.route('/', methods=['GET']) + def home(): + # redirects from "/" to "/index.html" which is then served from static. + # print("Here - / - cwd:", os.getcwd()) + return redirect("index.html") + + @staticmethod + @socketio.on('connect') + def connected(): + ''' + When the JS Renderer connects, + this method will send the env and agent information + ''' + cls = simple_flask_server + print('Client connected') + + # Do we really need this? + cls.socketio.emit('message', {'message': 'Connected'}) + + print('Send Env grid and agents') + # cls.socketio.emit('grid', {'grid': cls.gridmap, 'agents_static': cls.agents_static}, broadcast=False) + cls.instance.send_env() + print("Env and agents sent") + + @staticmethod + @socketio.on('disconnect') + def disconnected(): + print('Client disconnected') + + def send_actions(self, dict_actions): + ''' Sends the agent positions and directions, not really actions. + ''' + llAgents = self.agents_to_list() + self.socketio.emit('agentsAction', {'actions': llAgents}) + + def send_observation(self, agent_handles, dict_obs): + """ Send an observation. + TODO: format observation message. + """ + self.socketio.emit("observation", {"agents": agent_handles, "observations": dict_obs}) + + def send_env(self): + """ Sends the env, ie the rail grid, and the agents (static) information + """ + # convert 2d array of int into 2d array of 16char strings + g2sGrid = np.vectorize(np.binary_repr)(self.env.rail.grid, width=16) + llGrid = g2sGrid.tolist() + llAgents = self.agents_to_list_dict() + self.socketio.emit('grid', { + 'grid': llGrid, + 'agents_static': llAgents + }, + broadcast=False) + + def send_env_and_wait(self): + for iAttempt in range(30): + if self.is_renderer_ready(): + print("Background Render complete") + break + else: + print("Waiting for browser to signal that rendering complete") + time.sleep(1) + + @staticmethod + @socketio.on('renderEvent') + def handle_render_event(data): + cls=simple_flask_server + self = cls.instance + print('RenderEvent!!') + print('status: ' + data['status']) + print('message: ' + data['message']) + + if data['status'] == 'listening': + self.renderer_ready = True + + def is_renderer_ready(self): + return self.renderer_ready + + def agents_to_list_dict(self): + ''' Create a list of lists / dicts for serialisation + Maps from the internal representation in EnvAgent to + the schema used by the Javascript renderer. + ''' + llAgents = [] + for agent in self.env.agents: + if agent.position is None: + # the int()s are to convert from numpy int64 which causes problems in serialization + # to plain old python int + lPos = [int(agent.initial_position[0]), int(agent.initial_position[1])] + else: + lPos = [int(agent.position[0]), int(agent.position[1])] + + lAgent = [ + lPos, + int(agent.direction), + [int(agent.target[0]), int(agent.target[1])], 0, + { # dummy values: + "position_fraction": 0, + "speed": 1, + "transition_action_on_cellexit": 3 + }, + { + "malfunction": 0, + "malfunction_rate": 0, + "next_malfunction": 0, + "nr_malfunctions": 0 + } + ] + llAgents.append(lAgent) + return llAgents + + def agents_to_list(self): + llAgents = [] + for agent in self.env.agents: + if agent.position is None: + lPos = [int(agent.initial_position[0]), int(agent.initial_position[1])] + else: + lPos = [int(agent.position[0]), int(agent.position[1])] + iDir = int(agent.direction) + + lAgent = [*lPos, iDir] + + llAgents.append(lAgent) + return llAgents + + + +def main1(): + + print('Run Flask SocketIO Server') + server = simple_flask_server() + threading.Thread(target=server.run_flask_server).start() + # Open Browser + webbrowser.open('http://127.0.0.1:8080') + + print('Send Action') + for i in server.actions: + time.sleep(1) + print('send action') + server.socketio.emit('agentsAction', {'actions': i}) + + + + + +if __name__ == "__main__": + main1() \ No newline at end of file diff --git a/flatland/utils/graphics_layer.py b/flatland/utils/graphics_layer.py index 2f51df2a1486adb6159a809b6f49e18bc8c873d9..e38a3694417a4cbaffce3a8ee26ac3fbd47a521d 100644 --- a/flatland/utils/graphics_layer.py +++ b/flatland/utils/graphics_layer.py @@ -25,6 +25,13 @@ class GraphicsLayer(object): pass def pause(self, seconds=0.00001): + """ deprecated """ + pass + + def idle(self, seconds=0.00001): + """ process any display events eg redraw, resize. + Return only after the given number of seconds, ie idle / loop until that number. + """ pass def clf(self): diff --git a/flatland/utils/graphics_pgl.py b/flatland/utils/graphics_pgl.py new file mode 100644 index 0000000000000000000000000000000000000000..299459cc5106566f78e94e0b961bfe13d3ee6a7c --- /dev/null +++ b/flatland/utils/graphics_pgl.py @@ -0,0 +1,152 @@ + +import pyglet as pgl +import time + +from PIL import Image +# from numpy import array +# from pkg_resources import resource_string as resource_bytes + +# from flatland.utils.graphics_layer import GraphicsLayer +from flatland.utils.graphics_pil import PILSVG + + +class PGLGL(PILSVG): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.window_open = False # means the window has not yet been opened. + self.close_requested = False # user has clicked + self.closed = False # windows has been closed (currently, we leave the env still running) + + def open_window(self): + print("open_window - pyglet") + assert self.window_open is False, "Window is already open!" + self.window = pgl.window.Window(resizable=True, vsync=False) + #self.__class__.window.title("Flatland") + #self.__class__.window.configure(background='grey') + self.window_open = True + + + @self.window.event + def on_draw(): + #print("pyglet draw event") + self.window.clear() + self.show(from_event=True) + #print("pyglet draw event done") + + + @self.window.event + def on_resize(width, height): + #print(f"The window was resized to {width}, {height}") + self.show(from_event=True) + self.window.dispatch_event("on_draw") + #print("pyglet resize event done") + + @self.window.event + def on_close(): + self.close_requested = True + + + def close_window(self): + self.window.close() + self.closed=True + + def show(self, block=False, from_event=False): + if not self.window_open: + self.open_window() + + if self.close_requested: + if not self.closed: + self.close_window() + return + + #tStart = time.time() + self._processEvents() + + pil_img = self.alpha_composite_layers() + pil_img_resized = pil_img.resize((self.window.width, self.window.height), resample=Image.NEAREST) + + # convert our PIL image to pyglet: + bytes_image = pil_img_resized.tobytes() + pgl_image = pgl.image.ImageData(pil_img_resized.width, pil_img_resized.height, + #self.window.width, self.window.height, + 'RGBA', + bytes_image, pitch=-pil_img_resized.width * 4) + + pgl_image.blit(0,0) + #tEnd = time.time() + #print("show time: ", tEnd - tStart) + + def _processEvents(self): + """ This is the replacement for a custom event loop for Pyglet. + The lines below are typical of Pyglet examples. + Manually resizing the window is still very clunky. + """ + #print("process events...", end="") + pgl.clock.tick() + #for window in pgl.app.windows: + if not self.closed: + self.window.switch_to() + self.window.dispatch_events() + self.window.flip() + #print(" events done") + + + + def idle(self, seconds=0.00001): + tStart = time.time() + tEnd = tStart + seconds + while (time.time() < tEnd): + self._processEvents() + #self.show() + time.sleep(min(seconds, 0.1)) + + +def test_pyglet(): + oGL = PGLGL(400,300) + time.sleep(2) + + +def test_event_loop(): + """ Shows how it should work with the standard event loop + Resizing is fairly smooth (ie runs at least 10-20x a second) + """ + + + window = pgl.window.Window(resizable=True) + pil_img = Image.open("notebooks/simple_example_3.png") + + def show(): + pil_img_resized = pil_img.resize((window.width, window.height), resample=Image.NEAREST) + bytes_image = pil_img_resized.tobytes() + pgl_image = pgl.image.ImageData(pil_img_resized.width, pil_img_resized.height, + #self.window.width, self.window.height, + 'RGBA', + bytes_image, pitch=-pil_img_resized.width * 4) + pgl_image.blit(0,0) + + @window.event + def on_draw(): + print("pyglet draw event") + window.clear() + show() + print("pyglet draw event done") + + + @window.event + def on_resize(width, height): + print(f"The window was resized to {width}, {height}") + #show() + print("pyglet resize event done") + + @window.event + def on_close(): + #self.close_requested = True + print("close") + + pgl.app.run() + + +if __name__=="__main__": + #test_pyglet() + test_event_loop() \ No newline at end of file diff --git a/flatland/utils/graphics_pil.py b/flatland/utils/graphics_pil.py index 0e464f03d5060b8d8db4462ed5d6c640a967b2e3..8ab3b662e0146efcf5a18e0e566ba21a62abbab8 100644 --- a/flatland/utils/graphics_pil.py +++ b/flatland/utils/graphics_pil.py @@ -1,10 +1,10 @@ import io import os import time -import tkinter as tk +#import tkinter as tk import numpy as np -from PIL import Image, ImageDraw, ImageTk, ImageFont +from PIL import Image, ImageDraw, ImageFont from numpy import array from pkg_resources import resource_string as resource_bytes @@ -32,7 +32,7 @@ from flatland.core.grid.rail_env_grid import RailEnvTransitions # noqa: E402 class PILGL(GraphicsLayer): # tk.Tk() must be a singleton! # https://stackoverflow.com/questions/26097811/image-pyimage2-doesnt-exist - window = tk.Tk() + # window = tk.Tk() RAIL_LAYER = 0 PREDICTION_PATH_LAYER = 1 @@ -85,7 +85,7 @@ class PILGL(GraphicsLayer): self.agent_colors = [self.rgb_s2i(sColor) for sColor in sColors.split("#")] self.n_agent_colors = len(self.agent_colors) - self.window_open = False + # self.window_open = False self.firstFrame = True self.old_background_image = (None, None, None) self.create_layers() @@ -164,15 +164,10 @@ class PILGL(GraphicsLayer): self.draw_image_xy(pil_img, xyPixLeftTop, layer=layer) def open_window(self): - assert self.window_open is False, "Window is already open!" - self.__class__.window.title("Flatland") - self.__class__.window.configure(background='grey') - self.window_open = True + pass def close_window(self): - self.panel.destroy() - # quit but not destroy! - self.__class__.window.quit() + pass def text(self, xPx, yPx, strText, layer=RAIL_LAYER): xyPixLeftTop = (xPx, yPx) @@ -194,28 +189,15 @@ class PILGL(GraphicsLayer): self.create_layer(iLayer=PILGL.PREDICTION_PATH_LAYER, clear=True) def show(self, block=False): - img = self.alpha_composite_layers() - - if not self.window_open: - self.open_window() - - tkimg = ImageTk.PhotoImage(img) - - if self.firstFrame: - # Do TK actions for a new panel (not sure what they really do) - self.panel = tk.Label(self.window, image=tkimg) - self.panel.pack(side="bottom", fill="both", expand="yes") - else: - # update the image in situ - self.panel.configure(image=tkimg) - self.panel.image = tkimg - - self.__class__.window.update() - self.firstFrame = False + #print("show() - ", self.__class__) + pass def pause(self, seconds=0.00001): pass + def idle(self, seconds=0.00001): + pass + def alpha_composite_layers(self): img = self.layers[0] for img2 in self.layers[1:]: @@ -316,7 +298,7 @@ class PILSVG(PILGL): return pil_img def load_buildings(self): - dBuildingFiles = [ + lBuildingFiles = [ "Buildings-Bank.svg", "Buildings-Bar.svg", "Buildings-Wohnhaus.svg", @@ -338,13 +320,17 @@ class PILSVG(PILGL): "Buildings-Fabrik_I.svg" ] - imgBg = self.pil_from_svg_file('svg', "Background_city.svg") + imgBg = self.pil_from_svg_file('flatland.svg', "Background_city.svg") + imgBg = imgBg.convert("RGBA") + #print("imgBg mode:", imgBg.mode) - self.dBuildings = [] - for sFile in dBuildingFiles: - img = self.pil_from_svg_file('svg', sFile) + self.lBuildings = [] + for sFile in lBuildingFiles: + #print("Loading:", sFile) + img = self.pil_from_svg_file('flatland.svg', sFile) + #print("img mode:", img.mode) img = Image.alpha_composite(imgBg, img) - self.dBuildings.append(img) + self.lBuildings.append(img) def load_scenery(self): scenery_files = [ @@ -371,31 +357,31 @@ class PILSVG(PILGL): "Scenery_Water.svg" ] - img_back_ground = self.pil_from_svg_file('svg', "Background_Light_green.svg") + img_back_ground = self.pil_from_svg_file('flatland.svg', "Background_Light_green.svg").convert("RGBA") - self.scenery_background_white = self.pil_from_svg_file('svg', "Background_white.svg") + self.scenery_background_white = self.pil_from_svg_file('flatland.svg', "Background_white.svg").convert("RGBA") self.scenery = [] for file in scenery_files: - img = self.pil_from_svg_file('svg', file) + img = self.pil_from_svg_file('flatland.svg', file) img = Image.alpha_composite(img_back_ground, img) self.scenery.append(img) self.scenery_d2 = [] for file in scenery_files_d2: - img = self.pil_from_svg_file('svg', file) + img = self.pil_from_svg_file('flatland.svg', file) img = Image.alpha_composite(img_back_ground, img) self.scenery_d2.append(img) self.scenery_d3 = [] for file in scenery_files_d3: - img = self.pil_from_svg_file('svg', file) + img = self.pil_from_svg_file('flatland.svg', file) img = Image.alpha_composite(img_back_ground, img) self.scenery_d3.append(img) self.scenery_water = [] for file in scenery_files_water: - img = self.pil_from_svg_file('svg', file) + img = self.pil_from_svg_file('flatland.svg', file) img = Image.alpha_composite(img_back_ground, img) self.scenery_water.append(img) @@ -448,10 +434,10 @@ class PILSVG(PILGL): whitefilter="Background_white_filter.svg") # Load station and recolorize them - station = self.pil_from_svg_file("svg", "Bahnhof_#d50000_target.svg") + station = self.pil_from_svg_file('flatland.svg', "Bahnhof_#d50000_target.svg") self.station_colors = self.recolor_image(station, [0, 0, 0], self.agent_colors, False) - cell_occupied = self.pil_from_svg_file("svg", "Cell_occupied.svg") + cell_occupied = self.pil_from_svg_file('flatland.svg', "Cell_occupied.svg") self.cell_occupied = self.recolor_image(cell_occupied, [0, 0, 0], self.agent_colors, False) # Merge them with the regular rails. @@ -480,14 +466,14 @@ class PILSVG(PILGL): transition_16_bit_string = "".join(transition_16_bit) binary_trans = int(transition_16_bit_string, 2) - pil_rail = self.pil_from_svg_file('svg', file) + pil_rail = self.pil_from_svg_file('flatland.svg', file).convert("RGBA") if background_image is not None: - img_bg = self.pil_from_svg_file('svg', background_image) + img_bg = self.pil_from_svg_file('flatland.svg', background_image).convert("RGBA") pil_rail = Image.alpha_composite(img_bg, pil_rail) if whitefilter is not None: - img_bg = self.pil_from_svg_file('svg', whitefilter) + img_bg = self.pil_from_svg_file('flatland.svg', whitefilter).convert("RGBA") pil_rail = Image.alpha_composite(pil_rail, img_bg) if rotate: @@ -535,13 +521,13 @@ class PILSVG(PILGL): if binary_trans == 0: if self.background_grid[col][row] <= 4 + np.ceil(((col * row + col) % 10) / city_size): a = int(self.background_grid[col][row]) - a = a % len(self.dBuildings) + a = a % len(self.lBuildings) if (col + row + col * row) % 13 > 11: pil_track = self.scenery[a % len(self.scenery)] else: if (col + row + col * row) % 3 == 0: - a = (a + (col + row + col * row)) % len(self.dBuildings) - pil_track = self.dBuildings[a] + a = (a + (col + row + col * row)) % len(self.lBuildings) + pil_track = self.lBuildings[a] elif ((self.background_grid[col][row] > 5 + ((col * row + col) % 3)) or ((col ** 3 + row ** 2 + col * row) % 10 == 0)): a = int(self.background_grid[col][row]) - 4 @@ -579,7 +565,7 @@ class PILSVG(PILGL): if target is not None: if is_selected: - svgBG = self.pil_from_svg_file("svg", "Selected_Target.svg") + svgBG = self.pil_from_svg_file('flatland.svg', "Selected_Target.svg") self.clear_layer(PILGL.SELECTED_TARGET_LAYER, 0) self.draw_image_row_col(svgBG, (row, col), layer=PILGL.SELECTED_TARGET_LAYER) @@ -619,7 +605,7 @@ class PILSVG(PILGL): for directions, path_svg in file_directory.items(): in_direction, out_direction = directions - pil_zug = self.pil_from_svg_file("svg", path_svg) + pil_zug = self.pil_from_svg_file('flatland.svg', path_svg) # Rotate both the directions and the image and save in the dict for rot_direction in range(4): @@ -649,7 +635,7 @@ class PILSVG(PILGL): self.draw_image_row_col(self.scenery_background_white, (row, col), layer=PILGL.RAIL_LAYER) if is_selected: - bg_svg = self.pil_from_svg_file("svg", "Selected_Agent.svg") + bg_svg = self.pil_from_svg_file('flatland.svg', "Selected_Agent.svg") self.clear_layer(PILGL.SELECTED_AGENT_LAYER, 0) self.draw_image_row_col(bg_svg, (row, col), layer=PILGL.SELECTED_AGENT_LAYER) if show_debug: diff --git a/flatland/utils/graphics_tkpil.py b/flatland/utils/graphics_tkpil.py new file mode 100644 index 0000000000000000000000000000000000000000..7e89e734a170e746c07a4779f4f12e2ff88cf42e --- /dev/null +++ b/flatland/utils/graphics_tkpil.py @@ -0,0 +1,52 @@ + +import tkinter as tk + +from PIL import ImageTk +# from numpy import array +# from pkg_resources import resource_string as resource_bytes + +# from flatland.utils.graphics_layer import GraphicsLayer +from flatland.utils.graphics_pil import PILSVG + + +class TKPILGL(PILSVG): + # tk.Tk() must be a singleton! + # https://stackoverflow.com/questions/26097811/image-pyimage2-doesnt-exist + window = tk.Tk() + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.window_open = False + + def open_window(self): + print("open_window - tk") + assert self.window_open is False, "Window is already open!" + self.__class__.window.title("Flatland") + self.__class__.window.configure(background='grey') + self.window_open = True + + def close_window(self): + self.panel.destroy() + # quit but not destroy! + self.__class__.window.quit() + + def show(self, block=False): + # print("show - ", self.__class__) + img = self.alpha_composite_layers() + + if not self.window_open: + self.open_window() + + tkimg = ImageTk.PhotoImage(img) + + if self.firstFrame: + # Do TK actions for a new panel (not sure what they really do) + self.panel = tk.Label(self.window, image=tkimg) + self.panel.pack(side="bottom", fill="both", expand="yes") + else: + # update the image in situ + self.panel.configure(image=tkimg) + self.panel.image = tkimg + + self.__class__.window.update() + self.firstFrame = False diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index 1121a7106176aa932df587d90ecec219e8887dd8..0e980219855f5859940507be7f160b7fdfc05990 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -8,7 +8,7 @@ from numpy import array from recordtype import recordtype from flatland.utils.graphics_pil import PILGL, PILSVG - +from flatland.utils.flask_util import simple_flask_server # TODO: suggested renaming to RailEnvRenderTool, as it will only work with RailEnv! @@ -21,6 +21,160 @@ class AgentRenderVariant(IntEnum): class RenderTool(object): + """ RenderTool is a facade to a renderer, either local or browser + """ + def __init__(self, env, gl="BROWSER", jupyter=False, + agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND, + show_debug=False, clear_debug_text=True, screen_width=800, screen_height=600, + host="localhost", port=None): + + self.env = env + self.frame_nr = 0 + self.start_time = time.time() + self.times_list = deque() + + self.agent_render_variant = agent_render_variant + + if gl in ["PIL", "PILSVG", "TKPIL", "TKPILSVG", "PGL"]: + self.renderer = RenderLocal(env, gl, jupyter, + agent_render_variant, + show_debug, clear_debug_text, screen_width, screen_height) + + # To support legacy access to the GraphicsLayer (gl) + # DEPRECATED - TODO: remove these calls! + self.gl = self.renderer.gl + + elif gl == "BROWSER": + self.renderer = RenderBrowser(env, host=host, port=port) + else: + print("[", gl, "] not found, switch to PILSVG or BROWSER") + + def render_env(self, + show=False, # whether to call matplotlib show() or equivalent after completion + show_agents=True, # whether to include agents + show_inactive_agents=False, # whether to show agents before they start + show_observations=True, # whether to include observations + show_predictions=False, # whether to include predictions + frames=False, # frame counter to show (intended since invocation) + episode=None, # int episode number to show + step=None, # int step number to show in image + selected_agent=None, # indicate which agent is "selected" in the editor): + return_image=False): # indicate if image is returned for use in monitor: + return self.renderer.render_env(show, show_agents, show_inactive_agents, show_observations, + show_predictions, frames, episode, step, selected_agent, return_image) + + def close_window(self): + self.renderer.close_window() + + def reset(self): + self.renderer.reset() + + def set_new_rail(self): + self.renderer.set_new_rail() + self.renderer.env = self.env # bit of a hack - copy our env to the delegate + + def update_background(self): + self.renderer.update_background() + + def get_endpoint_URL(self): + """ Returns a string URL for the root of the HTTP server + TODO: Need to update this work work on a remote server! May be tricky... + """ + #return "http://localhost:{}".format(self.renderer.get_port()) + if hasattr(self.renderer, "get_endpoint_url"): + return self.renderer.get_endpoint_url() + else: + print("Attempt to get_endpoint_url from RenderTool - only supported with BROWSER") + return None + + def get_image(self): + """ + """ + if hasattr(self.renderer, "gl"): + return self.renderer.gl.get_image() + else: + print("Attempt to retrieve image from RenderTool - not supported with BROWSER") + return None + + + +class RenderBase(object): + def __init__(self, env): + pass + + def render_env(self): + pass + + def close_window(self): + pass + + def reset(self): + pass + + def set_new_rail(self): + """ Signal to the renderer that the env has changed and will need re-rendering. + """ + pass + + def update_background(self): + """ A lesser version of set_new_rail? + TODO: can update_background be pruned for simplicity? + """ + pass + + +class RenderBrowser(RenderBase): + def __init__(self, env, host="localhost", port=None): + self.server = simple_flask_server(env) + self.server.run_flask_server_in_thread(host=host, port=port) + self.env = env + self.background_rendered = False + + def render_env(self, + show=False, # whether to call matplotlib show() or equivalent after completion + show_agents=True, # whether to include agents + show_inactive_agents=False, + show_observations=True, # whether to include observations + show_predictions=False, # whether to include predictions + frames=False, # frame counter to show (intended since invocation) + episode=None, # int episode number to show + step=None, # int step number to show in image + selected_agent=None, # indicate which agent is "selected" in the editor): + return_image=False): # indicate if image is returned for use in monitor: + + if not self.background_rendered: + self.server.send_env_and_wait() + self.background_rendered = True + + self.server.send_actions({}) + + if show_observations: + self.render_observation(range(self.env.get_num_agents()), self.env.dev_obs_dict) + + def render_observation(self, agent_handles, dict_observation): + # Change keys to strings, and OrderedSet to list (of tuples) + dict_obs2 = {str(item[0]): list(item[1]) for item in self.env.dev_obs_dict.items()} + # Convert any ranges into a list + list_handles = list(agent_handles) + self.server.send_observation(list_handles, dict_obs2) + + def get_port(self): + return self.server.port + + def get_endpoint_url(self): + return self.server.get_endpoint_url() + + def close_window(self): + pass + + def reset(self): + pass + + def set_new_rail(self): + pass + + +class RenderLocal(RenderBase): """ Class to render the RailEnv and agents. Uses two layers, layer 0 for rails (mostly static), layer 1 for agents etc (dynamic) The lower / rail layer 0 is only redrawn after set_new_rail() has been called. @@ -50,12 +204,24 @@ class RenderTool(object): self.agent_render_variant = agent_render_variant + self.gl_str = gl + if gl == "PIL": self.gl = PILGL(env.width, env.height, jupyter, screen_width=screen_width, screen_height=screen_height) elif gl == "PILSVG": self.gl = PILSVG(env.width, env.height, jupyter, screen_width=screen_width, screen_height=screen_height) + elif gl in ["TKPILSVG", "TKPIL"]: + # Conditional import to avoid importing tkinter unless required. + print("Importing TKPILGL - requires a local display!") + from flatland.utils.graphics_tkpil import TKPILGL + self.gl = TKPILGL(env.width, env.height, jupyter, screen_width=screen_width, screen_height=screen_height) + elif gl in ["PGL"]: + # Conditional import + from flatland.utils.graphics_pgl import PGLGL + self.gl = PGLGL(env.width, env.height, jupyter, screen_width=screen_width, screen_height=screen_height) else: - print("[", gl, "] not found, switch to PILSVG") + print("[", gl, "] not found, switch to PGL, PILSVG, TKPIL (deprecated) or BROWSER") + print("Using PILSVG.") self.gl = PILSVG(env.width, env.height, jupyter, screen_width=screen_width, screen_height=screen_height) self.new_rail = True @@ -80,6 +246,7 @@ class RenderTool(object): for agent_idx, agent in enumerate(self.env.agents): if agent is None: continue + #print(f"updatebg: {agent_idx} {agent.target}") targets[tuple(agent.target)] = agent_idx self.gl.build_background_map(targets) @@ -401,31 +568,39 @@ class RenderTool(object): def render_env(self, show=False, # whether to call matplotlib show() or equivalent after completion show_agents=True, # whether to include agents + show_inactive_agents=False, show_observations=True, # whether to include observations show_predictions=False, # whether to include predictions frames=False, # frame counter to show (intended since invocation) episode=None, # int episode number to show step=None, # int step number to show in image - selected_agent=None): # indicate which agent is "selected" in the editor + selected_agent=None, # indicate which agent is "selected" in the editor + return_image=False): # indicate if image is returned for use in monitor: """ Draw the environment using the GraphicsLayer this RenderTool was created with. (Use show=False from a Jupyter notebook with %matplotlib inline) """ - if type(self.gl) is PILSVG: - self.render_env_svg(show=show, + + # if type(self.gl) is PILSVG: + if self.gl_str in ["PILSVG", "TKPIL", "TKPILSVG", "PGL"]: + return self.render_env_svg(show=show, show_observations=show_observations, show_predictions=show_predictions, selected_agent=selected_agent, - show_agents=show_agents + show_agents=show_agents, + show_inactive_agents=show_inactive_agents, + return_image=return_image ) else: - self.render_env_pil(show=show, + return self.render_env_pil(show=show, show_agents=show_agents, + show_inactive_agents=show_inactive_agents, show_observations=show_observations, show_predictions=show_predictions, frames=frames, episode=episode, step=step, - selected_agent=selected_agent + selected_agent=selected_agent, + return_image=return_image ) def _draw_square(self, center, size, color, opacity=255, layer=0): @@ -442,12 +617,14 @@ class RenderTool(object): show=False, # whether to call matplotlib show() or equivalent after completion # use false when calling from Jupyter. (and matplotlib no longer supported!) show_agents=True, # whether to include agents + show_inactive_agents=False, show_observations=True, # whether to include observations show_predictions=False, # whether to include predictions frames=False, # frame counter to show (intended since invocation) episode=None, # int episode number to show step=None, # int step number to show in image - selected_agent=None # indicate which agent is "selected" in the editor + selected_agent=None, # indicate which agent is "selected" in the editor + return_image=False # indicate if image is returned for use in monitor: ): if type(self.gl) is PILGL: @@ -495,11 +672,13 @@ class RenderTool(object): self.gl.pause(0.00001) + if return_image: + return self.get_image() return def render_env_svg( self, show=False, show_observations=True, show_predictions=False, selected_agent=None, - show_agents=True + show_agents=True, show_inactive_agents=False, return_image=False ): """ Renders the environment with SVG support (nice image) @@ -539,26 +718,55 @@ class RenderTool(object): self.gl.build_background_map(targets) + # label rows, cols + for iRow in range(env.height): + self.gl.text_rowcol((iRow, 0), str(iRow), layer=self.gl.RAIL_LAYER) + for iCol in range(env.width): + self.gl.text_rowcol((0, iCol), str(iCol), layer=self.gl.RAIL_LAYER) + + if show_agents: for agent_idx, agent in enumerate(self.env.agents): - if agent is None or agent.position is None: + if agent is None: + continue + + # Show an agent even if it hasn't already started + if show_inactive_agents and (agent.position is None): + # print("agent ", agent_idx, agent.position, agent.old_position, agent.initial_position) + self.gl.set_agent_at(agent_idx, *(agent.initial_position), + agent.initial_direction, agent.initial_direction, + is_selected=(selected_agent == agent_idx), + rail_grid=env.rail.grid, + show_debug=self.show_debug, clear_debug_text=self.clear_debug_text, + malfunction=False) continue is_malfunction = agent.malfunction_data["malfunction"] > 0 if self.agent_render_variant == AgentRenderVariant.BOX_ONLY: self.gl.set_cell_occupied(agent_idx, *(agent.position)) + elif self.agent_render_variant == AgentRenderVariant.ONE_STEP_BEHIND or \ self.agent_render_variant == AgentRenderVariant.ONE_STEP_BEHIND_AND_BOX: # noqa: E125 + + # Most common case - the agent has been running for >1 steps if agent.old_position is not None: position = agent.old_position direction = agent.direction old_direction = agent.old_direction - else: + + # the agent's first step - it doesn't have an old position yet + elif agent.position is not None: position = agent.position direction = agent.direction old_direction = agent.direction + + # When the editor has just added an agent + elif agent.initial_position is not None: + position = agent.initial_position + direction = agent.initial_direction + old_direction = agent.initial_direction # set_agent_at uses the agent index for the color if self.agent_render_variant == AgentRenderVariant.ONE_STEP_BEHIND_AND_BOX: @@ -592,12 +800,17 @@ class RenderTool(object): self.render_observation(range(env.get_num_agents()), env.dev_obs_dict) if show_predictions: self.render_prediction(range(env.get_num_agents()), env.dev_pred_dict) + + + if show: self.gl.show() for i in range(3): self.gl.process_events() self.frame_nr += 1 + if return_image: + return self.get_image() return def close_window(self): diff --git a/requirements_dev.txt b/requirements_dev.txt index b71ad94971c59c7a2ef18f038f9c498330a92b1a..4abb5acec55aab6af94cc7d553fc4d4afce0913a 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -6,7 +6,6 @@ Click>=7.0 crowdai-api>=0.1.21 numpy>=1.16.2 recordtype>=1.3 -xarray>=0.11.3 matplotlib>=3.0.2 Pillow>=5.4.1 CairoSVG>=2.3.1 @@ -21,3 +20,6 @@ timeout-decorator>=0.4.1 attrs ushlex gym==0.14.0 +flask +flask_cors +flask_socketio diff --git a/tests/test_flatland_envs_rail_env.py b/tests/test_flatland_envs_rail_env.py index 351a326a19c078e7f525c4b96ad8662777f7dd2d..07cf778925864e2f1f871bf4621d9c4a65bca220 100644 --- a/tests/test_flatland_envs_rail_env.py +++ b/tests/test_flatland_envs_rail_env.py @@ -12,14 +12,17 @@ from flatland.envs.rail_generators import complex_rail_generator, rail_from_file from flatland.envs.rail_generators import rail_from_grid_transition_map from flatland.envs.schedule_generators import random_schedule_generator, complex_schedule_generator, schedule_from_file from flatland.utils.simple_rail import make_simple_rail +from flatland.envs.persistence import RailEnvPersister """Tests for `flatland` package.""" def test_load_env(): - env = RailEnv(10, 10) - env.reset() - env.load_resource('env_data.tests', 'test-10x10.mpk') + #env = RailEnv(10, 10) + #env.reset() + # env.load_resource('env_data.tests', 'test-10x10.mpk') + #env, env_dict = RailEnvPersister.load_resource("env_data.tests", "test-10x10.mpk") + env, env_dict = RailEnvPersister.load_new("./env_data/tests/test-10x10.mpk") agent_static = EnvAgent((0, 0), 2, (5, 5), False) env.add_agent(agent_static) @@ -37,8 +40,13 @@ def test_save_load(): agent_2_pos = env.agents[1].position agent_2_dir = env.agents[1].direction agent_2_tar = env.agents[1].target - env.save("test_save.dat") - env.load("test_save.dat") + + env.save("test_save_2.pkl") + RailEnvPersister.save(env, "test_save.pkl") + + + #env.load("test_save.dat") + env, env_dict = RailEnvPersister.load_new("test_save.pkl") assert (env.width == 10) assert (env.height == 10) assert (len(env.agents) == 2) @@ -228,15 +236,20 @@ def test_rail_env_reset(): schedule_generator=random_schedule_generator(), number_of_agents=3, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env.reset() - env.save(file_name) + + #env.save(file_name) + RailEnvPersister.save(env, file_name) + dist_map_shape = np.shape(env.distance_map.get()) rails_initial = env.rail.grid agents_initial = env.agents - env2 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name), - schedule_generator=schedule_from_file(file_name), number_of_agents=1, - obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) - env2.reset(False, False, False) + #env2 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name), + # schedule_generator=schedule_from_file(file_name), number_of_agents=1, + # obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) + #env2.reset(False, False, False) + env2, env2_dict = RailEnvPersister.load_new(file_name) + rails_loaded = env2.rail.grid agents_loaded = env2.agents diff --git a/tests/test_flatland_envs_rail_env_shortest_paths.py b/tests/test_flatland_envs_rail_env_shortest_paths.py index effa1f866cda1c68d231e74cff7829e212cddba1..ef28266d6bf39446b8fb2c1f2e179e626e9eae3f 100644 --- a/tests/test_flatland_envs_rail_env_shortest_paths.py +++ b/tests/test_flatland_envs_rail_env_shortest_paths.py @@ -6,12 +6,13 @@ from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.envs.observations import GlobalObsForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env_shortest_paths import get_shortest_paths, get_k_shortest_paths -from flatland.envs.rail_env_utils import load_flatland_environment_from_file +#from flatland.envs.rail_env_utils import load_flatland_environment_from_file from flatland.envs.rail_generators import rail_from_grid_transition_map from flatland.envs.rail_trainrun_data_structures import Waypoint from flatland.envs.schedule_generators import random_schedule_generator from flatland.utils.rendertools import RenderTool from flatland.utils.simple_rail import make_disconnected_simple_rail, make_simple_rail_with_alternatives +from flatland.envs.persistence import RailEnvPersister def test_get_shortest_paths_unreachable(): @@ -41,7 +42,8 @@ def test_get_shortest_paths_unreachable(): # todo file test_002.pkl has to be generated automatically # see https://gitlab.aicrowd.com/flatland/flatland/issues/279 def test_get_shortest_paths(): - env = load_flatland_environment_from_file('test_002.pkl', 'env_data.tests') + #env = load_flatland_environment_from_file('test_002.pkl', 'env_data.tests') + env, env_dict = RailEnvPersister.load_new("./env_data/tests/test_002.mpk") env.reset() actual = get_shortest_paths(env.distance_map) @@ -96,7 +98,8 @@ def test_get_shortest_paths(): # todo file test_002.pkl has to be generated automatically # see https://gitlab.aicrowd.com/flatland/flatland/issues/279 def test_get_shortest_paths_max_depth(): - env = load_flatland_environment_from_file('test_002.pkl', 'env_data.tests') + #env = load_flatland_environment_from_file('test_002.pkl', 'env_data.tests') + env, _ = RailEnvPersister.load_new("./env_data/tests/test_002.mpk") env.reset() actual = get_shortest_paths(env.distance_map, max_depth=2) @@ -119,7 +122,8 @@ def test_get_shortest_paths_max_depth(): # todo file Level_distance_map_shortest_path.pkl has to be generated automatically # see https://gitlab.aicrowd.com/flatland/flatland/issues/279 def test_get_shortest_paths_agent_handle(): - env = load_flatland_environment_from_file('Level_distance_map_shortest_path.pkl', 'env_data.tests') + #env = load_flatland_environment_from_file('Level_distance_map_shortest_path.pkl', 'env_data.tests') + env, _ = RailEnvPersister.load_new("./env_data/tests/Level_distance_map_shortest_path.mpk") env.reset() actual = get_shortest_paths(env.distance_map, agent_handle=6) diff --git a/tests/test_generators.py b/tests/test_generators.py index 95f399fc59791a199ff48508cde47171d5bb6472..c723c194f179efcc191f80fb93a3e5370e5469c9 100644 --- a/tests/test_generators.py +++ b/tests/test_generators.py @@ -11,6 +11,7 @@ from flatland.envs.rail_generators import rail_from_grid_transition_map, rail_fr from flatland.envs.schedule_generators import random_schedule_generator, complex_schedule_generator, \ schedule_from_file from flatland.utils.simple_rail import make_simple_rail +from flatland.envs.persistence import RailEnvPersister def test_empty_rail_generator(): @@ -108,7 +109,8 @@ def tests_rail_from_file(): schedule_generator=random_schedule_generator(), number_of_agents=3, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env.reset() - env.save(file_name) + #env.save(file_name) + RailEnvPersister.save(env, file_name) dist_map_shape = np.shape(env.distance_map.get()) rails_initial = env.rail.grid agents_initial = env.agents @@ -135,7 +137,8 @@ def tests_rail_from_file(): rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=3, obs_builder_object=GlobalObsForRailEnv()) env2.reset() - env2.save(file_name_2) + #env2.save(file_name_2) + RailEnvPersister.save(env2, file_name_2) rails_initial_2 = env2.rail.grid agents_initial_2 = env2.agents diff --git a/tests/test_malfunction_generators.py b/tests/test_malfunction_generators.py index 51839babe563943a609492bcad64243d36105b5c..2593361e5922dd3078b614997e6306c1ab5549d5 100644 --- a/tests/test_malfunction_generators.py +++ b/tests/test_malfunction_generators.py @@ -4,7 +4,7 @@ from flatland.envs.rail_env import RailEnv, RailEnvActions from flatland.envs.rail_generators import rail_from_grid_transition_map from flatland.envs.schedule_generators import random_schedule_generator from flatland.utils.simple_rail import make_simple_rail2 - +from flatland.envs.persistence import RailEnvPersister def test_malfanction_from_params(): """ @@ -54,7 +54,9 @@ def test_malfanction_to_and_from_file(): malfunction_generator_and_process_data=malfunction_from_params(stochastic_data) ) env.reset() - env.save("./malfunction_saving_loading_tests.pkl") + #env.save("./malfunction_saving_loading_tests.pkl") + RailEnvPersister.save(env, "./malfunction_saving_loading_tests.pkl") + malfunction_generator, malfunction_process_data = malfunction_from_file("./malfunction_saving_loading_tests.pkl") env2 = RailEnv(width=25, diff --git a/tests/test_utils.py b/tests/test_utils.py index cae9b7fb814d63dcbd8c40678a6fb02e46d56a3b..99f731e47d488d01f281acbdc2f556b92dbf0b6d 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -11,7 +11,7 @@ from flatland.envs.rail_env import RailEnvActions, RailEnv from flatland.envs.rail_generators import RailGenerator from flatland.envs.schedule_generators import ScheduleGenerator from flatland.utils.rendertools import RenderTool - +from flatland.envs.persistence import RailEnvPersister @attrs class Replay(object): @@ -150,4 +150,6 @@ def create_and_save_env(file_name: str, schedule_generator: ScheduleGenerator, r malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), remove_agents_at_target=True) env.reset(True, True) - env.save(file_name) + #env.save(file_name) + RailEnvPersister.save(env, file_name) +