Skip to content
Snippets Groups Projects
Commit c1b86131 authored by spiglerg's avatar spiglerg
Browse files

closing issue #93

parent 2c17b423
No related branches found
No related tags found
No related merge requests found
......@@ -80,6 +80,7 @@ class RailEnv(Environment):
rail_generator=random_rail_generator(),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2),
max_episode_steps = None
):
"""
Environment init.
......@@ -110,6 +111,8 @@ class RailEnv(Environment):
obs_builder_object: ObservationBuilder object
ObservationBuilder-derived object that takes builds observation
vectors for each agent.
max_episode_steps : int or None
file_name: you can load a pickle file.
"""
......@@ -123,6 +126,9 @@ class RailEnv(Environment):
self.obs_builder = obs_builder_object
self.obs_builder._set_env(self)
self._max_episode_steps = max_episode_steps
self._elapsed_steps = 0
self.dones = dict.fromkeys(list(range(number_of_agents)) + ["__all__"], False)
self.obs_dict = {}
......@@ -184,6 +190,7 @@ class RailEnv(Environment):
agent.speed_data['position_fraction'] = 0.0
self.num_resets += 1
self._elapsed_steps = 0
# TODO perhaps dones should be part of each agent.
self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
......@@ -196,6 +203,8 @@ class RailEnv(Environment):
return self._get_observations()
def step(self, action_dict_):
self._elapsed_steps += 1
action_dict = action_dict_.copy()
alpha = 1.0
......@@ -323,6 +332,11 @@ class RailEnv(Environment):
self.dones["__all__"] = True
self.rewards_dict = {i: 0 * r + global_reward for i, r in self.rewards_dict.items()}
if (self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps):
self.dones["__all__"] = True
for k in self.dones.keys():
self.dones[k] = True
return self._get_observations(), self.rewards_dict, self.dones, {}
def _check_action_on_agent(self, action, agent):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment