Skip to content
Snippets Groups Projects
Commit 9b82e9a4 authored by spiglerg's avatar spiglerg
Browse files

Merge branch 'issue93' into 'master'

Issue93

Closes #93

See merge request flatland/flatland!106
parents f05d3345 74eaa920
No related branches found
No related tags found
No related merge requests found
...@@ -83,6 +83,7 @@ class RailEnv(Environment): ...@@ -83,6 +83,7 @@ class RailEnv(Environment):
rail_generator=random_rail_generator(), rail_generator=random_rail_generator(),
number_of_agents=1, number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2), obs_builder_object=TreeObsForRailEnv(max_depth=2),
max_episode_steps=None
): ):
""" """
Environment init. Environment init.
...@@ -113,6 +114,8 @@ class RailEnv(Environment): ...@@ -113,6 +114,8 @@ class RailEnv(Environment):
obs_builder_object: ObservationBuilder object obs_builder_object: ObservationBuilder object
ObservationBuilder-derived object that takes builds observation ObservationBuilder-derived object that takes builds observation
vectors for each agent. vectors for each agent.
max_episode_steps : int or None
file_name: you can load a pickle file. file_name: you can load a pickle file.
""" """
...@@ -126,6 +129,9 @@ class RailEnv(Environment): ...@@ -126,6 +129,9 @@ class RailEnv(Environment):
self.obs_builder = obs_builder_object self.obs_builder = obs_builder_object
self.obs_builder._set_env(self) self.obs_builder._set_env(self)
self._max_episode_steps = max_episode_steps
self._elapsed_steps = 0
self.dones = dict.fromkeys(list(range(number_of_agents)) + ["__all__"], False) self.dones = dict.fromkeys(list(range(number_of_agents)) + ["__all__"], False)
self.obs_dict = {} self.obs_dict = {}
...@@ -191,6 +197,7 @@ class RailEnv(Environment): ...@@ -191,6 +197,7 @@ class RailEnv(Environment):
agent.speed_data['position_fraction'] = 0.0 agent.speed_data['position_fraction'] = 0.0
self.num_resets += 1 self.num_resets += 1
self._elapsed_steps = 0
# TODO perhaps dones should be part of each agent. # TODO perhaps dones should be part of each agent.
self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False) self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
...@@ -203,6 +210,8 @@ class RailEnv(Environment): ...@@ -203,6 +210,8 @@ class RailEnv(Environment):
return self._get_observations() return self._get_observations()
def step(self, action_dict_): def step(self, action_dict_):
self._elapsed_steps += 1
action_dict = action_dict_.copy() action_dict = action_dict_.copy()
alpha = 1.0 alpha = 1.0
...@@ -330,6 +339,11 @@ class RailEnv(Environment): ...@@ -330,6 +339,11 @@ class RailEnv(Environment):
self.dones["__all__"] = True self.dones["__all__"] = True
self.rewards_dict = {i: 0 * r + global_reward for i, r in self.rewards_dict.items()} self.rewards_dict = {i: 0 * r + global_reward for i, r in self.rewards_dict.items()}
if (self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps):
self.dones["__all__"] = True
for k in self.dones.keys():
self.dones[k] = True
return self._get_observations(), self.rewards_dict, self.dones, {} return self._get_observations(), self.rewards_dict, self.dones, {}
def _check_action_on_agent(self, action, agent): def _check_action_on_agent(self, action, agent):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment