Commit b06c40e0 authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

replace malfunction data with malfunction handler

parent 2709f085
Pipeline #8532 failed with stages
......@@ -22,7 +22,6 @@ Agent = NamedTuple('Agent', [('initial_position', Tuple[int, int]),
('moving', bool),
('earliest_departure', int),
('latest_arrival', int),
('malfunction_data', dict),
('handle', int),
('position', Tuple[int, int]),
('arrival_time', int),
......@@ -68,13 +67,6 @@ class EnvAgent:
earliest_departure = attrib(default=None, type=int) # default None during _from_line()
latest_arrival = attrib(default=None, type=int) # default None during _from_line()
# if broken>0, the agent's actions are ignored for 'broken' steps
# number of time the agent had to stop, since the last time it broke down
malfunction_data = attrib(
default=Factory(
lambda: dict({'malfunction': 0, 'malfunction_rate': 0, 'next_malfunction': 0, 'nr_malfunctions': 0,
'moving_before_malfunction': False})))
handle = attrib(default=None)
# INIT TILL HERE IN _from_line()
......@@ -106,10 +98,7 @@ class EnvAgent:
self.old_direction = None
self.moving = False
# Reset agent malfunction values
self.malfunction_data['malfunction'] = 0
self.malfunction_data['nr_malfunctions'] = 0
self.malfunction_data['moving_before_malfunction'] = False
self.malfunction_handler.reset()
self.action_saver.clear_saved_action()
self.speed_counter.reset_counter()
......@@ -123,7 +112,6 @@ class EnvAgent:
moving=self.moving,
earliest_departure=self.earliest_departure,
latest_arrival=self.latest_arrival,
malfunction_data=self.malfunction_data,
handle=self.handle,
position=self.position,
old_direction=self.old_direction,
......@@ -169,16 +157,6 @@ class EnvAgent:
for i_agent in range(num_agents):
speed = line.agent_speeds[i_agent] if line.agent_speeds is not None else 1.0
if line.agent_malfunction_rates is not None:
malfunction_rate = line.agent_malfunction_rates[i_agent]
else:
malfunction_rate = 0.
malfunction_data = {'malfunction': 0,
'malfunction_rate': malfunction_rate,
'next_malfunction': 0,
'nr_malfunctions': 0
}
agent = EnvAgent(initial_position = line.agent_positions[i_agent],
initial_direction = line.agent_directions[i_agent],
direction = line.agent_directions[i_agent],
......@@ -186,7 +164,6 @@ class EnvAgent:
moving = False,
earliest_departure = None,
latest_arrival = None,
malfunction_data = malfunction_data,
handle = i_agent,
speed_counter = SpeedCounter(speed=speed))
agent_list.append(agent)
......@@ -200,17 +177,11 @@ class EnvAgent:
if len(static_agent) >= 6:
agent = EnvAgent(initial_position=static_agent[0], initial_direction=static_agent[1],
direction=static_agent[1], target=static_agent[2], moving=static_agent[3],
speed_counter=SpeedCounter(static_agent[4]['speed']), malfunction_data=static_agent[5],
handle=i)
speed_counter=SpeedCounter(static_agent[4]['speed']), handle=i)
else:
agent = EnvAgent(initial_position=static_agent[0], initial_direction=static_agent[1],
direction=static_agent[1], target=static_agent[2],
moving=False,
malfunction_data={
'malfunction': 0,
'nr_malfunctions': 0,
'moving_before_malfunction': False
},
speed_counter=SpeedCounter(1.0),
handle=i)
agents.append(agent)
......@@ -219,10 +190,15 @@ class EnvAgent:
def __str__(self):
return f"\n \
handle(agent index): {self.handle} \n \
initial_position: {self.initial_position} initial_direction: {self.initial_direction} \n \
position: {self.position} direction: {self.direction} target: {self.target} \n \
old_position: {self.old_position} old_direction {self.old_direction} \n \
earliest_departure: {self.earliest_departure} latest_arrival: {self.latest_arrival} \n \
initial_position: {self.initial_position} \n \
initial_direction: {self.initial_direction} \n \
position: {self.position} \n \
direction: {self.direction} \n \
target: {self.target} \n \
old_position: {self.old_position} \n \
old_direction {self.old_direction} \n \
earliest_departure: {self.earliest_departure} \n \
latest_arrival: {self.latest_arrival} \n \
state: {str(self.state)} \n \
malfunction_handler: {self.malfunction_handler} \n \
action_saver: {self.action_saver} \n \
......@@ -240,6 +216,14 @@ class EnvAgent:
warnings.warn("Not recommended to set the state with this function unless completely required")
self.state_machine.set_state(state)
@property
def malfunction_data(self):
raise ValueError("agent.malunction_data is deprecated, please use agent.malfunction_hander instead")
@property
def speed_data(self):
raise ValueError("agent.speed_data is deprecated, please use agent.speed_counter instead")
......@@ -148,7 +148,7 @@ class SparseLineGen(BaseLineGen):
timedelay_factor * alpha * (rail.width + rail.height + num_agents / len(city_positions)))
return Line(agent_positions=agents_position, agent_directions=agents_direction,
agent_targets=agents_target, agent_speeds=speeds, agent_malfunction_rates=None)
agent_targets=agents_target, agent_speeds=speeds)
def line_from_file(filename, load_from_package=None) -> LineGenerator:
......@@ -186,11 +186,7 @@ def line_from_file(filename, load_from_package=None) -> LineGenerator:
agents_target = [a.target for a in agents]
agents_speed = [a.speed_counter.speed for a in agents]
# Malfunctions from here are not used. They have their own generator.
#agents_malfunction = [a.malfunction_data['malfunction_rate'] for a in agents]
return Line(agent_positions=agents_position, agent_directions=agents_direction,
agent_targets=agents_target, agent_speeds=agents_speed,
agent_malfunction_rates=None)
agent_targets=agents_target, agent_speeds=agents_speed)
return generator
......@@ -219,7 +219,7 @@ def malfunction_from_file(filename: str, load_from_package=None) -> Tuple[Malfun
if reset:
return Malfunction(0)
if agent.malfunction_data['malfunction'] < 1:
if agent.malfunction_handler.malfunction_down_counter < 1:
if np_random.rand() < _malfunction_prob(mean_malfunction_rate):
num_broken_steps = np_random.randint(min_number_of_steps_broken,
max_number_of_steps_broken + 1) + 1
......
......@@ -99,8 +99,8 @@ class TreeObsForRailEnv(ObservationBuilder):
self.location_has_agent[tuple(_agent.position)] = 1
self.location_has_agent_direction[tuple(_agent.position)] = _agent.direction
self.location_has_agent_speed[tuple(_agent.position)] = _agent.speed_counter.speed
self.location_has_agent_malfunction[tuple(_agent.position)] = _agent.malfunction_data[
'malfunction']
self.location_has_agent_malfunction[tuple(_agent.position)] = \
_agent.malfunction_handler.malfunction_down_counter
# [NIMISH] WHAT IS THIS
if _agent.state.is_off_map_state() and \
......@@ -220,7 +220,7 @@ class TreeObsForRailEnv(ObservationBuilder):
(handle, *agent_virtual_position,
agent.direction)],
num_agents_same_direction=0, num_agents_opposite_direction=0,
num_agents_malfunctioning=agent.malfunction_data['malfunction'],
num_agents_malfunctioning=agent.malfunction_handler.malfunction_down_counter,
speed_min_fractional=agent.speed_counter.speed,
num_agents_ready_to_depart=0,
childs={})
......@@ -603,7 +603,7 @@ class GlobalObsForRailEnv(ObservationBuilder):
# second channel only for other agents
if i != handle:
obs_agents_state[other_agent.position][1] = other_agent.direction
obs_agents_state[other_agent.position][2] = other_agent.malfunction_data['malfunction']
obs_agents_state[other_agent.position][2] = other_agent.malfunction_handler.malfunction_down_counter
obs_agents_state[other_agent.position][3] = other_agent.speed_counter.speed
# fifth channel: all ready to depart on this position
if other_agent.state.is_off_map_state():
......
......@@ -253,7 +253,7 @@ class RailEnvPersister(object):
#msgpack.packb(agent_data, use_bin_type=True)
distance_map_data = self.distance_map.get()
malfunction_data: MalfunctionProcessData = self.malfunction_process_data
malfunction_data: mal_gen.MalfunctionProcessData = self.malfunction_process_data
#msgpack.packb(distance_map_data, use_bin_type=True) # does nothing
msg_data = {
"grid": grid_data,
......
......@@ -11,6 +11,10 @@ class MalfunctionHandler:
def __init__(self):
self._malfunction_down_counter = 0
self.num_malfunctions = 0
def reset(self):
self._malfunction_down_counter = 0
self.num_malfunctions = 0
@property
def in_malfunction(self):
......
......@@ -6,8 +6,7 @@ from flatland.core.grid.grid_utils import IntVector2DArray
Line = NamedTuple('Line', [('agent_positions', IntVector2DArray),
('agent_directions', List[Grid4TransitionsEnum]),
('agent_targets', IntVector2DArray),
('agent_speeds', List[float]),
('agent_malfunction_rates', List[int])])
('agent_speeds', List[float])])
Timetable = NamedTuple('Timetable', [('earliest_departures', List[int]),
('latest_arrivals', List[int]),
......
......@@ -688,7 +688,7 @@ class RenderLocal(RenderBase):
malfunction=False)
continue
is_malfunction = agent.malfunction_data["malfunction"] > 0
is_malfunction = agent.malfunction_handler.malfunction_down_counter > 0
if self.agent_render_variant == AgentRenderVariant.BOX_ONLY:
self.gl.set_cell_occupied(agent_idx, *(agent.position))
......
......@@ -254,7 +254,7 @@ def test_initial_malfunction():
# reset to initialize agents_static
env.reset(False, False, random_seed=10)
env._max_episode_steps = 1000
print(env.agents[0].malfunction_data)
print(env.agents[0].malfunction_handler)
env.agents[0].target = (0, 5)
set_penalties_for_replay(env)
replay_config = ReplayConfig(
......@@ -568,8 +568,8 @@ def test_last_malfunction_step():
env.agents[a_idx].position = env.agents[a_idx].initial_position
env.agents[a_idx].state = TrainState.MOVING
# Force malfunction to be off at beginning and next malfunction to happen in 2 steps
env.agents[0].malfunction_data['next_malfunction'] = 2
env.agents[0].malfunction_data['malfunction'] = 0
# env.agents[0].malfunction_data['next_malfunction'] = 2
env.agents[0].malfunction_handler.malfunction_down_counter = 0
env_data = []
# Perform DO_NOTHING actions until all trains get to READY_TO_DEPART
......@@ -582,14 +582,14 @@ def test_last_malfunction_step():
# Go forward all the time
action_dict[agent.handle] = RailEnvActions(2)
if env.agents[0].malfunction_data['malfunction'] < 1:
if env.agents[0].malfunction_handler.malfunction_down_counter < 1:
agent_can_move = True
# Store the position before and after the step
pre_position = env.agents[0].speed_counter.counter
_, reward, _, _ = env.step(action_dict)
# Check if the agent is still allowed to move in this step
if env.agents[0].malfunction_data['malfunction'] > 0:
if env.agents[0].malfunction_handler.malfunction_down_counter > 0:
agent_can_move = False
post_position = env.agents[0].speed_counter.counter
# Assert that the agent moved while it was still allowed
......
......@@ -105,9 +105,9 @@ def test_get_global_observation():
for other_i, other_agent in enumerate(env.agents):
if other_agent.state in [TrainState.MOVING, TrainState.MALFUNCTION, TrainState.STOPPED,
TrainState.DONE] and other_agent.position == (r, c):
assert np.isclose(obs_agents_state[(r, c)][2], other_agent.malfunction_data['malfunction']), \
assert np.isclose(obs_agents_state[(r, c)][2], other_agent.malfunction_handler.malfunction_down_counter), \
"agent {} in state {} at {} should see agent malfunction {}, found = {}" \
.format(i, agent.state, (r, c), other_agent.malfunction_data['malfunction'],
.format(i, agent.state, (r, c), other_agent.malfunction_handler.malfunction_down_counter,
obs_agents_state[(r, c)][2])
assert np.isclose(obs_agents_state[(r, c)][3], other_agent.speed_counter.speed)
has_agent = True
......
......@@ -109,5 +109,5 @@ def test_single_malfunction_generator():
break
for agent in env.agents:
# Go forward all the time
tot_malfunctions += agent.malfunction_data['nr_malfunctions']
tot_malfunctions += agent.malfunction_handler.num_malfunctions
assert tot_malfunctions == 1
......@@ -133,8 +133,8 @@ def test_seeding_and_malfunction():
action = np.random.randint(4)
action_dict[a] = action
# print("----------------------")
# print(env.agents[a].malfunction_data, env.agents[a].status)
# print(env2.agents[a].malfunction_data, env2.agents[a].status)
# print(env.agents[a].malfunction_handler, env.agents[a].status)
# print(env2.agents[a].malfunction_handler, env2.agents[a].status)
_, reward1, done1, _ = env.step(action_dict)
_, reward2, done2, _ = env2.step(action_dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment