Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Flatland
Flatland
Commits
e4399082
Commit
e4399082
authored
Sep 10, 2021
by
Dipam Chakraborty
Browse files
Change speed data to speed counter
parent
8a3a043c
Pipeline
#8453
failed with stages
in 5 minutes and 31 seconds
Changes
8
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
flatland/action_plan/action_plan.py
View file @
e4399082
...
...
@@ -150,7 +150,7 @@ class ControllerFromTrainruns():
def
_create_action_plan_for_agent
(
self
,
agent_id
,
trainrun
)
->
ActionPlan
:
action_plan
=
[]
agent
=
self
.
env
.
agents
[
agent_id
]
minimum_cell_time
=
int
(
np
.
ceil
(
1.0
/
agent
.
speed_data
[
'speed'
]))
minimum_cell_time
=
agent
.
speed_counter
.
max_count
for
path_loop
,
trainrun_waypoint
in
enumerate
(
trainrun
):
trainrun_waypoint
:
TrainrunWaypoint
=
trainrun_waypoint
...
...
flatland/action_plan/action_plan_player.py
View file @
e4399082
...
...
@@ -30,6 +30,8 @@ class ControllerFromTrainrunsReplayer():
assert
agent
.
position
==
waypoint
.
position
,
\
"before {}, agent {} at {}, expected {}"
.
format
(
i
,
agent_id
,
agent
.
position
,
waypoint
.
position
)
if
agent_id
==
1
:
print
(
env
.
_elapsed_steps
,
agent
.
position
,
agent
.
state
,
agent
.
speed_counter
)
actions
=
ctl
.
act
(
i
)
print
(
"actions for {}: {}"
.
format
(
i
,
actions
))
...
...
flatland/envs/agent_utils.py
View file @
e4399082
from
flatland.envs.rail_trainrun_data_structures
import
Waypoint
import
numpy
as
np
import
warnings
from
typing
import
Tuple
,
Optional
,
NamedTuple
,
List
...
...
@@ -21,7 +22,6 @@ Agent = NamedTuple('Agent', [('initial_position', Tuple[int, int]),
(
'moving'
,
bool
),
(
'earliest_departure'
,
int
),
(
'latest_arrival'
,
int
),
(
'speed_data'
,
dict
),
(
'malfunction_data'
,
dict
),
(
'handle'
,
int
),
(
'position'
,
Tuple
[
int
,
int
]),
...
...
@@ -49,13 +49,6 @@ class EnvAgent:
earliest_departure
=
attrib
(
default
=
None
,
type
=
int
)
# default None during _from_line()
latest_arrival
=
attrib
(
default
=
None
,
type
=
int
)
# default None during _from_line()
# speed_data: speed is added to position_fraction on each moving step, until position_fraction>=1.0,
# after which 'transition_action_on_cellexit' is executed (equivalent to executing that action in the previous
# cell if speed=1, as default)
# N.B. we need to use factory since default arguments are not recreated on each call!
speed_data
=
attrib
(
default
=
Factory
(
lambda
:
dict
({
'position_fraction'
:
0.0
,
'speed'
:
1.0
,
'transition_action_on_cellexit'
:
0
})))
# if broken>0, the agent's actions are ignored for 'broken' steps
# number of time the agent had to stop, since the last time it broke down
malfunction_data
=
attrib
(
...
...
@@ -67,7 +60,7 @@ class EnvAgent:
# INIT TILL HERE IN _from_line()
# Env step facelift
speed_counter
=
attrib
(
default
=
None
,
type
=
SpeedCounter
)
speed_counter
=
attrib
(
default
=
Factory
(
lambda
:
SpeedCounter
(
1.0
))
,
type
=
SpeedCounter
)
action_saver
=
attrib
(
default
=
Factory
(
lambda
:
ActionSaver
()),
type
=
ActionSaver
)
state_machine
=
attrib
(
default
=
Factory
(
lambda
:
TrainStateMachine
(
initial_state
=
TrainState
.
WAITING
))
,
type
=
TrainStateMachine
)
...
...
@@ -94,10 +87,6 @@ class EnvAgent:
self
.
old_direction
=
None
self
.
moving
=
False
# Reset agent values for speed
self
.
speed_data
[
'position_fraction'
]
=
0.
self
.
speed_data
[
'transition_action_on_cellexit'
]
=
0.
# Reset agent malfunction values
self
.
malfunction_data
[
'malfunction'
]
=
0
self
.
malfunction_data
[
'nr_malfunctions'
]
=
0
...
...
@@ -115,7 +104,6 @@ class EnvAgent:
moving
=
self
.
moving
,
earliest_departure
=
self
.
earliest_departure
,
latest_arrival
=
self
.
latest_arrival
,
speed_data
=
self
.
speed_data
,
malfunction_data
=
self
.
malfunction_data
,
handle
=
self
.
handle
,
state
=
self
.
state
,
...
...
@@ -137,7 +125,7 @@ class EnvAgent:
distance
=
len
(
shortest_path
)
else
:
distance
=
0
speed
=
self
.
speed_
data
[
'
speed
'
]
speed
=
self
.
speed_
counter
.
speed
return
int
(
np
.
ceil
(
distance
/
speed
))
def
get_time_remaining_until_latest_arrival
(
self
,
elapsed_steps
:
int
)
->
int
:
...
...
@@ -161,11 +149,6 @@ class EnvAgent:
agent_list
=
[]
for
i_agent
in
range
(
num_agents
):
speed
=
line
.
agent_speeds
[
i_agent
]
if
line
.
agent_speeds
is
not
None
else
1.0
speed_data
=
{
'position_fraction'
:
0.0
,
'speed'
:
speed
,
'transition_action_on_cellexit'
:
0
}
if
line
.
agent_malfunction_rates
is
not
None
:
malfunction_rate
=
line
.
agent_malfunction_rates
[
i_agent
]
...
...
@@ -177,7 +160,6 @@ class EnvAgent:
'next_malfunction'
:
0
,
'nr_malfunctions'
:
0
}
agent
=
EnvAgent
(
initial_position
=
line
.
agent_positions
[
i_agent
],
initial_direction
=
line
.
agent_directions
[
i_agent
],
direction
=
line
.
agent_directions
[
i_agent
],
...
...
@@ -185,7 +167,6 @@ class EnvAgent:
moving
=
False
,
earliest_departure
=
None
,
latest_arrival
=
None
,
speed_data
=
speed_data
,
malfunction_data
=
malfunction_data
,
handle
=
i_agent
,
speed_counter
=
SpeedCounter
(
speed
=
speed
))
...
...
@@ -195,6 +176,7 @@ class EnvAgent:
@
classmethod
def
load_legacy_static_agent
(
cls
,
static_agents_data
:
Tuple
):
raise
NotImplementedError
(
"Not implemented for Flatland 3"
)
agents
=
[]
for
i
,
static_agent
in
enumerate
(
static_agents_data
):
if
len
(
static_agent
)
>=
6
:
...
...
@@ -205,16 +187,35 @@ class EnvAgent:
agent
=
EnvAgent
(
initial_position
=
static_agent
[
0
],
initial_direction
=
static_agent
[
1
],
direction
=
static_agent
[
1
],
target
=
static_agent
[
2
],
moving
=
False
,
speed_data
=
{
"speed"
:
1.
,
"position_fraction"
:
0.
,
"transition_action_on_cell_exit"
:
0.
},
malfunction_data
=
{
'malfunction'
:
0
,
'nr_malfunctions'
:
0
,
'moving_before_malfunction'
:
False
},
speed_counter
=
SpeedCounter
(
1.0
),
handle
=
i
)
agents
.
append
(
agent
)
return
agents
def
_set_state
(
self
,
state
):
warnings
.
warn
(
"Not recommended to set the state with this function unless completely required"
)
self
.
state_machine
.
set_state
(
state
)
def
__str__
(
self
):
return
f
"
\n
\
handle(agent index):
{
self
.
handle
}
\n
\
initial_position:
{
self
.
initial_position
}
initial_direction:
{
self
.
initial_direction
}
\n
\
position:
{
self
.
position
}
direction:
{
self
.
position
}
target:
{
self
.
target
}
\n
\
earliest_departure:
{
self
.
earliest_departure
}
latest_arrival:
{
self
.
latest_arrival
}
\n
\
state:
{
str
(
self
.
state
)
}
\n
\
malfunction_data:
{
self
.
malfunction_data
}
\n
\
action_saver:
{
self
.
action_saver
}
\n
\
speed_counter:
{
self
.
speed_counter
}
"
@
property
def
state
(
self
):
return
self
.
state_machine
.
state
flatland/envs/line_generators.py
View file @
e4399082
...
...
@@ -189,7 +189,7 @@ def line_from_file(filename, load_from_package=None) -> LineGenerator:
#agents_direction = [a.direction for a in agents]
agents_direction
=
[
a
.
initial_direction
for
a
in
agents
]
agents_target
=
[
a
.
target
for
a
in
agents
]
agents_speed
=
[
a
.
speed_
data
[
'
speed
'
]
for
a
in
agents
]
agents_speed
=
[
a
.
speed_
counter
.
speed
for
a
in
agents
]
# Malfunctions from here are not used. They have their own generator.
#agents_malfunction = [a.malfunction_data['malfunction_rate'] for a in agents]
...
...
flatland/envs/observations.py
View file @
e4399082
...
...
@@ -98,7 +98,7 @@ class TreeObsForRailEnv(ObservationBuilder):
_agent
.
position
:
self
.
location_has_agent
[
tuple
(
_agent
.
position
)]
=
1
self
.
location_has_agent_direction
[
tuple
(
_agent
.
position
)]
=
_agent
.
direction
self
.
location_has_agent_speed
[
tuple
(
_agent
.
position
)]
=
_agent
.
speed_
data
[
'
speed
'
]
self
.
location_has_agent_speed
[
tuple
(
_agent
.
position
)]
=
_agent
.
speed_
counter
.
speed
self
.
location_has_agent_malfunction
[
tuple
(
_agent
.
position
)]
=
_agent
.
malfunction_data
[
'malfunction'
]
...
...
@@ -221,7 +221,7 @@ class TreeObsForRailEnv(ObservationBuilder):
agent
.
direction
)],
num_agents_same_direction
=
0
,
num_agents_opposite_direction
=
0
,
num_agents_malfunctioning
=
agent
.
malfunction_data
[
'malfunction'
],
speed_min_fractional
=
agent
.
speed_
data
[
'
speed
'
],
speed_min_fractional
=
agent
.
speed_
counter
.
speed
num_agents_ready_to_depart
=
0
,
childs
=
{})
#print("root node type:", type(root_node_observation))
...
...
@@ -275,7 +275,7 @@ class TreeObsForRailEnv(ObservationBuilder):
visited
=
OrderedSet
()
agent
=
self
.
env
.
agents
[
handle
]
time_per_cell
=
np
.
reciprocal
(
agent
.
speed_
data
[
"
speed
"
]
)
time_per_cell
=
np
.
reciprocal
(
agent
.
speed_
counter
.
speed
)
own_target_encountered
=
np
.
inf
other_agent_encountered
=
np
.
inf
other_target_encountered
=
np
.
inf
...
...
@@ -604,7 +604,7 @@ class GlobalObsForRailEnv(ObservationBuilder):
if
i
!=
handle
:
obs_agents_state
[
other_agent
.
position
][
1
]
=
other_agent
.
direction
obs_agents_state
[
other_agent
.
position
][
2
]
=
other_agent
.
malfunction_data
[
'malfunction'
]
obs_agents_state
[
other_agent
.
position
][
3
]
=
other_agent
.
speed_
data
[
'
speed
'
]
obs_agents_state
[
other_agent
.
position
][
3
]
=
other_agent
.
speed_
counter
.
speed
# fifth channel: all ready to depart on this position
if
other_agent
.
state
.
is_off_map_state
():
obs_agents_state
[
other_agent
.
initial_position
][
4
]
+=
1
...
...
flatland/envs/predictions.py
View file @
e4399082
...
...
@@ -141,7 +141,7 @@ class ShortestPathPredictorForRailEnv(PredictionBuilder):
continue
agent_virtual_direction
=
agent
.
direction
agent_speed
=
agent
.
speed_
data
[
"
speed
"
]
agent_speed
=
agent
.
speed_
counter
.
speed
times_per_cell
=
int
(
np
.
reciprocal
(
agent_speed
))
prediction
=
np
.
zeros
(
shape
=
(
self
.
max_depth
+
1
,
5
))
prediction
[
0
]
=
[
0
,
*
agent_virtual_position
,
agent_virtual_direction
,
0
]
...
...
flatland/envs/rail_env.py
View file @
e4399082
...
...
@@ -261,8 +261,7 @@ class RailEnv(Environment):
False: Agent cannot provide an action
"""
return
agent
.
state
==
TrainState
.
READY_TO_DEPART
or
\
(
agent
.
state
.
is_on_map_state
()
and
\
fast_isclose
(
agent
.
speed_data
[
'position_fraction'
],
0.0
,
rtol
=
1e-03
)
)
(
agent
.
state
.
is_on_map_state
()
and
agent
.
speed_counter
.
is_cell_entry
)
def
reset
(
self
,
regenerate_rail
:
bool
=
True
,
regenerate_schedule
:
bool
=
True
,
*
,
random_seed
:
bool
=
None
)
->
Tuple
[
Dict
,
Dict
]:
...
...
@@ -344,19 +343,6 @@ class RailEnv(Environment):
# Reset agents to initial states
self
.
reset_agents
()
# for agent in self.agents:
# # Induce malfunctions
# if activate_agents:
# self.set_agent_active(agent)
# self._break_agent(agent)
# if agent.malfunction_data["malfunction"] > 0:
# agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.DO_NOTHING
# # Fix agents that finished their malfunction
# self._fix_agent_after_malfunction(agent)
self
.
num_resets
+=
1
self
.
_elapsed_steps
=
0
...
...
@@ -369,14 +355,7 @@ class RailEnv(Environment):
# Empty the episode store of agent positions
self
.
cur_episode
=
[]
info_dict
:
Dict
=
{
'action_required'
:
{
i
:
self
.
action_required
(
agent
)
for
i
,
agent
in
enumerate
(
self
.
agents
)},
'malfunction'
:
{
i
:
agent
.
malfunction_data
[
'malfunction'
]
for
i
,
agent
in
enumerate
(
self
.
agents
)
},
'speed'
:
{
i
:
agent
.
speed_data
[
'speed'
]
for
i
,
agent
in
enumerate
(
self
.
agents
)},
'state'
:
{
i
:
agent
.
state
for
i
,
agent
in
enumerate
(
self
.
agents
)}
}
info_dict
=
self
.
get_info_dict
()
# Return the new observation vectors for each agent
observation_dict
:
Dict
=
self
.
_get_observations
()
return
observation_dict
,
info_dict
...
...
@@ -469,10 +448,12 @@ class RailEnv(Environment):
def
get_info_dict
(
self
):
# TODO Important : Update this
info_dict
=
{
"action_required"
:
{},
"malfunction"
:
{},
"speed"
:
{},
"status"
:
{},
'action_required'
:
{
i
:
self
.
action_required
(
agent
)
for
i
,
agent
in
enumerate
(
self
.
agents
)},
'malfunction'
:
{
i
:
agent
.
malfunction_data
[
'malfunction'
]
for
i
,
agent
in
enumerate
(
self
.
agents
)
},
'speed'
:
{
i
:
agent
.
speed_counter
.
speed
for
i
,
agent
in
enumerate
(
self
.
agents
)},
'state'
:
{
i
:
agent
.
state
for
i
,
agent
in
enumerate
(
self
.
agents
)}
}
return
info_dict
...
...
flatland/envs/timetable_generators.py
View file @
e4399082
...
...
@@ -57,7 +57,7 @@ def timetable_generator(agents: List[EnvAgent], distance_map: DistanceMap,
shortest_paths_lengths
=
[
len_handle_none
(
v
)
for
k
,
v
in
shortest_paths
.
items
()]
# Find mean_shortest_path_time
agent_speeds
=
[
agent
.
speed_
data
[
'
speed
'
]
for
agent
in
agents
]
agent_speeds
=
[
agent
.
speed_
counter
.
speed
for
agent
in
agents
]
agent_shortest_path_times
=
np
.
array
(
shortest_paths_lengths
)
/
np
.
array
(
agent_speeds
)
mean_shortest_path_time
=
np
.
mean
(
agent_shortest_path_times
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment