rail_env.py 11.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
"""
Definition of the RailEnv environment and related level-generation functions.

Generator functions are functions that take width, height and num_resets as arguments and return
a GridTransitionMap object.
"""
import numpy as np

from flatland.core.env import Environment
from flatland.core.env_observation_builder import TreeObsForRailEnv
11
from flatland.envs.generators import random_rail_generator
12
from flatland.envs.env_utils import get_new_position
13
from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent
14

15
16
# from flatland.core.transitions import Grid8Transitions, RailEnvTransitions
# from flatland.core.transition_map import GridTransitionMap
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46


class RailEnv(Environment):
    """
    RailEnv environment class.

    RailEnv is an environment inspired by a (simplified version of) a rail
    network, in which agents (trains) have to navigate to their target
    locations in the shortest time possible, while at the same time cooperating
    to avoid bottlenecks.

    The valid actions in the environment are:
        0: do nothing
        1: turn left and move to the next cell
        2: move to the next cell in front of the agent
        3: turn right and move to the next cell

    Moving forward in a dead-end cell makes the agent turn 180 degrees and step
    to the cell it came from.

    The actions of the agents are executed in order of their handle to prevent
    deadlocks and to allow them to learn relative priorities.

    TODO: WRITE ABOUT THE REWARD FUNCTION, and possibly allow for alpha and
    beta to be passed as parameters to __init__().
    """

    def __init__(self,
                 width,
                 height,
spiglerg's avatar
spiglerg committed
47
                 rail_generator=random_rail_generator(),
48
49
50
51
52
53
54
55
                 number_of_agents=1,
                 obs_builder_object=TreeObsForRailEnv(max_depth=2)):
        """
        Environment init.

        Parameters
        -------
        rail_generator : function
56
57
58
59
            The rail_generator function is a function that takes the width,
            height and agents handles of a  rail environment, along with the number of times
            the env has been reset, and returns a GridTransitionMap object and a list of
            starting positions, targets, and initial orientations for agent handle.
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
            Implemented functions are:
                random_rail_generator : generate a random rail of given size
                rail_from_GridTransitionMap_generator(rail_map) : generate a rail from
                                        a GridTransitionMap object
                rail_from_manual_specifications_generator(rail_spec) : generate a rail from
                                        a rail specifications array
                TODO: generate_rail_from_saved_list or from list of ndarray bitmaps ---
        width : int
            The width of the rail map. Potentially in the future,
            a range of widths to sample from.
        height : int
            The height of the rail map. Potentially in the future,
            a range of heights to sample from.
        number_of_agents : int
            Number of agents to spawn on the map. Potentially in the future,
            a range of number of agents to sample from.
        obs_builder_object: ObservationBuilder object
            ObservationBuilder-derived object that takes builds observation
            vectors for each agent.
        """

        self.rail_generator = rail_generator
        self.rail = None
        self.width = width
        self.height = height

        self.number_of_agents = number_of_agents

        self.obs_builder = obs_builder_object
        self.obs_builder._set_env(self)

spiglerg's avatar
spiglerg committed
91
92
        self.actions = [0] * self.number_of_agents
        self.rewards = [0] * self.number_of_agents
93
94
95
96
97
98
99
100
101
102
103
        self.done = False

        self.dones = {"__all__": False}
        self.obs_dict = {}
        self.rewards_dict = {}

        self.agents_handles = list(range(self.number_of_agents))

        # self.agents_position = []
        # self.agents_target = []
        # self.agents_direction = []
hagrid67's avatar
hagrid67 committed
104
        self.agents = []
105
106
107
108
        self.num_resets = 0
        self.reset()
        self.num_resets = 0

109
110
        self.valid_positions = None

111
112
113
    def get_agent_handles(self):
        return self.agents_handles

114
    def reset(self, regen_rail=True, replace_agents=True):
hagrid67's avatar
hagrid67 committed
115
        """
116
        TODO: replace_agents is ignored at the moment; agents will always be replaced.
hagrid67's avatar
hagrid67 committed
117
        """
118
        if regen_rail or self.rail is None:
hagrid67's avatar
hagrid67 committed
119
            self.rail, agents_position, agents_direction, agents_target = self.rail_generator(
120
121
122
123
                self.width,
                self.height,
                self.agents_handles,
                self.num_resets)
124

hagrid67's avatar
hagrid67 committed
125
126
        if replace_agents:
            self.agents_static = EnvAgentStatic.from_lists(agents_position, agents_direction, agents_target)
127
            self.agents = EnvAgent.list_from_static(self.agents_static[:len(self.agents_handles)])
hagrid67's avatar
hagrid67 committed
128

129
130
        self.num_resets += 1

131
        # perhaps dones should be part of each agent.
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
        self.dones = {"__all__": False}
        for handle in self.agents_handles:
            self.dones[handle] = False

        # Reset the state of the observation builder with the new environment
        self.obs_builder.reset()

        # Return the new observation vectors for each agent
        return self._get_observations()

    def step(self, action_dict):
        alpha = 1.0
        beta = 1.0

        invalid_action_penalty = -2
        step_penalty = -1 * alpha
        global_reward = 1 * beta

        # Reset the step rewards
151
        self.rewards_dict = dict()
152
153
154
155
156
157
158
159
        for handle in self.agents_handles:
            self.rewards_dict[handle] = 0

        if self.dones["__all__"]:
            return self._get_observations(), self.rewards_dict, self.dones, {}

        for i in range(len(self.agents_handles)):
            handle = self.agents_handles[i]
