Commit e6223ce1 authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

remove list starmap init for agents

parent 53d7dcc1
......@@ -73,8 +73,6 @@ class EnvAgent:
state_machine = attrib(default= Factory(lambda: TrainStateMachine(initial_state=TrainState.WAITING)) ,
type=TrainStateMachine)
malfunction_handler = attrib(default = Factory(lambda: MalfunctionHandler()), type=MalfunctionHandler)
state = attrib(default=TrainState.WAITING, type=TrainState)
position = attrib(default=None, type=Optional[Tuple[int, int]])
......@@ -134,35 +132,42 @@ class EnvAgent:
def from_line(cls, line: Line):
""" Create a list of EnvAgent from lists of positions, directions and targets
"""
speed_datas = []
speed_counters = []
for i in range(len(line.agent_positions)):
speed = line.agent_speeds[i] if line.agent_speeds is not None else 1.0
speed_datas.append({'position_fraction': 0.0,
'speed': speed,
'transition_action_on_cellexit': 0})
speed_counters.append( SpeedCounter(speed=speed) )
malfunction_datas = []
for i in range(len(line.agent_positions)):
malfunction_datas.append({'malfunction': 0,
'malfunction_rate': line.agent_malfunction_rates[
i] if line.agent_malfunction_rates is not None else 0.,
'next_malfunction': 0,
'nr_malfunctions': 0})
num_agents = len(line.agent_positions)
return list(starmap(EnvAgent, zip(line.agent_positions, # TODO : Dipam - Really want to change this way of loading agents
line.agent_directions,
line.agent_directions,
line.agent_targets,
[False] * len(line.agent_positions),
[None] * len(line.agent_positions), # earliest_departure
[None] * len(line.agent_positions), # latest_arrival
speed_datas,
malfunction_datas,
range(len(line.agent_positions)),
speed_counters,
)))
agent_list = []
for i_agent in range(num_agents):
speed = line.agent_speeds[i_agent] if line.agent_speeds is not None else 1.0
speed_data = {'position_fraction': 0.0,
'speed': speed,
'transition_action_on_cellexit': 0
}
if line.agent_malfunction_rates is not None:
malfunction_rate = line.agent_malfunction_rates[i_agent]
else:
malfunction_rate = 0.
malfunction_data = {'malfunction': 0,
'malfunction_rate': malfunction_rate,
'next_malfunction': 0,
'nr_malfunctions': 0
}
agent = EnvAgent(initial_position = line.agent_positions[i_agent],
initial_direction = line.agent_directions[i_agent],
direction = line.agent_directions[i_agent],
target = line.agent_targets[i_agent],
moving = False,
earliest_departure = None,
latest_arrival = None,
speed_data = speed_data,
malfunction_data = malfunction_data,
handle = i_agent,
speed_counter = SpeedCounter(speed=speed))
agent_list.append(agent)
return agent_list
@classmethod
def load_legacy_static_agent(cls, static_agents_data: Tuple):
......@@ -185,3 +190,7 @@ class EnvAgent:
handle=i)
agents.append(agent)
return agents
@property
def state(self):
return self.state_machine.state
......@@ -561,7 +561,6 @@ class RailEnv(Environment):
state_transition_signals = self.generate_state_transition_signals(agent, preprocessed_action, movement_allowed)
agent.state_machine.set_transition_signals(state_transition_signals)
agent.state_machine.step()
agent.state = agent.state_machine.state # TODO : Make this a property instead?
# Remove agent is required
if self.remove_agents_at_target and agent.state == TrainState.DONE:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment