Skip to content
Snippets Groups Projects
Forked from Flatland / baselines
68 commits behind the upstream repository.
RailEnvRLLibWrapper.py 4.56 KiB
import numpy as np
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.utils.seed import seed as set_seed

from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator, random_rail_generator


class RailEnvRLLibWrapper(MultiAgentEnv):

    def __init__(self, config):

        super(MultiAgentEnv, self).__init__()

        # Environment ID if num_envs_per_worker > 1
        if hasattr(config, "vector_index"):
            vector_index = config.vector_index
        else:
            vector_index = 1

        self.predefined_env = False

        if config['rail_generator'] == "complex_rail_generator":
            self.rail_generator = complex_rail_generator(nr_start_goal=config['number_of_agents'],
                                                         min_dist=config['min_dist'],
                                                         nr_extra=config['nr_extra'],
                                                         seed=config['seed'] * (1 + vector_index))

        elif config['rail_generator'] == "random_rail_generator":
            self.rail_generator = random_rail_generator()
        elif config['rail_generator'] == "load_env":
            self.predefined_env = True
            self.rail_generator = random_rail_generator()
        else:
            raise (ValueError, f'Unknown rail generator: {config["rail_generator"]}')

        set_seed(config['seed'] * (1 + vector_index))
        self.env = RailEnv(width=config["width"], height=config["height"],
                           number_of_agents=config["number_of_agents"],
                           obs_builder_object=config['obs_builder'], rail_generator=self.rail_generator)

        if self.predefined_env:
            self.env.load_resource('torch_training.railway', 'complex_scene.pkl')

        self.width = self.env.width
        self.height = self.env.height
        self.step_memory = config["step_memory"]

        # needed for the renderer
        self.rail = self.env.rail
        self.agents = self.env.agents
        self.agents_static = self.env.agents_static
        self.dev_obs_dict = self.env.dev_obs_dict

    def reset(self):
        self.agents_done = []
        if self.predefined_env:
            obs = self.env.reset(False, False)
        else:
            obs = self.env.reset()

        # RLLib only receives observation of agents that are not done.
        o = dict()

        for i_agent in range(len(self.env.agents)):
            data, distance, agent_data = self.env.obs_builder.split_tree(tree=np.array(obs[i_agent]),
                                                                         current_depth=0)
            o[i_agent] = [data, distance, agent_data]

        # needed for the renderer
        self.rail = self.env.rail
        self.agents = self.env.agents
        self.agents_static = self.env.agents_static
        self.dev_obs_dict = self.env.dev_obs_dict

        # If step_memory > 1, we need to concatenate it the observations in memory, only works for
        # step_memory = 1 or 2 for the moment
        if self.step_memory < 2:
            return o
        else:
            self.old_obs = o
            oo = dict()

            for i_agent in range(len(self.env.agents)):
                oo[i_agent] = [o[i_agent], o[i_agent]]
            return oo

    def step(self, action_dict):
        obs, rewards, dones, infos = self.env.step(action_dict)

        d = dict()
        r = dict()
        o = dict()

        for i_agent in range(len(self.env.agents)):
            if i_agent not in self.agents_done:
                data, distance, agent_data = self.env.obs_builder.split_tree(tree=np.array(obs[i_agent]),
                                                                             current_depth=0)

                o[i_agent] = [data, distance, agent_data]
                r[i_agent] = rewards[i_agent]
                d[i_agent] = dones[i_agent]

        d['__all__'] = dones['__all__']

        if self.step_memory >= 2:
            oo = dict()

            for i_agent in range(len(self.env.agents)):
                if i_agent not in self.agents_done:
                    oo[i_agent] = [o[i_agent], self.old_obs[i_agent]]

            self.old_obs = o

        for agent, done in dones.items():
            if done and agent != '__all__':
                self.agents_done.append(agent)

        if self.step_memory < 2:
            return o, r, d, infos
        else:
            return oo, r, d, infos

    def get_agent_handles(self):
        return self.env.get_agent_handles()

    def get_num_agents(self):
        return self.env.get_num_agents()