160
            transition_isValid = None
161
            agent = self.agents[i]
Erik Nygren's avatar
Erik Nygren committed
162

163
            if handle not in action_dict:  # no action has been supplied for this agent
164
165
                continue

166
            if self.dones[handle]:  # this agent has already completed...
167
                continue
168
169
170
171
172
173
174
175
            action = action_dict[handle]

            if action < 0 or action > 3:
                print('ERROR: illegal action=', action,
                      'for agent with handle=', handle)
                return

            if action > 0:
176
177
                # pos = agent.position #  self.agents_position[i]
                # direction = agent.direction # self.agents_direction[i]
178

Erik Nygren's avatar
Erik Nygren committed
179
180
181
                # compute number of possible transitions in the current
                # cell used to check for invalid actions

182
                possible_transitions = self.rail.get_transitions((*agent.position, agent.direction))
183
184
                num_transitions = np.count_nonzero(possible_transitions)

185
                movement = agent.direction
hagrid67's avatar
hagrid67 committed
186
                # print(nbits,np.sum(possible_transitions))
187
                if action == 1:
188
                    movement = agent.direction - 1
189
                    if num_transitions <= 1:
190
                        transition_isValid = False
191

192
                elif action == 3:
193
                    movement = agent.direction + 1
194
                    if num_transitions <= 1:
195
                        transition_isValid = False
196

197
                movement %= 4
198
199

                if action == 2:
200
201
202
203
                    if num_transitions == 1:
                        # - dead-end, straight line or curved line;
                        # movement will be the only valid transition
                        # - take only available transition
Erik Nygren's avatar
Erik Nygren committed
204
                        movement = np.argmax(possible_transitions)
205
                        transition_isValid = True
206

207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
                new_position = get_new_position(agent.position, movement)
                # Is it a legal move?
                # 1) transition allows the movement in the cell,
                # 2) the new cell is not empty (case 0),
                # 3) the cell is free, i.e., no agent is currently in that cell
                
                # if (
                #        new_position[1] >= self.width or
                #        new_position[0] >= self.height or
                #        new_position[0] < 0 or new_position[1] < 0):
                #    new_cell_isValid = False

                # if self.rail.get_transitions(new_position) == 0:
                #     new_cell_isValid = False

                new_cell_isValid = (
                        np.array_equal(  # Check the new position is still in the grid
                            new_position,
                            np.clip(new_position, [0, 0], [self.height-1, self.width-1]))
                        and  # check the new position has some transitions (ie is not an empty cell)
                        self.rail.get_transitions(new_position) > 0)
228

Erik Nygren's avatar
Erik Nygren committed
229
                # If transition validity hasn't been checked yet.
hagrid67's avatar
hagrid67 committed
230
                if transition_isValid is None:
231
                    transition_isValid = self.rail.get_transition(
232
                        (*agent.position, agent.direction),
233
                        movement)
234

235
236
237
238
239
240
241
242
243
244
245
                # cell_isFree = True
                # for j in range(self.number_of_agents):
                #    if self.agents_position[j] == new_position:
                #        cell_isFree = False
                #        break
                # Check the new position is not the same as any of the existing agent positions
                # (including itself, for simplicity, since it is moving)
                cell_isFree = not np.any(
                        np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1))

                if all([new_cell_isValid, transition_isValid, cell_isFree]):
246
247
                    # move and change direction to face the movement that was
                    # performed
248
249
250
251
                    # self.agents_position[i] = new_position
                    # self.agents_direction[i] = movement
                    agent.position = new_position
                    agent.direction = movement
252
253
254
255
256
                else:
                    # the action was not valid, add penalty
                    self.rewards_dict[handle] += invalid_action_penalty

            # if agent is not in target position, add step penalty
257
258
259
260
            # if self.agents_position[i][0] == self.agents_target[i][0] and \
            #        self.agents_position[i][1] == self.agents_target[i][1]:
            #    self.dones[handle] = True
            if np.equal(agent.position, agent.target).all():
261
262
263
264
265
                self.dones[handle] = True
            else:
                self.rewards_dict[handle] += step_penalty

        # Check for end of episode + add global reward to all rewards!
266
267
268
269
270
271
272
        # num_agents_in_target_position = 0
        # for i in range(self.number_of_agents):
        #    if self.agents_position[i][0] == self.agents_target[i][0] and \
        #            self.agents_position[i][1] == self.agents_target[i][1]:
        #        num_agents_in_target_position += 1
        # if num_agents_in_target_position == self.number_of_agents:
        if np.all([np.array_equal(agent2.position, agent2.target) for agent2 in self.agents]):
273
            self.dones["__all__"] = True
spiglerg's avatar
spiglerg committed
274
            self.rewards_dict = [r + global_reward for r in self.rewards_dict]
275
276
277

        # Reset the step actions (in case some agent doesn't 'register_action'
        # on the next step)
spiglerg's avatar
spiglerg committed
278
        self.actions = [0] * self.number_of_agents
279
280
281
282
283
284
285
286
287
288
289
        return self._get_observations(), self.rewards_dict, self.dones, {}

    def _get_observations(self):
        self.obs_dict = {}
        for handle in self.agents_handles:
            self.obs_dict[handle] = self.obs_builder.get(handle)
        return self.obs_dict

    def render(self):
        # TODO:
        pass
290