Skip to content
Snippets Groups Projects
Commit e6223ce1 authored by Dipam Chakraborty's avatar Dipam Chakraborty
Browse files

remove list starmap init for agents

parent 53d7dcc1
No related branches found
No related tags found
No related merge requests found
...@@ -73,8 +73,6 @@ class EnvAgent: ...@@ -73,8 +73,6 @@ class EnvAgent:
state_machine = attrib(default= Factory(lambda: TrainStateMachine(initial_state=TrainState.WAITING)) , state_machine = attrib(default= Factory(lambda: TrainStateMachine(initial_state=TrainState.WAITING)) ,
type=TrainStateMachine) type=TrainStateMachine)
malfunction_handler = attrib(default = Factory(lambda: MalfunctionHandler()), type=MalfunctionHandler) malfunction_handler = attrib(default = Factory(lambda: MalfunctionHandler()), type=MalfunctionHandler)
state = attrib(default=TrainState.WAITING, type=TrainState)
position = attrib(default=None, type=Optional[Tuple[int, int]]) position = attrib(default=None, type=Optional[Tuple[int, int]])
...@@ -134,35 +132,42 @@ class EnvAgent: ...@@ -134,35 +132,42 @@ class EnvAgent:
def from_line(cls, line: Line): def from_line(cls, line: Line):
""" Create a list of EnvAgent from lists of positions, directions and targets """ Create a list of EnvAgent from lists of positions, directions and targets
""" """
speed_datas = [] num_agents = len(line.agent_positions)
speed_counters = []
for i in range(len(line.agent_positions)):
speed = line.agent_speeds[i] if line.agent_speeds is not None else 1.0
speed_datas.append({'position_fraction': 0.0,
'speed': speed,
'transition_action_on_cellexit': 0})
speed_counters.append( SpeedCounter(speed=speed) )
malfunction_datas = []
for i in range(len(line.agent_positions)):
malfunction_datas.append({'malfunction': 0,
'malfunction_rate': line.agent_malfunction_rates[
i] if line.agent_malfunction_rates is not None else 0.,
'next_malfunction': 0,
'nr_malfunctions': 0})
return list(starmap(EnvAgent, zip(line.agent_positions, # TODO : Dipam - Really want to change this way of loading agents agent_list = []
line.agent_directions, for i_agent in range(num_agents):
line.agent_directions, speed = line.agent_speeds[i_agent] if line.agent_speeds is not None else 1.0
line.agent_targets,
[False] * len(line.agent_positions), speed_data = {'position_fraction': 0.0,
[None] * len(line.agent_positions), # earliest_departure 'speed': speed,
[None] * len(line.agent_positions), # latest_arrival 'transition_action_on_cellexit': 0
speed_datas, }
malfunction_datas,
range(len(line.agent_positions)), if line.agent_malfunction_rates is not None:
speed_counters, malfunction_rate = line.agent_malfunction_rates[i_agent]
))) else:
malfunction_rate = 0.
malfunction_data = {'malfunction': 0,
'malfunction_rate': malfunction_rate,
'next_malfunction': 0,
'nr_malfunctions': 0
}
agent = EnvAgent(initial_position = line.agent_positions[i_agent],
initial_direction = line.agent_directions[i_agent],
direction = line.agent_directions[i_agent],
target = line.agent_targets[i_agent],
moving = False,
earliest_departure = None,
latest_arrival = None,
speed_data = speed_data,
malfunction_data = malfunction_data,
handle = i_agent,
speed_counter = SpeedCounter(speed=speed))
agent_list.append(agent)
return agent_list
@classmethod @classmethod
def load_legacy_static_agent(cls, static_agents_data: Tuple): def load_legacy_static_agent(cls, static_agents_data: Tuple):
...@@ -185,3 +190,7 @@ class EnvAgent: ...@@ -185,3 +190,7 @@ class EnvAgent:
handle=i) handle=i)
agents.append(agent) agents.append(agent)
return agents return agents
@property
def state(self):
return self.state_machine.state
...@@ -561,7 +561,6 @@ class RailEnv(Environment): ...@@ -561,7 +561,6 @@ class RailEnv(Environment):
state_transition_signals = self.generate_state_transition_signals(agent, preprocessed_action, movement_allowed) state_transition_signals = self.generate_state_transition_signals(agent, preprocessed_action, movement_allowed)
agent.state_machine.set_transition_signals(state_transition_signals) agent.state_machine.set_transition_signals(state_transition_signals)
agent.state_machine.step() agent.state_machine.step()
agent.state = agent.state_machine.state # TODO : Make this a property instead?
# Remove agent is required # Remove agent is required
if self.remove_agents_at_target and agent.state == TrainState.DONE: if self.remove_agents_at_target and agent.state == TrainState.DONE:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment