training_example.py 2.99 KB
Newer Older
u214892's avatar
u214892 committed
1
2
import numpy as np

3
from flatland.envs.generators import complex_rail_generator
4
5
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import DummyPredictorForRailEnv
6
7
8
9
10
11
12
from flatland.envs.rail_env import RailEnv

np.random.seed(1)

# Use the complex_rail_generator to generate feasible network configurations with corresponding tasks
# Training on simple small tasks is the best way to get familiar with the environment
#
13
14
15
16
17
18
19

TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv())
env = RailEnv(width=20,
              height=20,
              rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
              obs_builder_object=TreeObservation,
              number_of_agents=2)
20

u214892's avatar
u214892 committed
21

22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# Import your own Agent or use RLlib to train agents on Flatland
# As an example we use a random agent here


class RandomAgent:

    def __init__(self, state_size, action_size):
        self.state_size = state_size
        self.action_size = action_size

    def act(self, state):
        """
        :param state: input is the observation of the agent
        :return: returns an action
        """
        return np.random.choice(np.arange(self.action_size))

    def step(self, memories):
        """
        Step function to improve agent by adjusting policy given the observations

        :param memories: SARS Tuple to be
        :return:
        """
        return

48
    def save(self, filename):
49
50
51
        # Store the current policy
        return

u214892's avatar
u214892 committed
52
    def load(self, filename):
53
54
        # Load a policy
        return
55

u214892's avatar
u214892 committed
56

Erik Nygren's avatar
Erik Nygren committed
57
58
# Initialize the agent with the parameters corresponding to the environment and observation_builder
agent = RandomAgent(218, 4)
59
n_trials = 5
60
61
62

# Empty dictionary for all agent action
action_dict = dict()
Erik Nygren's avatar
Erik Nygren committed
63
print("Starting Training...")
64

65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
for trials in range(1, n_trials + 1):

    # Reset environment and get initial observations for all agents
    obs = env.reset()
    # Here you can also further enhance the provided observation by means of normalization
    # See training navigation example in the baseline repository

    score = 0
    # Run episode
    for step in range(100):
        # Chose an action for each agent in the environment
        for a in range(env.get_num_agents()):
            action = agent.act(obs[a])
            action_dict.update({a: action})

        # Environment step which returns the observations for all agents, their corresponding
        # reward and whether their are done
        next_obs, all_rewards, done, _ = env.step(action_dict)
83
84
        TreeObservation.util_print_obs_subtree(next_obs[0], num_features_per_node=8)
        print(len(next_obs[0]))
85
        # Update replay buffer and train agent
86
87
88
        for a in range(env.get_num_agents()):
            agent.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a]))
            score += all_rewards[a]
89
90
91
92

        obs = next_obs.copy()
        if done['__all__']:
            break
u214892's avatar
u214892 committed
93
    print('Episode Nr. {}\t Score = {}'.format(trials, score))