try_cell_graph.py 6.39 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import numpy as np
import time

# In Flatland you can use custom observation builders and predicitors
# Observation builders generate the observation needed by the controller
# Preditctors can be used to do short time prediction which can help in avoiding conflicts in the network
from flatland.envs.observations import GlobalObsForRailEnv, ObservationBuilder
# First of all we import the Flatland rail environment
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_env import RailEnvActions, RailAgentStatus
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.schedule_generators import sparse_schedule_generator
# We also include a renderer because we want to visualize what is going on in the environment
from flatland.utils.rendertools import RenderTool, AgentRenderVariant
from flatland.envs.malfunction_generators import malfunction_from_params

from libs.cell_graph_dispatcher import CellGraphDispatcher


start_time = time.time()

# width = 150  # With of map
# height = 150  # Height of map
# nr_trains = 200 # Number of trains that have an assigned task in the env
# cities_in_map = 35  # Number of cities where agents can start or end
# seed = 5  # Random seed

width = 50  # With of map
height = 50  # Height of map
nr_trains = 200 # Number of trains that have an assigned task in the env
cities_in_map = 35  # Number of cities where agents can start or end
seed = 5  # Random seed


# width = 150  # With of map
# height = 150  # Height of map
# nr_trains = 100  # Number of trains that have an assigned task in the env
# cities_in_map = 100  # Number of cities where agents can start or end
# seed = 14  # Random seed


grid_distribution_of_cities = False  # Type of city distribution, if False cities are randomly placed
max_rails_between_cities = 2  # Max number of tracks allowed between cities. This is number of entry point to a city
max_rail_in_cities = 6  # Max number of parallel tracks within a city, representing a realistic trainstation

rail_generator = sparse_rail_generator(max_num_cities=cities_in_map,
                                       seed=seed,
                                       grid_mode=grid_distribution_of_cities,
                                       max_rails_between_cities=max_rails_between_cities,
                                       max_rails_in_city=max_rail_in_cities,
                                       )

# The schedule generator can make very basic schedules with a start point, end point and a speed profile for each agent.
# The speed profiles can be adjusted directly as well as shown later on. We start by introducing a statistical
# distribution of speed profiles

# Different agent types (trains) with different speeds.
speed_ration_map = {1.: 0.25,  # Fast passenger train
                    1. / 2.: 0.25,  # Fast freight train
                    1. / 3.: 0.25,  # Slow commuter train
                    1. / 4.: 0.25}  # Slow freight train

# We can now initiate the schedule generator with the given speed profiles

schedule_generator = sparse_schedule_generator(speed_ration_map)

# We can furthermore pass stochastic data to the RailEnv constructor which will allow for stochastic malfunctions
# during an episode.

stochastic_data = {'malfunction_rate': 500,  # Rate of malfunction occurence of single agent
                   'prop_malfunction': 0.01,
                   'min_duration': 20,  # Minimal duration of malfunction
                   'max_duration': 80  # Max duration of malfunction
                   }


# Custom observation builder without predictor


class DummyObservationBuilder(ObservationBuilder):
    """
    DummyObservationBuilder class which returns dummy observations
    This is used in the evaluation service
    """

    def __init__(self):
        super().__init__()

    def reset(self):
        pass

    def get_many(self, handles = None) -> bool:
        return True

    def get(self, handle: int = 0) -> bool:
        return True


observation_builder = DummyObservationBuilder()

# Custom observation builder with predictor, uncomment line below if you want to try this one
# observation_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())

# Construct the enviornment with the given observation, generataors, predictors, and stochastic data
env = RailEnv(width=width,
              height=height,
              rail_generator=rail_generator,
              schedule_generator=schedule_generator,
              number_of_agents=nr_trains,
              malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),  # Malfunction data generator
              obs_builder_object=observation_builder,
              remove_agents_at_target=True  # Removes agents at the end of their journey to make space for others
              )
env.reset()

# Initiate the renderer
env_renderer = RenderTool(env, gl="PILSVG",
                          agent_render_variant=AgentRenderVariant.AGENT_SHOWS_OPTIONS_AND_BOX,
                          show_debug=False,
                          screen_height=1920,  # Adjust these parameters to fit your resolution
                          screen_width=1080)  # Adjust these parameters to fit your resolution

dispatcher = CellGraphDispatcher(env)

score = 0
# Run episode
frame_step = 0

step = 0
while True:
    step += 1

    action_dict = dispatcher.step(step)

    # Environment step which returns the observations for all agents, their corresponding
    # reward and whether their are done

    next_obs, all_rewards, done, _ = env.step(action_dict)

    env_renderer.render_env(show=True, show_observations=False, show_predictions=False)
    # env_renderer.render_env(show=True, show_observations=True, show_predictions=True)

    # os.makedirs('./misc/Fames2/', exist_ok=True)
    # env_renderer.gl.save_image('./misc/Fames2/flatland_frame_{:04d}.png'.format(step))
    frame_step += 1

    score += np.sum(list(all_rewards.values()))

    #
    # observations = next_obs.copy()
    finished = np.sum([a.status==RailAgentStatus.DONE or a.status==RailAgentStatus.DONE_REMOVED for a in env.agents])
    print('Episode: Steps {}\t Score = {}\t Finished = {}'.format(step, score, finished))

    if done['__all__']:
        break


finished = np.sum([a.status==RailAgentStatus.DONE or a.status==RailAgentStatus.DONE_REMOVED for a in env.agents])
print(f'Trains finished {finished}/{len(env.agents)} = {finished*100/len(env.agents):.2f}%')
print(f'Total time: {time.time()-start_time}s')