cell_graph_dispatcher.py 8.78 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import traceback
from copy import deepcopy
from typing import Dict

from flatland.envs.rail_env import RailEnv, RailAgentStatus, RailEnvActions
from libs import cell_graph_rescheduling, cell_graph_partial_rescheduling, cell_graph_rescheduling_data
from libs.cell_graph import CellGraph
from libs.cell_graph_agent import CellGraphAgent
from libs.cell_graph_locker import CellGraphLocker


class CellGraphDispatcher:
    def __init__(self, env: RailEnv, sort_function=None):
        self.env = env

        self.graph = CellGraph(env)
        self.locker = CellGraphLocker(self.graph)

        max_steps = env._max_episode_steps
        self.controllers = [CellGraphAgent(agent, self.graph, self.locker, i, max_steps) for i, agent in
                            enumerate(env.agents)]

        self.action_dict = {}

        if sort_function is None:
            sort_function = lambda idx: self.controllers[idx].dist_to_target[
                                            self.graph._vertex_idx_from_point(env.agents[idx].initial_position),
                                            env.agents[idx].initial_direction] \
                                        - 10000 * env.agents[idx].speed_data['speed']
        else:
            sort_function = sort_function(self)

        self.agents_order = sorted(range(len(env.agents)), key=sort_function)
        self.agent_locked_by_malfunction = []
        for agent in env.agents:
            self.agent_locked_by_malfunction.append(agent.malfunction_data['malfunction'] > 0)
        self.crashed = False
        self.blocked_agents = set()

    def step(self, step) -> Dict[int, RailEnvActions]:
        try:
            has_new_malfunctions = False
            for i, agent in enumerate(self.env.agents):
                is_locked = agent.malfunction_data['malfunction']

                if agent.status == RailAgentStatus.ACTIVE:
                    if (not self.agent_locked_by_malfunction[i]) and is_locked:
                        has_new_malfunctions = True

                self.agent_locked_by_malfunction[i] = is_locked

            updated = set()

            full_recalc_needed = False
            # old_locker = None
            try:
                if has_new_malfunctions:
                    # print('new malfunction at step', step)
                    # old_locker = deepcopy(self.locker)
                    cached_ways, vertex_agent_order, agent_way_position, agent_position_duration = \
                        cell_graph_rescheduling_data.get_rescheduling_data(self.env, step, self.controllers, self.graph,
                                                                           self.locker)

                    vertex_agent_order2 = deepcopy(vertex_agent_order)
                    agent_way_position2 = deepcopy(agent_way_position)
                    agent_position_duration2 = deepcopy(agent_position_duration)

                    new_way, full_recalc_needed = cell_graph_rescheduling.reschedule(cached_ways, vertex_agent_order,
                                                                                     agent_way_position,
                                                                                     agent_position_duration,
                                                                                     self.env, step, self.controllers,
                                                                                     self.graph, self.locker)
                    for i in self.agents_order:
                        if len(new_way[i]):
                            changed = cell_graph_rescheduling.recover_agent_way(self.controllers[i], self.env.agents[i],
                                                                                self.graph, new_way[i])
                            if changed:
                                updated.add(i)
            # resheduling failed, try to make a partial rescheduling
            except Exception as e:
                print("-----------------Rescheduling Exception----------------")
                print("Step: ", step)
                # traceback.print_exc()
                print("-----------------Rescheduling Exception----------------")

                updated.clear()
                full_recalc_needed = False
                # if old_locker is not None:
                #     self.locker.data = old_locker.data

                self.partial_resheduling(cached_ways, vertex_agent_order2, agent_way_position2,
                                         agent_position_duration2, step)
                self.limit_max_visited()

            for i in self.agents_order:
                try:
                    agent = self.env.agents[i]

                    # if agent.speed_data['position_fraction'] >= 1.0:
                    #     print('agent', i, 'blocked by some another agent, fraction:', agent.speed_data['position_fraction'])

                    force_new_path = full_recalc_needed or self.crashed or i in updated
                    # force_new_path = full_recalc_needed or i in updated
                    # if (force_new_path and i in self.blocked_agents):
                    #     # self.action_dict.update({i: RailEnvActions.DO_NOTHING})
                    #     force_new_path = False
                    #     # continue
                    if i in self.blocked_agents:
                        force_new_path = True

                    if agent.speed_data['position_fraction'] > 0.0 and not force_new_path:
                        self.action_dict.update({i: RailEnvActions.DO_NOTHING})
                        continue

                    # action = self.controllers[i].act(agent, step, force_new_path=has_new_malfunctions)
                    action = self.controllers[i].act(agent, step, force_new_path=force_new_path)
                    self.action_dict.update({i: action})
                # act crashed tor one agent
                except Exception as e:
                    print("-----------------Agent step Exception----------------", i)
                    print("Step: ", step)
                    # traceback.print_exc()
                    print("-----------------Agent step Exception----------------")

                    self.action_dict.update({i: RailEnvActions.DO_NOTHING})
                    self.limit_max_visited()
                    # pass

            self.blocked_agents.clear()
            self.crashed = False

        # global step exception handling, no idea what to do here
        except Exception as e:
            # except ArithmeticError:
            self.crashed = True
            print("-----------------Step Exception----------------")
            print("Step: ", step)
            traceback.print_exc()
            print("-----------------Step Exception----------------")

            # hit_problem = False
            # for j in self.agents_order:
            #     if j == i:
            #         hit_problem = True
            #     if hit_problem:
            #         self.action_dict.update({j: RailEnvActions.STOP_MOVING })
            self.action_dict = {i: RailEnvActions.STOP_MOVING for i in self.agents_order}
            self.limit_max_visited()

            # raise e

        return self.action_dict

    def partial_resheduling(self, cached_ways, vertex_agent_order2, agent_way_position2, agent_position_duration2,
                            step):
        print('partial_resheduling')
        try:
            new_way, blocked_agents = cell_graph_partial_rescheduling.partial_reschedule(cached_ways,
                                                                                         vertex_agent_order2,
                                                                                         agent_way_position2,
                                                                                         agent_position_duration2,
                                                                                         self.env, step,
                                                                                         self.controllers, self.graph,
                                                                                         self.locker)

            for i in self.agents_order:
                if len(new_way[i]):
                    cell_graph_rescheduling.recover_agent_way(self.controllers[i], self.env.agents[i], self.graph,
                                                              new_way[i])

            self.blocked_agents.update(blocked_agents)

            print('blocked agents', self.blocked_agents)

        except Exception as e:
            self.crashed = True
            print("-----------------Partial rescheduing Exception----------------")
            traceback.print_exc()
            print("-----------------Partial rescheduing Exception----------------")
            self.limit_max_visited()

    def limit_max_visited(self):
        for c in self.controllers:
            c.set_max_visited(100)