multi_agent_training.py 35.7 KB
Newer Older
MasterScrat's avatar
MasterScrat committed
1
2
3
4
import os
import random
import sys
from argparse import ArgumentParser, Namespace
adrian_egli2's avatar
adrian_egli2 committed
5
6
from collections import deque
from datetime import datetime
MasterScrat's avatar
MasterScrat committed
7
8
from pathlib import Path
from pprint import pprint
adrian_egli2's avatar
adrian_egli2 committed
9
from typing import Optional, List
MasterScrat's avatar
MasterScrat committed
10
11

import numpy as np
adrian_egli2's avatar
adrian_egli2 committed
12
13
14
import psutil
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.envs.malfunction_generators import MalfunctionParameters, ParamMalfunctionGen
MasterScrat's avatar
MasterScrat committed
15
16
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import sparse_rail_generator
adrian_egli2's avatar
adrian_egli2 committed
17
18
19
from flatland.envs.schedule_generators import sparse_schedule_generator
from flatland.utils.rendertools import RenderTool, AgentRenderVariant
from torch.utils.tensorboard import SummaryWriter
MasterScrat's avatar
MasterScrat committed
20

adrian_egli2's avatar
adrian_egli2 committed
21
22
23
24
25
26
27
28
29
30
31
32
from reinforcement_learning.dddqn_policy import DDDQNPolicy
from reinforcement_learning.deadlockavoidance_with_decision_agent import DeadLockAvoidanceWithDecisionAgent
from reinforcement_learning.decision_point_agent import DecisionPointAgent
from reinforcement_learning.multi_decision_agent import MultiDecisionAgent
from reinforcement_learning.multi_policy import MultiPolicy
from reinforcement_learning.ppo_agent import FLATLandPPOPolicy
from utils.agent_action_config import get_flatland_full_action_size, get_action_size, map_actions, map_action, \
    set_action_size_reduced, set_action_size_full, convert_default_rail_env_action
from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent
from utils.deadlock_check import find_and_punish_deadlock
from utils.flatland_observation import FlatlandObservation, FlatlandTreeObservation, FlatlandFastTreeObservation
from utils.shortest_path_walking_agent import ShortestPathWalkingAgent
MasterScrat's avatar
MasterScrat committed
33
34
35
36
37
38

base_dir = Path(__file__).resolve().parent.parent
sys.path.append(str(base_dir))

from utils.timer import Timer

adrian_egli2's avatar
adrian_egli2 committed
39
40
try:
    import wandb
MasterScrat's avatar
MasterScrat committed
41

adrian_egli2's avatar
adrian_egli2 committed
42
43
44
    wandb.init(sync_tensorboard=True)
except ImportError:
    print("Install wandb to log to Weights & Biases")
MasterScrat's avatar
MasterScrat committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76

"""
This file shows how to train multiple agents using a reinforcement learning approach.
After training an agent, you can submit it straight away to the NeurIPS 2020 Flatland challenge!

Agent documentation: https://flatland.aicrowd.com/getting-started/rl/multi-agent.html
Submission documentation: https://flatland.aicrowd.com/getting-started/first-submission.html
"""


def create_rail_env(env_params, tree_observation):
    n_agents = env_params.n_agents
    x_dim = env_params.x_dim
    y_dim = env_params.y_dim
    n_cities = env_params.n_cities
    max_rails_between_cities = env_params.max_rails_between_cities
    max_rails_in_city = env_params.max_rails_in_city
    seed = env_params.seed

    # Break agents from time to time
    malfunction_parameters = MalfunctionParameters(
        malfunction_rate=env_params.malfunction_rate,
        min_duration=20,
        max_duration=50
    )

    return RailEnv(
        width=x_dim, height=y_dim,
        rail_generator=sparse_rail_generator(
            max_num_cities=n_cities,
            grid_mode=False,
            max_rails_between_cities=max_rails_between_cities,
adrian_egli2's avatar
adrian_egli2 committed
77
            max_rails_in_city=max_rails_in_city
MasterScrat's avatar
MasterScrat committed
78
        ),
adrian_egli2's avatar
adrian_egli2 committed
79
        schedule_generator=sparse_schedule_generator(),
MasterScrat's avatar
MasterScrat committed
80
        number_of_agents=n_agents,
adrian_egli2's avatar
adrian_egli2 committed
81
        malfunction_generator=ParamMalfunctionGen(malfunction_parameters),
MasterScrat's avatar
MasterScrat committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
        obs_builder_object=tree_observation,
        random_seed=seed
    )


def train_agent(train_params, train_env_params, eval_env_params, obs_params):
    # Environment parameters
    n_agents = train_env_params.n_agents
    x_dim = train_env_params.x_dim
    y_dim = train_env_params.y_dim
    n_cities = train_env_params.n_cities
    max_rails_between_cities = train_env_params.max_rails_between_cities
    max_rails_in_city = train_env_params.max_rails_in_city
    seed = train_env_params.seed

    # Unique ID for this training
    now = datetime.now()
    training_id = now.strftime('%y%m%d%H%M%S')

    # Observation parameters
    observation_tree_depth = obs_params.observation_tree_depth
    observation_radius = obs_params.observation_radius
    observation_max_path_depth = obs_params.observation_max_path_depth

    # Training parameters
    eps_start = train_params.eps_start
    eps_end = train_params.eps_end
    eps_decay = train_params.eps_decay
    n_episodes = train_params.n_episodes
    checkpoint_interval = train_params.checkpoint_interval
    n_eval_episodes = train_params.n_evaluation_episodes
    restore_replay_buffer = train_params.restore_replay_buffer
    save_replay_buffer = train_params.save_replay_buffer
