Skip to content
Snippets Groups Projects
Commit de7c1afd authored by Erik Nygren's avatar Erik Nygren
Browse files

added pkl_load functionality to the rail_env.

It now sets all the info from pickle file before running all the prediction and observations.
Also added a test for file load
parent d8253ebb
No related branches found
No related tags found
No related merge requests found
...@@ -80,7 +80,7 @@ class RailEnv(Environment): ...@@ -80,7 +80,7 @@ class RailEnv(Environment):
rail_generator=random_rail_generator(), rail_generator=random_rail_generator(),
number_of_agents=1, number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2), obs_builder_object=TreeObsForRailEnv(max_depth=2),
pkl_load=None file_name=None
): ):
""" """
Environment init. Environment init.
...@@ -111,7 +111,7 @@ class RailEnv(Environment): ...@@ -111,7 +111,7 @@ class RailEnv(Environment):
obs_builder_object: ObservationBuilder object obs_builder_object: ObservationBuilder object
ObservationBuilder-derived object that takes builds observation ObservationBuilder-derived object that takes builds observation
vectors for each agent. vectors for each agent.
pkl_load: you can load a pickle file. file_name: you can load a pickle file.
""" """
self.rail_generator = rail_generator self.rail_generator = rail_generator
...@@ -121,6 +121,8 @@ class RailEnv(Environment): ...@@ -121,6 +121,8 @@ class RailEnv(Environment):
self.rewards = [0] * number_of_agents self.rewards = [0] * number_of_agents
self.done = False self.done = False
self.obs_builder = obs_builder_object
self.obs_builder._set_env(self)
self.dones = dict.fromkeys(list(range(number_of_agents)) + ["__all__"], False) self.dones = dict.fromkeys(list(range(number_of_agents)) + ["__all__"], False)
...@@ -131,13 +133,10 @@ class RailEnv(Environment): ...@@ -131,13 +133,10 @@ class RailEnv(Environment):
self.agents = [None] * number_of_agents # live agents self.agents = [None] * number_of_agents # live agents
self.agents_static = [None] * number_of_agents # static agent information self.agents_static = [None] * number_of_agents # static agent information
self.num_resets = 0 self.num_resets = 0
if pkl_load: if file_name:
self.loaded_data = pkl_load self.loaded_file = file_name
else: else:
self.loaded_data = None self.loaded_file = None
self.obs_builder = obs_builder_object
self.obs_builder._set_env(self)
self.action_space = [1] self.action_space = [1]
self.observation_space = self.obs_builder.observation_space # updated on resets? self.observation_space = self.obs_builder.observation_space # updated on resets?
...@@ -182,8 +181,8 @@ class RailEnv(Environment): ...@@ -182,8 +181,8 @@ class RailEnv(Environment):
if replace_agents: if replace_agents:
self.agents_static = EnvAgentStatic.from_lists(*tRailAgents[1:5]) self.agents_static = EnvAgentStatic.from_lists(*tRailAgents[1:5])
if self.loaded_data: if self.loaded_file:
self.load_pkl(self.loaded_data) self.load(self.loaded_file)
self.restart_agents() self.restart_agents()
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from flatland.envs.generators import rail_from_GridTransitionMap_generator, empty_rail_generator
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from tests.simple_rail import make_simple_rail
def test_load_pkl():
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_GridTransitionMap_generator(rail),
number_of_agents=3,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
env.save("test_pkl.pkl")
# initialize agents_static
obs_0 = env.reset(False, False)
file_name = "test_pkl.pkl"
env = RailEnv(width=1,
height=1,
rail_generator=empty_rail_generator(),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
file_name=file_name
)
obs_1 = env.reset(False, False)
assert obs_0 == obs_1
return
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment