Skip to content
Snippets Groups Projects
RailEnvRLLibWrapper.py 2.39 KiB
from flatland.envs.rail_env import RailEnv
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.generators import random_rail_generator
from ray.rllib.utils.seed import seed as set_seed
import numpy as np


class RailEnvRLLibWrapper(MultiAgentEnv):

    def __init__(self, config):
                 # width,
                 # height,
                 # rail_generator=random_rail_generator(),
                 # number_of_agents=1,
                 # obs_builder_object=TreeObsForRailEnv(max_depth=2)):
        super(MultiAgentEnv, self).__init__()
        self.rail_generator = config["rail_generator"](nr_start_goal=config['number_of_agents'], min_dist=5,
                                                       nr_extra=30, seed=config['seed'] * (1+config.vector_index))
        set_seed(config['seed'] * (1+config.vector_index))
        self.env = RailEnv(width=config["width"], height=config["height"], rail_generator=self.rail_generator,
                number_of_agents=config["number_of_agents"], obs_builder_object=config['obs_builder'])
    
    def reset(self):
        self.agents_done = []
        obs = self.env.reset()
        o = dict()
        # o['agents'] = obs
        # obs[0] = [obs[0], np.ones((17, 17)) * 17]
        # obs['global_obs'] = np.ones((17, 17)) * 17
        return obs

    def step(self, action_dict):
        obs, rewards, dones, infos = self.env.step(action_dict)
        # print(obs)

        d = dict()
        r = dict()
        o = dict()
        # print(self.agents_done)
        # print(dones)
        for agent, done in dones.items():
            if agent not in self.agents_done:
                if agent != '__all__':
                    o[agent] = obs[agent]
                    r[agent] = rewards[agent]
    
                d[agent] = dones[agent]

        for agent, done in dones.items():
            if done and agent != '__all__':
                self.agents_done.append(agent)
        
        #print(obs)
        #return obs, rewards, dones, infos
        # oo = dict()
        # oo['agents'] = o
        # o['global'] = np.ones((17, 17)) * 17

        # o[0] = [o[0], np.ones((17, 17)) * 17]
        # o['global_obs'] = np.ones((17, 17)) * 17
        # r['global_obs'] = 0
        # d['global_obs'] = True
        return o, r, d, infos

    def get_agent_handles(self):
        return self.env.get_agent_handles()