adrian_egli2's avatar
adrian_egli2 committed
115
    skip_unfinished_agent = train_params.skip_unfinished_agent
MasterScrat's avatar
MasterScrat committed
116
117
118
119
120
121

    # Set the seeds
    random.seed(seed)
    np.random.seed(seed)

    # Observation builder
adrian_egli2's avatar
adrian_egli2 committed
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
    print("------------------------------------- CREATE OBSERVATION -----------------------------------")
    if train_params.use_observation == 'TreeObs':
        print("Using FlatlandTreeObservation (standard)")
        tree_observation = FlatlandTreeObservation(max_depth=observation_tree_depth)
    elif train_params.use_observation == 'FastTreeObs':
        print("Using FlatlandFastTreeObservation")
        tree_observation = FlatlandFastTreeObservation()
    else:  # train_params.use_observation == 'FlatlandObs':
        print("Using FlatlandObservation")
        tree_observation = FlatlandObservation(max_depth=observation_tree_depth)
    # Get the state size
    state_size = tree_observation.observation_dim

    if train_params.policy == "DeadLockAvoidance":
        print("Using SimpleObservationBuilder")

        class SimpleObservationBuilder(ObservationBuilder):
            """
            DummyObservationBuilder class which returns dummy observations
            This is used in the evaluation service
            """

            def __init__(self):
                super().__init__()

            def reset(self):
                pass

            def get_many(self, handles: Optional[List[int]] = None):
                return super().get_many(handles)

            def get(self, handle: int = 0):
                return [handle]

        tree_observation = SimpleObservationBuilder()
        tree_observation.observation_dim = 1
MasterScrat's avatar
MasterScrat committed
158
159
160
161
162
163
164

    # Setup the environments
    train_env = create_rail_env(train_env_params, tree_observation)
    train_env.reset(regenerate_schedule=True, regenerate_rail=True)
    eval_env = create_rail_env(eval_env_params, tree_observation)
    eval_env.reset(regenerate_schedule=True, regenerate_rail=True)

adrian_egli2's avatar
adrian_egli2 committed
165
    action_count = [0] * get_flatland_full_action_size()
MasterScrat's avatar
MasterScrat committed
166
167
168
169
170
171
    action_dict = dict()

    # Smoothed values used as target for hyperparameter tuning
    smoothed_eval_normalized_score = -1.0
    smoothed_eval_completion = 0.0

adrian_egli2's avatar
adrian_egli2 committed
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
    scores_window = deque(maxlen=checkpoint_interval)  # todo smooth when rendering instead
    completion_window = deque(maxlen=checkpoint_interval)
    deadlocked_window = deque(maxlen=checkpoint_interval)

    if train_params.action_size == "reduced":
        set_action_size_reduced()
    else:
        set_action_size_full()

    print("---------------------------------------- CREATE AGENT --------------------------------------")
    print('Using', train_params.policy)
    if train_params.policy == "DDDQN":
        policy = DDDQNPolicy(state_size, get_action_size(), train_params,
                             enable_delayed_transition_push_at_episode_end=False,
                             skip_unfinished_agent=skip_unfinished_agent)
    elif train_params.policy == "ShortestPathWalkingAgent":
        policy = ShortestPathWalkingAgent(train_env)
    elif train_params.policy == "PPO":
        policy = FLATLandPPOPolicy(state_size, get_action_size(),
                                   use_replay_buffer=train_params.buffer_size > 0,
                                   enable_replay_curiosity_sampling=False,
                                   in_parameters=train_params,
                                   skip_unfinished_agent=skip_unfinished_agent,
                                   K_epoch=train_params.K_epoch)
    elif train_params.policy == "PPORCS":
        policy = FLATLandPPOPolicy(state_size, get_action_size(),
                                   use_replay_buffer=train_params.buffer_size > 0,
                                   enable_replay_curiosity_sampling=True,
                                   in_parameters=train_params,
                                   skip_unfinished_agent=skip_unfinished_agent,
                                   K_epoch=train_params.K_epoch)
    elif train_params.policy == "DeadLockAvoidance":
        policy = DeadLockAvoidanceAgent(train_env, get_action_size(), enable_eps=False)
    elif train_params.policy == "DeadLockAvoidanceWithDecisionAgent":
        policy = DeadLockAvoidanceWithDecisionAgent(train_env, state_size, get_action_size(),
                                                    in_parameters=train_params)
    elif train_params.policy == "DecisionPointAgent":
        inter_policy = FLATLandPPOPolicy(state_size, get_action_size(),
                                         use_replay_buffer=train_params.buffer_size > 0,
                                         enable_replay_curiosity_sampling=True,
                                         in_parameters=train_params,
                                         skip_unfinished_agent=skip_unfinished_agent,
                                         K_epoch=train_params.K_epoch)
        policy = DecisionPointAgent(train_env, state_size, get_action_size(), inter_policy)
    elif train_params.policy == "DecisionPointAgent_DDDQN":
        inter_policy = DDDQNPolicy(state_size, get_action_size(), train_params,
                                   enable_delayed_transition_push_at_episode_end=False,
                                   skip_unfinished_agent=skip_unfinished_agent)
        policy = DecisionPointAgent(train_env, state_size, get_action_size(), inter_policy)
    elif train_params.policy == "MultiDecisionAgent":
        policy = MultiDecisionAgent(train_env, state_size, get_action_size(), train_params)
    elif train_params.policy == "MultiPolicy":
        policy = MultiPolicy(state_size, get_action_size())
    else:
        policy = FLATLandPPOPolicy(state_size, get_action_size(), use_replay_buffer=False, in_parameters=train_params)

    # make sure that at least one policy is set
    if policy is None:
        policy = DDDQNPolicy(state_size, get_action_size(), train_params)

    # Load existing policy
    if train_params.load_policy != "":
        policy.load(train_params.load_policy)
MasterScrat's avatar
MasterScrat committed
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249

    # Loads existing replay buffer
    if restore_replay_buffer:
        try:
            policy.load_replay_buffer(restore_replay_buffer)
            policy.test()
        except RuntimeError as e:
            print("\n🛑 Could't load replay buffer, were the experiences generated using the same tree depth?")
            print(e)
            exit(1)

    print("\n💾 Replay buffer status: {}/{} experiences".format(len(policy.memory.memory), train_params.buffer_size))

    hdd = psutil.disk_usage('/')
    if save_replay_buffer and (hdd.free / (2 ** 30)) < 500.0:
adrian_egli2's avatar
adrian_egli2 committed
250
251
        print("⚠️  Careful! Saving replay buffers will quickly consume a lot of disk space. You have {:.2f}gb left."
              .format(hdd.free / (2 ** 30)))
MasterScrat's avatar
MasterScrat committed
252
253

    # TensorBoard writer
adrian_egli2's avatar
adrian_egli2 committed
254
255
256
257
    writer = SummaryWriter(comment="_" +
                                   train_params.policy + "_" +
                                   train_params.use_observation + "_" +
                                   train_params.action_size)
MasterScrat's avatar
MasterScrat committed
258
259
260
261

    training_timer = Timer()
    training_timer.start()

adrian_egli2's avatar
adrian_egli2 committed
262
263
264
265
266
267
268
269
270
271
272
    print(
        "\n🚉 Training {} trains on {}x{} grid for {} episodes, evaluating {} trains on {} episodes every {} episodes. "
        "Training id '{}'.\n".format(
            train_env.get_num_agents(),
            x_dim, y_dim,
            n_episodes,
            eval_env.get_num_agents(),
            n_eval_episodes,
            checkpoint_interval,
            training_id
        ))
MasterScrat's avatar
MasterScrat committed
273
274
275

    for episode_idx in range(n_episodes + 1):
        reset_timer = Timer()
adrian_egli2's avatar
adrian_egli2 committed
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
        policy_start_episode_timer = Timer()
        policy_start_step_timer = Timer()
        policy_act_timer = Timer()
        env_step_timer = Timer()
        policy_shape_reward_timer = Timer()
        policy_step_timer = Timer()
        policy_end_step_timer = Timer()
        policy_end_episode_timer = Timer()
        total_episode_timer = Timer()

        total_episode_timer.start()

        action_count = [0] * get_flatland_full_action_size()
        agent_prev_obs = [None] * n_agents
        agent_prev_action = [convert_default_rail_env_action(RailEnvActions.STOP_MOVING)] * n_agents
        update_values = [False] * n_agents
MasterScrat's avatar
MasterScrat committed
292
293
294

        # Reset environment
        reset_timer.start()
adrian_egli2's avatar
adrian_egli2 committed
295
296
297
298
299
300
301
302
303
304
305
306
        if train_params.n_agent_fixed:
            number_of_agents = n_agents
        else:
            number_of_agents = int(min(n_agents, 1 + np.floor(max(0, episode_idx - 1) / 200)))
        if train_params.n_agent_iterate:
            train_env_params.n_agents = episode_idx % number_of_agents + 1
        else:
            train_env_params.n_agents = number_of_agents

        train_env = create_rail_env(train_env_params, tree_observation)
        agent_obs, info = train_env.reset(regenerate_rail=True, regenerate_schedule=True)
        policy.reset(train_env)
MasterScrat's avatar
MasterScrat committed
307
308
309
        reset_timer.end()

        if train_params.render:
adrian_egli2's avatar
adrian_egli2 committed
310
311
312
313
314
            # Setup renderer
            env_renderer = RenderTool(train_env, gl="PGL",
                                      show_debug=True,
                                      agent_render_variant=AgentRenderVariant.AGENT_SHOWS_OPTIONS)

MasterScrat's avatar
MasterScrat committed
315
316
317
318
319
320
321
            env_renderer.set_new_rail()

        score = 0
        nb_steps = 0
        actions_taken = []

        # Build initial agent-specific observations
adrian_egli2's avatar
adrian_egli2 committed
322
323
324
325
326
327
328
329
        for agent_handle in train_env.get_agent_handles():
            agent_prev_obs[agent_handle] = agent_obs[agent_handle].copy()

        # Max number of steps per episode
        # This is the official formula used during evaluations
        # See details in flatland.envs.schedule_generators.sparse_schedule_generator
        # max_steps = int(4 * 2 * (env.height + env.width + (n_agents / n_cities)))
        max_steps = train_env._max_episode_steps
MasterScrat's avatar
MasterScrat committed
330
331

        # Run episode
adrian_egli2's avatar
adrian_egli2 committed
332
333
334
        policy_start_episode_timer.start()
        policy.start_episode(train=True)
        policy_start_episode_timer.end()
MasterScrat's avatar
MasterScrat committed
335
        for step in range(max_steps - 1):
adrian_egli2's avatar
adrian_egli2 committed
336
337
338
339
340
341
342
343
344
345
346
347
348
349
            # policy.start_step ---------------------------------------------------------------------------------------
            policy_start_step_timer.start()
            policy.start_step(train=True)
            policy_start_step_timer.end()

            # policy.act ----------------------------------------------------------------------------------------------
            policy_act_timer.start()
            action_dict = {}
            for agent_handle in policy.get_agent_handles(train_env):
                if info['action_required'][agent_handle]:
                    update_values[agent_handle] = True
                    action = policy.act(agent_handle, agent_obs[agent_handle], eps=eps_start)
                    action_count[map_action(action)] += 1
                    actions_taken.append(map_action(action))
MasterScrat's avatar
MasterScrat committed
350
351
352
                else:
                    # An action is not required if the train hasn't joined the railway network,
                    # if it already reached its target, or if is currently malfunctioning.
adrian_egli2's avatar
adrian_egli2 committed
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
                    update_values[agent_handle] = False
                    action = convert_default_rail_env_action(RailEnvActions.DO_NOTHING)

                action_dict.update({agent_handle: action})
            policy_act_timer.end()

            # policy.end_step -----------------------------------------------------------------------------------------
            policy_end_step_timer.start()
            policy.end_step(train=True)
            policy_end_step_timer.end()

            # Environment step ----------------------------------------------------------------------------------------
            env_step_timer.start()
            next_obs, all_rewards, dones, info = train_env.step(map_actions(action_dict))
            env_step_timer.end()

            # policy.shape_reward -------------------------------------------------------------------------------------
            policy_shape_reward_timer.start()
            # Deadlock
            deadlocked_agents, all_rewards, = find_and_punish_deadlock(train_env, all_rewards, -10.0)

            # The might requires a policy based transformation
            for agent_handle in train_env.get_agent_handles():
                all_rewards[agent_handle] = policy.shape_reward(agent_handle,
                                                                action_dict[agent_handle],
                                                                agent_obs[agent_handle],
                                                                all_rewards[agent_handle],
                                                                dones[agent_handle],
                                                                deadlocked_agents[agent_handle])

            policy_shape_reward_timer.end()
MasterScrat's avatar
MasterScrat committed
384
385

            # Render an episode at some interval
adrian_egli2's avatar
adrian_egli2 committed
386
            if train_params.render:
MasterScrat's avatar
MasterScrat committed
387
388
389
                env_renderer.render_env(
                    show=True,
                    frames=False,
adrian_egli2's avatar
adrian_egli2 committed
390
                    show_observations=True,
MasterScrat's avatar
MasterScrat committed
391
392
393
394
                    show_predictions=False
                )

            # Update replay buffer and train agent
adrian_egli2's avatar
adrian_egli2 committed
395
396
            for agent_handle in train_env.get_agent_handles():
                if update_values[agent_handle] or dones['__all__'] or deadlocked_agents[agent_handle]:
MasterScrat's avatar
MasterScrat committed
397
                    # Only learn from timesteps where somethings happened
adrian_egli2's avatar
adrian_egli2 committed
398
399
400
401
402
403
404
405
                    policy_step_timer.start()
                    policy.step(agent_handle,
                                agent_prev_obs[agent_handle],
                                agent_prev_action[agent_handle],
                                all_rewards[agent_handle],
                                agent_obs[agent_handle],
                                dones[agent_handle] or (deadlocked_agents[agent_handle] > 0))
                    policy_step_timer.end()
MasterScrat's avatar
MasterScrat committed
406

adrian_egli2's avatar
adrian_egli2 committed
407
408
                    agent_prev_obs[agent_handle] = agent_obs[agent_handle].copy()
                    agent_prev_action[agent_handle] = action_dict[agent_handle]
MasterScrat's avatar
MasterScrat committed
409

adrian_egli2's avatar
adrian_egli2 committed
410
                score += all_rewards[agent_handle]
MasterScrat's avatar
MasterScrat committed
411

adrian_egli2's avatar
adrian_egli2 committed
412
413
                # update_observation (step)
                agent_obs[agent_handle] = next_obs[agent_handle].copy()
MasterScrat's avatar
MasterScrat committed
414
415
416

            nb_steps = step

adrian_egli2's avatar
adrian_egli2 committed
417
            if dones['__all__']:
MasterScrat's avatar
MasterScrat committed
418
419
                break

adrian_egli2's avatar
adrian_egli2 committed
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
            if deadlocked_agents['__all__']:
                if train_params.render_deadlocked is not None:
                    # Setup renderer
                    env_renderer = RenderTool(train_env,
                                              gl="PGL",
                                              show_debug=True,
                                              agent_render_variant=AgentRenderVariant.AGENT_SHOWS_OPTIONS,
                                              screen_width=2000,
                                              screen_height=1200)

                    env_renderer.set_new_rail()
                    env_renderer.render_env(
                        show=False,
                        frames=True,
                        show_observations=False,
                        show_predictions=False
                    )
                    env_renderer.gl.save_image("{}/flatland_{:04d}.png".format(
                        train_params.render_deadlocked,
                        episode_idx))
                    break

        # policy.end_episode
        policy_end_episode_timer.start()
        policy.end_episode(train=True)
        policy_end_episode_timer.end()

