Forked from
Flatland / baselines
68 commits behind the upstream repository.
RailEnvRLLibWrapper.py 4.56 KiB
import numpy as np
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.utils.seed import seed as set_seed
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator, random_rail_generator
class RailEnvRLLibWrapper(MultiAgentEnv):
def __init__(self, config):
super(MultiAgentEnv, self).__init__()
# Environment ID if num_envs_per_worker > 1
if hasattr(config, "vector_index"):
vector_index = config.vector_index
else:
vector_index = 1
self.predefined_env = False
if config['rail_generator'] == "complex_rail_generator":
self.rail_generator = complex_rail_generator(nr_start_goal=config['number_of_agents'],
min_dist=config['min_dist'],
nr_extra=config['nr_extra'],
seed=config['seed'] * (1 + vector_index))
elif config['rail_generator'] == "random_rail_generator":
self.rail_generator = random_rail_generator()
elif config['rail_generator'] == "load_env":
self.predefined_env = True
self.rail_generator = random_rail_generator()
else:
raise (ValueError, f'Unknown rail generator: {config["rail_generator"]}')
set_seed(config['seed'] * (1 + vector_index))
self.env = RailEnv(width=config["width"], height=config["height"],
number_of_agents=config["number_of_agents"],
obs_builder_object=config['obs_builder'], rail_generator=self.rail_generator)
if self.predefined_env:
self.env.load_resource('torch_training.railway', 'complex_scene.pkl')
self.width = self.env.width
self.height = self.env.height
self.step_memory = config["step_memory"]
# needed for the renderer
self.rail = self.env.rail
self.agents = self.env.agents
self.agents_static = self.env.agents_static
self.dev_obs_dict = self.env.dev_obs_dict
def reset(self):
self.agents_done = []
if self.predefined_env:
obs = self.env.reset(False, False)
else:
obs = self.env.reset()
# RLLib only receives observation of agents that are not done.
o = dict()
for i_agent in range(len(self.env.agents)):
data, distance, agent_data = self.env.obs_builder.split_tree(tree=np.array(obs[i_agent]),
current_depth=0)
o[i_agent] = [data, distance, agent_data]
# needed for the renderer
self.rail = self.env.rail
self.agents = self.env.agents
self.agents_static = self.env.agents_static
self.dev_obs_dict = self.env.dev_obs_dict
# If step_memory > 1, we need to concatenate it the observations in memory, only works for
# step_memory = 1 or 2 for the moment
if self.step_memory < 2:
return o
else:
self.old_obs = o
oo = dict()
for i_agent in range(len(self.env.agents)):
oo[i_agent] = [o[i_agent], o[i_agent]]
return oo
def step(self, action_dict):
obs, rewards, dones, infos = self.env.step(action_dict)
d = dict()
r = dict()
o = dict()
for i_agent in range(len(self.env.agents)):
if i_agent not in self.agents_done:
data, distance, agent_data = self.env.obs_builder.split_tree(tree=np.array(obs[i_agent]),
current_depth=0)
o[i_agent] = [data, distance, agent_data]
r[i_agent] = rewards[i_agent]
d[i_agent] = dones[i_agent]
d['__all__'] = dones['__all__']
if self.step_memory >= 2:
oo = dict()
for i_agent in range(len(self.env.agents)):
if i_agent not in self.agents_done:
oo[i_agent] = [o[i_agent], self.old_obs[i_agent]]
self.old_obs = o
for agent, done in dones.items():
if done and agent != '__all__':
self.agents_done.append(agent)
if self.step_memory < 2:
return o, r, d, infos
else:
return oo, r, d, infos
def get_agent_handles(self):
return self.env.get_agent_handles()
def get_num_agents(self):
return self.env.get_num_agents()