Skip to content
Snippets Groups Projects
Commit 7ccb450c authored by Christian Eichenberger's avatar Christian Eichenberger :badminton:
Browse files

Merge branch '98_load_env_from_file' into 'master'

98 load env from file

See merge request flatland/flatland!101
parents 3729c9d2 b7020a0e
No related branches found
No related tags found
No related merge requests found
...@@ -172,18 +172,21 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -172,18 +172,21 @@ class TreeObsForRailEnv(ObservationBuilder):
if handles is None: if handles is None:
handles = [] handles = []
if self.predictor: if self.predictor:
self.max_prediction_depth = 0
self.predicted_pos = {} self.predicted_pos = {}
self.predicted_dir = {} self.predicted_dir = {}
self.predictions = self.predictor.get(custom_args={'distance_map': self.distance_map}) self.predictions = self.predictor.get(custom_args={'distance_map': self.distance_map})
for t in range(len(self.predictions[0])): if self.predictions:
pos_list = []
dir_list = [] for t in range(len(self.predictions[0])):
for a in handles: pos_list = []
pos_list.append(self.predictions[a][t][1:3]) dir_list = []
dir_list.append(self.predictions[a][t][3]) for a in handles:
self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)}) pos_list.append(self.predictions[a][t][1:3])
self.predicted_dir.update({t: dir_list}) dir_list.append(self.predictions[a][t][3])
self.max_prediction_depth = len(self.predicted_pos) self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)})
self.predicted_dir.update({t: dir_list})
self.max_prediction_depth = len(self.predicted_pos)
observations = {} observations = {}
for h in handles: for h in handles:
observations[h] = self.get(h) observations[h] = self.get(h)
......
...@@ -80,6 +80,7 @@ class RailEnv(Environment): ...@@ -80,6 +80,7 @@ class RailEnv(Environment):
rail_generator=random_rail_generator(), rail_generator=random_rail_generator(),
number_of_agents=1, number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2), obs_builder_object=TreeObsForRailEnv(max_depth=2),
file_name=None
): ):
""" """
Environment init. Environment init.
...@@ -110,6 +111,7 @@ class RailEnv(Environment): ...@@ -110,6 +111,7 @@ class RailEnv(Environment):
obs_builder_object: ObservationBuilder object obs_builder_object: ObservationBuilder object
ObservationBuilder-derived object that takes builds observation ObservationBuilder-derived object that takes builds observation
vectors for each agent. vectors for each agent.
file_name: you can load a pickle file.
""" """
self.rail_generator = rail_generator self.rail_generator = rail_generator
...@@ -117,14 +119,10 @@ class RailEnv(Environment): ...@@ -117,14 +119,10 @@ class RailEnv(Environment):
self.width = width self.width = width
self.height = height self.height = height
self.obs_builder = obs_builder_object
self.obs_builder._set_env(self)
self.action_space = [1]
self.observation_space = self.obs_builder.observation_space # updated on resets?
self.rewards = [0] * number_of_agents self.rewards = [0] * number_of_agents
self.done = False self.done = False
self.obs_builder = obs_builder_object
self.obs_builder._set_env(self)
self.dones = dict.fromkeys(list(range(number_of_agents)) + ["__all__"], False) self.dones = dict.fromkeys(list(range(number_of_agents)) + ["__all__"], False)
...@@ -135,6 +133,14 @@ class RailEnv(Environment): ...@@ -135,6 +133,14 @@ class RailEnv(Environment):
self.agents = [None] * number_of_agents # live agents self.agents = [None] * number_of_agents # live agents
self.agents_static = [None] * number_of_agents # static agent information self.agents_static = [None] * number_of_agents # static agent information
self.num_resets = 0 self.num_resets = 0
if file_name:
self.loaded_file = file_name
else:
self.loaded_file = None
self.action_space = [1]
self.observation_space = self.obs_builder.observation_space # updated on resets?
self.reset() self.reset()
self.num_resets = 0 # yes, set it to zero again! self.num_resets = 0 # yes, set it to zero again!
...@@ -175,6 +181,9 @@ class RailEnv(Environment): ...@@ -175,6 +181,9 @@ class RailEnv(Environment):
if replace_agents: if replace_agents:
self.agents_static = EnvAgentStatic.from_lists(*tRailAgents[1:5]) self.agents_static = EnvAgentStatic.from_lists(*tRailAgents[1:5])
if self.loaded_file:
self.load(self.loaded_file)
self.restart_agents() self.restart_agents()
for i_agent in range(self.get_num_agents()): for i_agent in range(self.get_num_agents()):
...@@ -425,6 +434,9 @@ class RailEnv(Environment): ...@@ -425,6 +434,9 @@ class RailEnv(Environment):
load_data = file_in.read() load_data = file_in.read()
self.set_full_state_msg(load_data) self.set_full_state_msg(load_data)
def load_pkl(self, pkl_data):
self.set_full_state_msg(pkl_data)
def load_resource(self, package, resource): def load_resource(self, package, resource):
from importlib_resources import read_binary from importlib_resources import read_binary
load_data = read_binary(package, resource) load_data = read_binary(package, resource)
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import numpy as np
from flatland.envs.generators import rail_from_GridTransitionMap_generator, empty_rail_generator
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from tests.simple_rail import make_simple_rail
def test_load_pkl():
file_name = "test_pkl.pkl"
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_GridTransitionMap_generator(rail),
number_of_agents=3,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
env.save(file_name)
# initialize agents_static
rails_initial = env.rail.grid
agents_initial = env.agents
env = RailEnv(width=1,
height=1,
rail_generator=empty_rail_generator(),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
file_name=file_name
)
rails_loaded = env.rail.grid
agents_loaded = env.agents
assert np.all(np.array_equal(rails_initial, rails_loaded))
assert agents_initial == agents_loaded
return
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment