flatland_env.py 12.3 KB
Newer Older
nilabha's avatar
nilabha committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
import os
import math
import numpy as np
import gym
from gym.utils import seeding
from pettingzoo import AECEnv
from pettingzoo.utils import agent_selector
from pettingzoo.utils import wrappers
from gym.utils import EzPickle
from pettingzoo.utils.conversions import to_parallel_wrapper
from flatland.envs.rail_env import RailEnv
from mava.wrappers.flatland import infer_observation_space, normalize_observation
from functools import partial
from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv


"""Adapted from 
- https://github.com/PettingZoo-Team/PettingZoo/blob/HEAD/pettingzoo/butterfly/pistonball/pistonball.py
- https://github.com/instadeepai/Mava/blob/HEAD/mava/wrappers/flatland.py
"""

def parallel_wrapper_fn(env_fn):
    def par_fn(**kwargs):
        env = env_fn(**kwargs)
        env = custom_parallel_wrapper(env)
        return env
    return par_fn

def env(**kwargs):
    env = raw_env(**kwargs)
    # env = wrappers.AssertOutOfBoundsWrapper(env)
    # env = wrappers.OrderEnforcingWrapper(env)
    return env


parallel_env = parallel_wrapper_fn(env)

class custom_parallel_wrapper(to_parallel_wrapper):
    
    def step(self, actions):
        rewards = {a: 0 for a in self.aec_env.agents}
        dones = {}
        infos = {}
        observations = {}

        for agent in self.aec_env.agents:
            try:
                assert agent == self.aec_env.agent_selection, f"expected agent {agent} got agent {self.aec_env.agent_selection}, agent order is nontrivial"
            except Exception as e:
                # print(e)
                print(self.aec_env.dones.values())
                raise e
            obs, rew, done, info = self.aec_env.last()
            self.aec_env.step(actions.get(agent,0))
            for agent in self.aec_env.agents:
                rewards[agent] += self.aec_env.rewards[agent]

        dones = dict(**self.aec_env.dones)
        infos = dict(**self.aec_env.infos)
        self.agents = self.aec_env.agents
        observations = {agent: self.aec_env.observe(agent) for agent in self.aec_env.agents}
        return observations, rewards, dones, infos

class raw_env(AECEnv, gym.Env):

    metadata = {'render.modes': ['human', "rgb_array"], 'name': "flatland_pettingzoo",
            'video.frames_per_second': 10,
            'semantics.autoreset': False }

    def __init__(self, environment = False, preprocessor = False, agent_info = False, use_renderer=False, *args, **kwargs):
        # EzPickle.__init__(self, *args, **kwargs)
        self._environment = environment
        self.use_renderer = use_renderer
        self.renderer = None
        if self.use_renderer:
            self.initialize_renderer()

        n_agents = self.num_agents
        self._agents = [get_agent_keys(i) for i in range(n_agents)]
        self._possible_agents = self.agents[:]
        self._reset_next_step = True

        self._agent_selector = agent_selector(self.agents)

        self.num_actions = 5

        self.action_spaces = {
            agent: gym.spaces.Discrete(self.num_actions) for agent in self.possible_agents
        }              

        self.seed()
        # preprocessor must be for observation builders other than global obs
        # treeobs builders would use the default preprocessor if none is
        # supplied
        self.preprocessor = self._obtain_preprocessor(preprocessor)

        self._include_agent_info = agent_info

        # observation space:
        # flatland defines no observation space for an agent. Here we try
        # to define the observation space. All agents are identical and would
        # have the same observation space.
        # Infer observation space based on returned observation
        obs, _ = self._environment.reset(regenerate_rail = False, regenerate_schedule = False)
        obs = self.preprocessor(obs)
        self.observation_spaces = {
            i: infer_observation_space(ob) for i, ob in obs.items()
        }

    
    @property
    def environment(self) -> RailEnv:
        """Returns the wrapped environment."""
        return self._environment

    @property
    def dones(self):
        dones = self._environment.dones
        # remove_all = dones.pop("__all__", None)
        return {get_agent_keys(key): value for key, value in dones.items()}    
    
    @property
    def obs_builder(self):    
        return self._environment.obs_builder    

    @property
    def width(self):    
        return self._environment.width  

    @property
    def height(self):    
        return self._environment.height  

    @property
    def agents_data(self):
        """Rail Env Agents data."""
        return self._environment.agents

    @property
    def num_agents(self) -> int:
        """Returns the number of trains/agents in the flatland environment"""
        return int(self._environment.number_of_agents)

    # def __getattr__(self, name):
    #     """Expose any other attributes of the underlying environment."""
    #     return getattr(self._environment, name)
   
    @property
    def agents(self):
        return self._agents

    @property
    def possible_agents(self):
        return self._possible_agents

    def env_done(self):
        return self._environment.dones["__all__"] or not self.agents
    
    def observe(self,agent):
        return self.obs.get(agent)
    
    def last(self, observe=True):
        '''
        returns observation, reward, done, info   for the current agent (specified by self.agent_selection)
        '''
        agent = self.agent_selection
        observation = self.observe(agent) if observe else None
        return observation, self.rewards.get(agent), self.dones.get(agent), self.infos.get(agent)
    
    def seed(self, seed: int = None) -> None:
        self._environment._seed(seed)

    def state(self):
        '''
        Returns an observation of the global environment
        '''
        return None

    def _clear_rewards(self):
        '''
        clears all items in .rewards
        '''
        # pass
        for agent in self.rewards:
            self.rewards[agent] = 0
    
    def reset(self, *args, **kwargs):
        self._reset_next_step = False
        self._agents = self.possible_agents[:]
        if self.use_renderer:
            if self.renderer: #TODO: Errors with RLLib with renderer as None.
                self.renderer.reset()
        obs, info = self._environment.reset(*args, **kwargs)
        observations = self._collate_obs_and_info(obs, info)
        self._agent_selector.reinit(self.agents)
        self.agent_selection = self._agent_selector.next()
        self.rewards = dict(zip(self.agents, [0 for _ in self.agents]))
        self._cumulative_rewards = dict(zip(self.agents, [0 for _ in self.agents]))
        self.action_dict = {get_agent_handle(i):0 for i in self.possible_agents}

        return observations

    def step(self, action):

        if self.env_done():
            self._agents = []
            self._reset_next_step = True
            return self.last()
        
        agent = self.agent_selection
        self.action_dict[get_agent_handle(agent)] = action

        if self.dones[agent]:
            # Disabled.. In case we want to remove agents once done
            # if self.remove_agents:
            #     self.agents.remove(agent)
            if self._agent_selector.is_last():
                observations, rewards, dones, infos = self._environment.step(self.action_dict)
                self.rewards = {get_agent_keys(key): value for key, value in rewards.items()}
                if observations:
                    observations = self._collate_obs_and_info(observations, infos)
                self._accumulate_rewards()
                obs, cumulative_reward, done, info = self.last()
                self.agent_selection = self._agent_selector.next()

            else:
                self._clear_rewards()
                obs, cumulative_reward, done, info = self.last()
                self.agent_selection = self._agent_selector.next()

            return obs, cumulative_reward, done, info

        if self._agent_selector.is_last():
            observations, rewards, dones, infos = self._environment.step(self.action_dict)
            self.rewards = {get_agent_keys(key): value for key, value in rewards.items()}
            if observations:
                observations = self._collate_obs_and_info(observations, infos)
    
        else:
            self._clear_rewards()
        
        # self._cumulative_rewards[agent] = 0
        self._accumulate_rewards()

        obs, cumulative_reward, done, info = self.last()
        
        self.agent_selection = self._agent_selector.next()

        return obs, cumulative_reward, done, info


    # collate agent info and observation into a tuple, making the agents obervation to
    # be a tuple of the observation from the env and the agent info
    def _collate_obs_and_info(self, observes, info):
        observations = {}
        infos = {}
        observes = self.preprocessor(observes)
        for agent, obs in observes.items():
            all_infos = {k: info[k][get_agent_handle(agent)] for k in info.keys()}
            agent_info = np.array(
                list(all_infos.values()), dtype=np.float32
            )
            infos[agent] = all_infos
            obs = (obs, agent_info) if self._include_agent_info else obs
            observations[agent] = obs

        self.infos = infos
        self.obs = observations
        return observations   


    def render(self, mode='human'):
        """
        This methods provides the option to render the
        environment's behavior to a window which should be
        readable to the human eye if mode is set to 'human'.
        """
        if not self.use_renderer:
            return

        if not self.renderer:
            self.initialize_renderer(mode=mode)

        return self.update_renderer(mode=mode)

    def initialize_renderer(self, mode="human"):
        # Initiate the renderer
        from flatland.utils.rendertools import RenderTool, AgentRenderVariant
        self.renderer = RenderTool(self.environment, gl="PGL",  # gl="TKPILSVG",
                                       agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
                                       show_debug=False,
                                       screen_height=600,  # Adjust these parameters to fit your resolution
                                       screen_width=800)  # Adjust these parameters to fit your resolution
        self.renderer.show = False

    def update_renderer(self, mode='human'):
        image = self.renderer.render_env(show=False, show_observations=False, show_predictions=False,
                                             return_image=True)
        return image[:,:,:3]

    def set_renderer(self, renderer):
        self.use_renderer = renderer
        if self.use_renderer:
            self.initialize_renderer(mode=self.use_renderer)

    def close(self):
        # self._environment.close()
        if self.renderer:
            try:
                if self.renderer.show:
                    self.renderer.close_window()
            except Exception as e:
                print("Could Not close window due to:",e)
            self.renderer = None

    def _obtain_preprocessor(
        self, preprocessor):
        """Obtains the actual preprocessor to be used based on the supplied
        preprocessor and the env's obs_builder object"""
        if not isinstance(self.obs_builder, GlobalObsForRailEnv):
            _preprocessor = preprocessor if preprocessor else lambda x: x
            if isinstance(self.obs_builder, TreeObsForRailEnv):
                _preprocessor = (
                    partial(
                        normalize_observation, tree_depth=self.obs_builder.max_depth
                    )
                    if not preprocessor
                    else preprocessor
                )
            assert _preprocessor is not None
        else:
            def _preprocessor(x):
                    return x

        def returned_preprocessor(obs):
            temp_obs = {}
            for agent_id, ob in obs.items():
                temp_obs[get_agent_keys(agent_id)] = _preprocessor(ob)
            return temp_obs

        return returned_preprocessor

# Utility functions   
def convert_np_type(dtype, value):
    return np.dtype(dtype).type(value) 

def get_agent_handle(id):
    """Obtain an agents handle given its id"""
    return int(id)

def get_agent_keys(id):
    """Obtain an agents handle given its id"""
    return str(id)