MasterScrat's avatar
MasterScrat committed
447
448
449
        # Epsilon decay
        eps_start = max(eps_end, eps_decay * eps_start)

adrian_egli2's avatar
adrian_egli2 committed
450
451
        total_episode_timer.end()

MasterScrat's avatar
MasterScrat committed
452
        # Collect information about training
adrian_egli2's avatar
adrian_egli2 committed
453
454
        tasks_finished = sum(dones[idx] for idx in train_env.get_agent_handles())
        tasks_deadlocked = sum(deadlocked_agents[idx] for idx in train_env.get_agent_handles())
MasterScrat's avatar
MasterScrat committed
455
        completion = tasks_finished / max(1, train_env.get_num_agents())
adrian_egli2's avatar
adrian_egli2 committed
456
457
458
459
460
461
462
463
464
465
        deadlocked = tasks_deadlocked / max(1, train_env.get_num_agents())
        normalized_score = score / max(1, train_env.get_num_agents())
        action_probs = action_count / max(1, np.sum(action_count))

        scores_window.append(normalized_score)
        completion_window.append(completion)
        deadlocked_window.append(deadlocked)
        smoothed_normalized_score = np.mean(scores_window)
        smoothed_completion = np.mean(completion_window)
        smoothed_deadlocked = np.mean(deadlocked_window)
MasterScrat's avatar
MasterScrat committed
466

adrian_egli2's avatar
adrian_egli2 committed
467
468
        if train_params.render:
            env_renderer.close_window()
MasterScrat's avatar
MasterScrat committed
469
470

        # Print logs
adrian_egli2's avatar
adrian_egli2 committed
471
472
        if episode_idx % checkpoint_interval == 0 and episode_idx > 0:
            policy.save('./checkpoints/' + training_id + '-' + str(episode_idx) + '.pth')
MasterScrat's avatar
MasterScrat committed
473
474
475
476

            if save_replay_buffer:
                policy.save_replay_buffer('./replay_buffers/' + training_id + '-' + str(episode_idx) + '.pkl')

adrian_egli2's avatar
adrian_egli2 committed
477
478
            # reset action count
            action_count = [0] * get_flatland_full_action_size()
MasterScrat's avatar
MasterScrat committed
479
480
481

        print(
            '\r🚂 Episode {}'
adrian_egli2's avatar
adrian_egli2 committed
482
483
484
485
486
            '\t 🚉 nAgents {:2}/{:2}'
            ' 🏆 Score: {:7.3f}'
            ' Avg: {:7.3f}'
            '\t 💯 Done: {:6.2f}%'
            ' Avg: {:6.2f}%'
MasterScrat's avatar
MasterScrat committed
487
488
489
            '\t 🎲 Epsilon: {:.3f} '
            '\t 🔀 Action Probs: {}'.format(
                episode_idx,
adrian_egli2's avatar
adrian_egli2 committed
490
                train_env_params.n_agents, number_of_agents,
MasterScrat's avatar
MasterScrat committed
491
492
493
494
495
496
497
498
499
                normalized_score,
                smoothed_normalized_score,
                100 * completion,
                100 * smoothed_completion,
                eps_start,
                format_action_prob(action_probs)
            ), end=" ")

        # Evaluate policy and log results at some interval
adrian_egli2's avatar
adrian_egli2 committed
500
501
502
503
504
505
        if episode_idx % checkpoint_interval == 0 and n_eval_episodes > 0 and episode_idx > 0:
            scores, completions, nb_steps_eval = eval_policy(eval_env,
                                                             tree_observation,
                                                             policy,
                                                             train_params,
                                                             obs_params)
MasterScrat's avatar
MasterScrat committed
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523

            writer.add_scalar("evaluation/scores_min", np.min(scores), episode_idx)
            writer.add_scalar("evaluation/scores_max", np.max(scores), episode_idx)
            writer.add_scalar("evaluation/scores_mean", np.mean(scores), episode_idx)
            writer.add_scalar("evaluation/scores_std", np.std(scores), episode_idx)
            writer.add_histogram("evaluation/scores", np.array(scores), episode_idx)
            writer.add_scalar("evaluation/completions_min", np.min(completions), episode_idx)
            writer.add_scalar("evaluation/completions_max", np.max(completions), episode_idx)
            writer.add_scalar("evaluation/completions_mean", np.mean(completions), episode_idx)
            writer.add_scalar("evaluation/completions_std", np.std(completions), episode_idx)
            writer.add_histogram("evaluation/completions", np.array(completions), episode_idx)
            writer.add_scalar("evaluation/nb_steps_min", np.min(nb_steps_eval), episode_idx)
            writer.add_scalar("evaluation/nb_steps_max", np.max(nb_steps_eval), episode_idx)
            writer.add_scalar("evaluation/nb_steps_mean", np.mean(nb_steps_eval), episode_idx)
            writer.add_scalar("evaluation/nb_steps_std", np.std(nb_steps_eval), episode_idx)
            writer.add_histogram("evaluation/nb_steps", np.array(nb_steps_eval), episode_idx)

            smoothing = 0.9
adrian_egli2's avatar
adrian_egli2 committed
524
525
            smoothed_eval_normalized_score = smoothed_eval_normalized_score * smoothing + np.mean(scores) * (
                    1.0 - smoothing)
MasterScrat's avatar
MasterScrat committed
526
527
528
529
            smoothed_eval_completion = smoothed_eval_completion * smoothing + np.mean(completions) * (1.0 - smoothing)
            writer.add_scalar("evaluation/smoothed_score", smoothed_eval_normalized_score, episode_idx)
            writer.add_scalar("evaluation/smoothed_completion", smoothed_eval_completion, episode_idx)

adrian_egli2's avatar
adrian_egli2 committed
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
        if episode_idx > 49:
            # Save logs to tensorboard
            writer.add_scalar("scene_done_training/completion_{}".format(train_env_params.n_agents),
                              np.mean(completion), episode_idx)
            writer.add_scalar("scene_dead_training/deadlocked_{}".format(train_env_params.n_agents),
                              np.mean(deadlocked), episode_idx)

            writer.add_scalar("training/score", normalized_score, episode_idx)
            writer.add_scalar("training/smoothed_score", smoothed_normalized_score, episode_idx)
            writer.add_scalar("training/completion", np.mean(completion), episode_idx)
            writer.add_scalar("training/deadlocked", np.mean(deadlocked), episode_idx)
            writer.add_scalar("training/smoothed_completion", np.mean(smoothed_completion), episode_idx)
            writer.add_scalar("training/smoothed_deadlocked", np.mean(smoothed_deadlocked), episode_idx)
            writer.add_scalar("training/nb_steps", nb_steps, episode_idx)
            writer.add_scalar("training/n_agents", train_env_params.n_agents, episode_idx)
            writer.add_histogram("actions/distribution", np.array(actions_taken), episode_idx)
            writer.add_scalar("actions/nothing", action_probs[RailEnvActions.DO_NOTHING], episode_idx)
            writer.add_scalar("actions/left", action_probs[RailEnvActions.MOVE_LEFT], episode_idx)
            writer.add_scalar("actions/forward", action_probs[RailEnvActions.MOVE_FORWARD], episode_idx)
            writer.add_scalar("actions/right", action_probs[RailEnvActions.MOVE_RIGHT], episode_idx)
            writer.add_scalar("actions/stop", action_probs[RailEnvActions.STOP_MOVING], episode_idx)
            writer.add_scalar("training/epsilon", eps_start, episode_idx)
            writer.add_scalar("training/buffer_size", len(policy.memory), episode_idx)
            writer.add_scalar("training/loss", policy.loss, episode_idx)

            writer.add_scalar("timer/00_reset", reset_timer.get(), episode_idx)
            writer.add_scalar("timer/01_policy_start_episode", policy_start_episode_timer.get(), episode_idx)
            writer.add_scalar("timer/02_policy_start_step", policy_start_step_timer.get(), episode_idx)
            writer.add_scalar("timer/03_policy_act", policy_act_timer.get(), episode_idx)
            writer.add_scalar("timer/04_env_step", env_step_timer.get(), episode_idx)
            writer.add_scalar("timer/05_policy_shape_reward", policy_shape_reward_timer.get(), episode_idx)
            writer.add_scalar("timer/06_policy_step", policy_step_timer.get(), episode_idx)
            writer.add_scalar("timer/07_policy_end_step", policy_end_step_timer.get(), episode_idx)
            writer.add_scalar("timer/08_policy_end_episode", policy_end_episode_timer.get(), episode_idx)
            writer.add_scalar("timer/09_total_episode", total_episode_timer.get_current(), episode_idx)
            writer.add_scalar("timer/10_total", training_timer.get_current(), episode_idx)

        writer.flush()
MasterScrat's avatar
MasterScrat committed
568
569
570
571
572
573
574
575
576
577
578
579
580


def format_action_prob(action_probs):
    action_probs = np.round(action_probs, 3)
    actions = ["↻", "←", "↑", "→", "◼"]

    buffer = ""
    for action, action_prob in zip(actions, action_probs):
        buffer += action + " " + "{:.3f}".format(action_prob) + " "

    return buffer


adrian_egli2's avatar
adrian_egli2 committed
581
def eval_policy(env, tree_observation, policy, train_params, obs_params):
MasterScrat's avatar
MasterScrat committed
582
583
584
585
586
587
588
589
590
591
592
    n_eval_episodes = train_params.n_evaluation_episodes
    max_steps = env._max_episode_steps

    action_dict = dict()
    scores = []
    completions = []
    nb_steps = []

    for episode_idx in range(n_eval_episodes):
        score = 0.0

adrian_egli2's avatar
adrian_egli2 committed
593
594
        agent_obs, info = env.reset(regenerate_rail=True, regenerate_schedule=True)
        policy.reset(env)
MasterScrat's avatar
MasterScrat committed
595
596
        final_step = 0

adrian_egli2's avatar
adrian_egli2 committed
597
598
599
600
601
602
603
604
        if train_params.eval_render:
            # Setup renderer
            env_renderer = RenderTool(env, gl="PGL",
                                      show_debug=True,
                                      agent_render_variant=AgentRenderVariant.AGENT_SHOWS_OPTIONS)
            env_renderer.set_new_rail()

        policy.start_episode(train=False)
