diff --git a/flatland/core/env.py b/flatland/core/env.py index 950365fd52fffca27b3a1f118d910628533cb41d..a7e63fd4e2697996a4dbe0a684735fd46153fb1b 100644 --- a/flatland/core/env.py +++ b/flatland/core/env.py @@ -5,6 +5,8 @@ The base Environment class is adapted from rllib.env.MultiAgentEnv """ import random +from .env_observation_builder import TreeObsForRailEnv + class Environment: """ @@ -118,7 +120,10 @@ class RailEnv: beta to be passed as parameters to __init__(). """ - def __init__(self, rail, number_of_agents=1): + def __init__(self, + rail, + number_of_agents=1, + custom_observation_builder=TreeObsForRailEnv): """ Environment init. @@ -128,6 +133,9 @@ class RailEnv: The transition matrix that defines the environment. number_of_agents : int Number of agents to spawn on the map. + custom_observation_builder: ObservationBuilder object + ObservationBuilder-derived object that takes this env object + as input as provides observation vectors for each agent. """ self.rail = rail @@ -136,10 +144,16 @@ class RailEnv: self.number_of_agents = number_of_agents + self.obs_builder = custom_observation_builder(env=self) + self.actions = [0]*self.number_of_agents self.rewards = [0]*self.number_of_agents self.done = False + self.dones = {"__all__": False} + self.obs_dict = {} + self.rewards_dict = {} + self.agents_handles = list(range(self.number_of_agents)) def get_agent_handles(self): @@ -192,10 +206,11 @@ class RailEnv: self.agents_direction[i] = random.sample( valid_starting_directions, 1)[0] - obs_dict = {} - for handle in self.agents_handles: - obs_dict[handle] = self._get_observation_for_agent(handle) - return obs_dict + # Reset the state of the observation builder with the new environment + self.obs_builder.reset() + + # Return the new observation vectors for each agent + return self._get_observations() def step(self, action_dict): alpha = 1.0 @@ -206,15 +221,12 @@ class RailEnv: global_reward = 1 * beta # Reset the step rewards - rewards_dict = {} + self.rewards_dict = {} for handle in self.agents_handles: - rewards_dict[handle] = 0 + self.rewards_dict[handle] = 0 if self.dones["__all__"]: - obs_dict = {} - for handle in self.agents_handles: - obs_dict[handle] = self._get_observation_for_agent(handle) - return obs_dict, rewards_dict, self.dones, {} + return self._get_observations(), self.rewards_dict, self.dones, {} for i in range(len(self.agents_handles)): handle = self.agents_handles[i] @@ -307,14 +319,14 @@ class RailEnv: self.agents_direction[i] = movement else: # the action was not valid, add penalty - rewards_dict[handle] += invalid_action_penalty + self.rewards_dict[handle] += invalid_action_penalty # if agent is not in target position, add step penalty if self.agents_position[i][0] == self.agents_target[i][0] and \ self.agents_position[i][1] == self.agents_target[i][1]: self.dones[handle] = True else: - rewards_dict[handle] += step_penalty + self.rewards_dict[handle] += step_penalty # Check for end of episode + add global reward to all rewards! num_agents_in_target_position = 0 @@ -325,17 +337,13 @@ class RailEnv: if num_agents_in_target_position == self.number_of_agents: self.dones["__all__"] = True - rewards_dict = [r+global_reward for r in rewards_dict] + self.rewards_dict = [r+global_reward for r in self.rewards_dict] # Reset the step actions (in case some agent doesn't 'register_action' # on the next step) self.actions = [0]*self.number_of_agents - obs_dict = {} - for handle in self.agents_handles: - obs_dict[handle] = self._get_observation_for_agent(handle) - - return obs_dict, rewards_dict, self.dones, {} + return self._get_observations(), self.rewards_dict, self.dones, {} def _new_position(self, position, movement): if movement == 0: # NORTH @@ -376,9 +384,11 @@ class RailEnv: return 0 - def _get_observation_for_agent(self, handle): - # TODO: - return None + def _get_observations(self): + self.obs_dict = {} + for handle in self.agents_handles: + self.obs_dict[handle] = self.obs_builder.get(handle) + return self.obs_dict def render(self): # TODO: diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..1706185425e38acd56c29a52ca7d644fac6024a3 --- /dev/null +++ b/flatland/core/env_observation_builder.py @@ -0,0 +1,27 @@ +## TODO: add docstrings, pylint, etc... + + +class ObservationBuilder: + def __init__(self, env): + self.env = env + + def reset(self): + raise NotImplementedError() + + def get(self, handle): + raise NotImplementedError() + + + +class TreeObsForRailEnv(ObservationBuilder): + def reset(self): + # TODO: precompute distances, etc... + #raise NotImplementedError() + pass + + def get(self, handle): + # TODO: compute the observation for agent `handle' + #raise NotImplementedError() + return [] + + diff --git a/flatland/core/transitionmap.py b/flatland/core/transition_map.py similarity index 100% rename from flatland/core/transitionmap.py rename to flatland/core/transition_map.py diff --git a/flatland/utils/rail_env_generator.py b/flatland/utils/rail_env_generator.py index 5b292f03833732699482728a7f5e3fe9358f3a07..69e5b831be67c6a808e22b5413255789acf90f27 100644 --- a/flatland/utils/rail_env_generator.py +++ b/flatland/utils/rail_env_generator.py @@ -6,7 +6,7 @@ import random import numpy as np from flatland.core.transitions import RailEnvTransitions -from flatland.core.transitionmap import GridTransitionMap +from flatland.core.transition_map import GridTransitionMap def generate_rail_from_manual_specifications(rail_spec): diff --git a/tests/test_environments.py b/tests/test_environments.py index 66e6bef404939ebefa9596d0661348da899ee16d..03544b08196a609bc6d4ed92c393f0a72bfbab8c 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -3,7 +3,7 @@ from flatland.core.env import RailEnv from flatland.core.transitions import Grid4Transitions -from flatland.core.transitionmap import GridTransitionMap +from flatland.core.transition_map import GridTransitionMap import numpy as np """Tests for `flatland` package."""