diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 3d97119e7f7289008525265d412ae5255dbafc69..2a00f2f07c2e0ea9679f8dc51d17d696f32f7ce6 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -20,17 +20,23 @@ from flatland.core.grid.grid_utils import IntVector2D from flatland.core.transition_map import GridTransitionMap from flatland.envs.agent_utils import EnvAgent, RailAgentStatus from flatland.envs.distance_map import DistanceMap + +# Need to use circular imports for persistence. from flatland.envs import malfunction_generators as mal_gen from flatland.envs import rail_generators as rail_gen from flatland.envs import schedule_generators as sched_gen +from flatland.envs import persistence + # Direct import of objects / classes does not work with circular imports. # from flatland.envs.malfunction_generators import no_malfunction_generator, Malfunction, MalfunctionProcessData # from flatland.envs.observations import GlobalObsForRailEnv # from flatland.envs.rail_generators import random_rail_generator, RailGenerator # from flatland.envs.schedule_generators import random_schedule_generator, ScheduleGenerator + from flatland.envs.observations import GlobalObsForRailEnv + import pickle m.patch() @@ -251,6 +257,8 @@ class RailEnv(Environment): agent.reset() self.active_agents = [i for i in range(len(self.agents))] + + def action_required(self, agent): """ Check if an agent needs to provide an action @@ -314,6 +322,8 @@ class RailEnv(Environment): if optionals and 'distance_map' in optionals: self.distance_map.set(optionals['distance_map']) + + if regenerate_schedule or regenerate_rail or self.get_num_agents() == 0: agents_hints = None if optionals and 'agents_hints' in optionals: @@ -792,6 +802,7 @@ class RailEnv(Environment): ------ Dict object """ + #print(f"_get_obs - num agents: {self.get_num_agents()} {list(range(self.get_num_agents()))}") self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents()))) return self.obs_dict @@ -836,3 +847,9 @@ class RailEnv(Environment): """ return agent.malfunction_data['malfunction'] < 1 + + def save(self, filename): + print("deprecated call to env.save() - pls call RailEnvPersister.save()") + persistence.RailEnvPersister.save(self, filename) + +