MasterScrat's avatar
MasterScrat committed
605
        for step in range(max_steps - 1):
adrian_egli2's avatar
adrian_egli2 committed
606
            policy.start_step(train=False)
MasterScrat's avatar
MasterScrat committed
607
            for agent in env.get_agent_handles():
adrian_egli2's avatar
adrian_egli2 committed
608
                action = convert_default_rail_env_action(RailEnvActions.DO_NOTHING)
MasterScrat's avatar
MasterScrat committed
609
                if info['action_required'][agent]:
adrian_egli2's avatar
adrian_egli2 committed
610
                    action = policy.act(agent, agent_obs[agent], eps=0.0)
MasterScrat's avatar
MasterScrat committed
611
                action_dict.update({agent: action})
adrian_egli2's avatar
adrian_egli2 committed
612
613
            policy.end_step(train=False)
            agent_obs, all_rewards, done, info = env.step(map_actions(action_dict))
MasterScrat's avatar
MasterScrat committed
614
615
616
617
618
619
620
621
622

            for agent in env.get_agent_handles():
                score += all_rewards[agent]

            final_step = step

            if done['__all__']:
                break

adrian_egli2's avatar
adrian_egli2 committed
623
624
625
626
627
628
629
630
631
632
            # Render an episode at some interval
            if train_params.eval_render:
                env_renderer.render_env(
                    show=True,
                    frames=False,
                    show_observations=True,
                    show_predictions=False
                )

        policy.end_episode(train=False)
MasterScrat's avatar
MasterScrat committed
633
634
635
636
637
638
639
640
641
        normalized_score = score / (max_steps * env.get_num_agents())
        scores.append(normalized_score)

        tasks_finished = sum(done[idx] for idx in env.get_agent_handles())
        completion = tasks_finished / max(1, env.get_num_agents())
        completions.append(completion)

        nb_steps.append(final_step)

adrian_egli2's avatar
adrian_egli2 committed
642
643
644
645
        if train_params.eval_render:
            env_renderer.close_window()

    print(" ✅ Eval: score {:.3f} done {:.1f}%".format(np.mean(scores), np.mean(completions) * 100.0))
MasterScrat's avatar
MasterScrat committed
646
647
648
649
650
651

    return scores, completions, nb_steps


if __name__ == "__main__":
    parser = ArgumentParser()
adrian_egli2's avatar
adrian_egli2 committed
652
653
654
655
656
657
658
659
660
    parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=5000, type=int)
    parser.add_argument("--n_agent_fixed", help="hold the number of agent fixed", action='store_true')
    parser.add_argument("--n_agent_iterate", help="iterate the number of agent fixed", action='store_true')
    parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=1,
                        type=int)
    parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=3,
                        type=int)
    parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=10, type=int)
    parser.add_argument("--checkpoint_interval", help="checkpoint interval", default=200, type=int)
MasterScrat's avatar
MasterScrat committed
661
662
    parser.add_argument("--eps_start", help="max exploration", default=1.0, type=float)
    parser.add_argument("--eps_end", help="min exploration", default=0.01, type=float)
adrian_egli2's avatar
adrian_egli2 committed
663
664
    parser.add_argument("--eps_decay", help="exploration decay", default=0.99975, type=float)
    parser.add_argument("--buffer_size", help="replay buffer size", default=int(32_000), type=int)
MasterScrat's avatar
MasterScrat committed
665
666
    parser.add_argument("--buffer_min_size", help="min buffer size to start training", default=0, type=int)
    parser.add_argument("--restore_replay_buffer", help="replay buffer to restore", default="", type=str)
adrian_egli2's avatar
adrian_egli2 committed
667
668
669
    parser.add_argument("--save_replay_buffer", help="save replay buffer at each evaluation interval", default=False,
                        type=bool)
    parser.add_argument("--batch_size", help="minibatch size", default=1024, type=int)
MasterScrat's avatar
MasterScrat committed
670
    parser.add_argument("--gamma", help="discount factor", default=0.99, type=float)
adrian_egli2's avatar
adrian_egli2 committed
671
    parser.add_argument("--tau", help="soft update of target parameters", default=0.5e-3, type=float)
MasterScrat's avatar
MasterScrat committed
672
673
    parser.add_argument("--learning_rate", help="learning rate", default=0.5e-4, type=float)
    parser.add_argument("--hidden_size", help="hidden size (2 fc layers)", default=128, type=int)
adrian_egli2's avatar
adrian_egli2 committed
674
    parser.add_argument("--update_every", help="how often to update the network", default=200, type=int)
MasterScrat's avatar
MasterScrat committed
675
    parser.add_argument("--use_gpu", help="use GPU if available", default=False, type=bool)
