flatland_2_0_example.py 4.61 KB
Newer Older
1
2
import numpy as np
from flatland.envs.generators import sparse_rail_generator
3

4
5
6
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
7
from flatland.envs.schedule_generators import sparse_schedule_generator
8
9
10
11
from flatland.utils.rendertools import RenderTool

np.random.seed(1)

Egli Adrian (IT-SCI-API-PFI)'s avatar
Egli Adrian (IT-SCI-API-PFI) committed
12
# Use the new sparse_rail_generator to generate feasible network configurations with corresponding tasks
13
14
# Training on simple small tasks is the best way to get familiar with the environment

15
# Use a the malfunction generator to break agents from time to time
Erik Nygren's avatar
Erik Nygren committed
16
17
18
19
20
stochastic_data = {'prop_malfunction': 0.5,  # Percentage of defective agents
                   'malfunction_rate': 30,  # Rate of malfunction occurence
                   'min_duration': 3,  # Minimal duration of malfunction
                   'max_duration': 10  # Max duration of malfunction
                   }
21

22
TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
23
24
env = RailEnv(width=20,
              height=20,
25
              rail_generator=sparse_rail_generator(num_cities=2,  # Number of cities in map (where train stations are)
Egli Adrian (IT-SCI-API-PFI)'s avatar
Egli Adrian (IT-SCI-API-PFI) committed
26
                                                   num_intersections=1,  # Number of intersections (no start / target)
27
                                                   num_trainstations=15,  # Number of possible start/targets on map
Erik Nygren's avatar
Erik Nygren committed
28
                                                   min_node_dist=3,  # Minimal distance of nodes
29
                                                   node_radius=3,  # Proximity of stations to city center
30
                                                   num_neighb=2,  # Number of connections to other cities/intersections
31
                                                   seed=15,  # Random seed
32
33
                                                   realistic_mode=True,
                                                   enhance_intersection=True
34
                                                   ),
35
              schedule_generator=sparse_schedule_generator(),
Erik Nygren's avatar
Erik Nygren committed
36
              number_of_agents=5,
Egli Adrian (IT-SCI-API-PFI)'s avatar
Egli Adrian (IT-SCI-API-PFI) committed
37
              stochastic_data=stochastic_data,  # Malfunction data generator
38
39
40
41
42
43
              obs_builder_object=TreeObservation)

env_renderer = RenderTool(env, gl="PILSVG", )


# Import your own Agent or use RLlib to train agents on Flatland
Egli Adrian (IT-SCI-API-PFI)'s avatar
Egli Adrian (IT-SCI-API-PFI) committed
44
# As an example we use a random agent instead
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
class RandomAgent:

    def __init__(self, state_size, action_size):
        self.state_size = state_size
        self.action_size = action_size

    def act(self, state):
        """
        :param state: input is the observation of the agent
        :return: returns an action
        """
        return np.random.choice(np.arange(self.action_size))

    def step(self, memories):
        """
        Step function to improve agent by adjusting policy given the observations

        :param memories: SARS Tuple to be
        :return:
        """
        return

    def save(self, filename):
        # Store the current policy
        return

    def load(self, filename):
        # Load a policy
        return


# Initialize the agent with the parameters corresponding to the environment and observation_builder
77
# Set action space to 4 to remove stop action
78
agent = RandomAgent(218, 4)
Egli Adrian (IT-SCI-API-PFI)'s avatar
Egli Adrian (IT-SCI-API-PFI) committed
79

80
81
# Empty dictionary for all agent action
action_dict = dict()
Egli Adrian (IT-SCI-API-PFI)'s avatar
Egli Adrian (IT-SCI-API-PFI) committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120

print("Start episode...")
# Reset environment and get initial observations for all agents
obs = env.reset()
# Update/Set agent's speed
for idx in range(env.get_num_agents()):
    speed = 1.0 / ((idx % 5) + 1.0)
    env.agents[idx].speed_data["speed"] = speed

# Reset the rendering sytem
env_renderer.reset()

# Here you can also further enhance the provided observation by means of normalization
# See training navigation example in the baseline repository

score = 0
# Run episode
frame_step = 0
for step in range(500):
    # Chose an action for each agent in the environment
    for a in range(env.get_num_agents()):
        action = agent.act(obs[a])
        action_dict.update({a: action})

    # Environment step which returns the observations for all agents, their corresponding
    # reward and whether their are done
    next_obs, all_rewards, done, _ = env.step(action_dict)
    env_renderer.render_env(show=True, show_observations=False, show_predictions=False)
    frame_step += 1
    # Update replay buffer and train agent
    for a in range(env.get_num_agents()):
        agent.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a]))
        score += all_rewards[a]

    obs = next_obs.copy()
    if done['__all__']:
        break

print('Episode: Steps {}\t Score = {}'.format(step, score))