adrian_egli2's avatar
adrian_egli2 committed
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
    parser.add_argument("--num_threads", help="number of threads PyTorch can use", default=4, type=int)

    parser.add_argument("--load_policy", help="policy filename (reference) to load", default="", type=str)
    parser.add_argument("--use_observation", help="observation name [TreeObs, FastTreeObs, FlatlandObs]",
                        default='FlatlandObs')
    parser.add_argument("--max_depth", help="max depth", default=2, type=int)
    parser.add_argument("--K_epoch", help="K_epoch", default=10, type=int)
    parser.add_argument("--skip_unfinished_agent", default=9999.0, type=float)
    parser.add_argument("--render", help="render while training", action='store_true')
    parser.add_argument("--eval_render", help="render evaluation", action='store_true')
    parser.add_argument("--render_deadlocked", default=None, type=str)
    parser.add_argument("--policy",
                        help="policy name [DDDQN, PPO, PPORCS, DecisionPointAgent, DecisionPointAgent_DDDQN,"
                             "DeadLockAvoidance, DeadLockAvoidanceWithDecisionAgent, MultiDecisionAgent, MultiPolicy]",
                        default="PPO")
    parser.add_argument("--action_size", help="define the action size [reduced,full]", default="full", type=str)
nilabha's avatar
nilabha committed
692

adrian_egli2's avatar
adrian_egli2 committed
693
    training_params = parser.parse_args()
MasterScrat's avatar
MasterScrat committed
694
695
696
    env_params = [
        {
            # Test_0
adrian_egli2's avatar
adrian_egli2 committed
697
698
699
            "n_agents": 1,
            "x_dim": 25,
            "y_dim": 25,
MasterScrat's avatar
MasterScrat committed
700
701
702
703
704
705
706
707
            "n_cities": 2,
            "max_rails_between_cities": 2,
            "max_rails_in_city": 3,
            "malfunction_rate": 1 / 50,
            "seed": 0
        },
        {
            # Test_1
adrian_egli2's avatar
adrian_egli2 committed
708
709
710
            "n_agents": 2,
            "x_dim": 25,
            "y_dim": 25,
MasterScrat's avatar
MasterScrat committed
711
712
713
            "n_cities": 2,
            "max_rails_between_cities": 2,
            "max_rails_in_city": 3,
adrian_egli2's avatar
adrian_egli2 committed
714
            "malfunction_rate": 1 / 50,
MasterScrat's avatar
MasterScrat committed
715
716
717
718
            "seed": 0
        },
        {
            # Test_2
adrian_egli2's avatar
adrian_egli2 committed
719
            "n_agents": 5,
MasterScrat's avatar
MasterScrat committed
720
721
            "x_dim": 30,
            "y_dim": 30,
adrian_egli2's avatar
adrian_egli2 committed
722
723
724
725
726
727
728
729
730
731
732
            "n_cities": 2,
            "max_rails_between_cities": 2,
            "max_rails_in_city": 3,
            "malfunction_rate": 0,
            "seed": 0
        },
        {
            # Test_3
            "n_agents": 10,
            "x_dim": 35,
            "y_dim": 35,
MasterScrat's avatar
MasterScrat committed
733
734
735
736
737
738
            "n_cities": 3,
            "max_rails_between_cities": 2,
            "max_rails_in_city": 3,
            "malfunction_rate": 1 / 200,
            "seed": 0
        },
adrian_egli2's avatar
adrian_egli2 committed
739
740
741
742
743
744
745
746
747
748
749
        {
            # Test_4
            "n_agents": 20,
            "x_dim": 40,
            "y_dim": 40,
            "n_cities": 5,
            "max_rails_between_cities": 2,
            "max_rails_in_city": 3,
            "malfunction_rate": 1 / 200,
            "seed": 0
        },
MasterScrat's avatar
MasterScrat committed
750
751
    ]

adrian_egli2's avatar
adrian_egli2 committed
752
753
754
755
756
757
758
759
760
761
762
763
obs_params = {
    "observation_tree_depth": training_params.max_depth,
    "observation_radius": 10,
    "observation_max_path_depth": 30
}


def check_env_config(id):
    if id >= len(env_params) or id < 0:
        print("\n🛑 Invalid environment configuration, only Test_0 to Test_{} are supported.".format(
            len(env_params) - 1))
        exit(1)
MasterScrat's avatar
MasterScrat committed
764
765


adrian_egli2's avatar
adrian_egli2 committed
766
767
check_env_config(training_params.training_env_config)
check_env_config(training_params.evaluation_env_config)
MasterScrat's avatar
MasterScrat committed
768

adrian_egli2's avatar
adrian_egli2 committed
769
770
training_env_params = env_params[training_params.training_env_config]
evaluation_env_params = env_params[training_params.evaluation_env_config]
MasterScrat's avatar
MasterScrat committed
771

adrian_egli2's avatar
adrian_egli2 committed
772
773
774
# FIXME hard-coded for sweep search
# see https://wb-forum.slack.com/archives/CL4V2QE59/p1602931982236600 to implement properly
# training_params.use_fast_tree_observation = True
MasterScrat's avatar
MasterScrat committed
775

adrian_egli2's avatar
adrian_egli2 committed
776
777
778
779
780
781
782
783
print("\nTraining parameters:")
pprint(vars(training_params))
print("\nTraining environment parameters (Test_{}):".format(training_params.training_env_config))
pprint(training_env_params)
print("\nEvaluation environment parameters (Test_{}):".format(training_params.evaluation_env_config))
pprint(evaluation_env_params)
print("\nObservation parameters:")
pprint(obs_params)
MasterScrat's avatar
MasterScrat committed
784

adrian_egli2's avatar
adrian_egli2 committed
785
786
787
os.environ["OMP_NUM_THREADS"] = str(training_params.num_threads)
train_agent(training_params, Namespace(**training_env_params), Namespace(**evaluation_env_params),
            Namespace(**obs_params))