Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
Flatland
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
pranjal_dhole
Flatland
Commits
fce9451b
Commit
fce9451b
authored
4 years ago
by
hagrid67
Browse files
Options
Downloads
Patches
Plain Diff
adding missing files, and fixed malfunction_generators
parent
48b99c0c
No related branches found
Branches containing commit
No related tags found
No related merge requests found
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
flatland/envs/malfunction_generators.py
+18
-10
18 additions, 10 deletions
flatland/envs/malfunction_generators.py
flatland/envs/persistence.py
+266
-0
266 additions, 0 deletions
flatland/envs/persistence.py
requirements_dev.txt
+1
-1
1 addition, 1 deletion
requirements_dev.txt
with
285 additions
and
11 deletions
flatland/envs/malfunction_generators.py
+
18
−
10
View file @
fce9451b
...
...
@@ -2,11 +2,11 @@
from
typing
import
Callable
,
NamedTuple
,
Optional
,
Tuple
import
msgpack
import
numpy
as
np
from
numpy.random.mtrand
import
RandomState
from
flatland.envs.agent_utils
import
EnvAgent
,
RailAgentStatus
from
flatland.envs
import
persistence
Malfunction
=
NamedTuple
(
'
Malfunction
'
,
[(
'
num_broken_steps
'
,
int
)])
MalfunctionParameters
=
NamedTuple
(
'
MalfunctionParameters
'
,
...
...
@@ -28,7 +28,7 @@ def _malfunction_prob(rate: float) -> float:
return
1
-
np
.
exp
(
-
(
1
/
rate
))
def
malfunction_from_file
(
filename
:
str
)
->
Tuple
[
MalfunctionGenerator
,
MalfunctionProcessData
]:
def
malfunction_from_file
(
filename
:
str
,
load_from_package
=
None
)
->
Tuple
[
MalfunctionGenerator
,
MalfunctionProcessData
]:
"""
Utility to load pickle file
...
...
@@ -40,18 +40,26 @@ def malfunction_from_file(filename: str) -> Tuple[MalfunctionGenerator, Malfunct
-------
generator, Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken
"""
with
open
(
filename
,
"
rb
"
)
as
file_in
:
load_data
=
file_in
.
read
()
data
=
msgpack
.
unpackb
(
load_data
,
use_list
=
False
,
encoding
=
'
utf-8
'
)
# with open(filename, "rb") as file_in:
# load_data = file_in.read()
# if filename.endswith("mpk"):
# data = msgpack.unpackb(load_data, use_list=False, encoding='utf-8')
# elif filename.endswith("pkl"):
# data = pickle.loads(load_data)
env_dict
=
persistence
.
RailEnvPersister
.
load_env_dict
(
filename
,
load_from_package
=
load_from_package
)
# TODO: make this better by using namedtuple in the pickle file. See issue 282
data
[
'
malfunction
'
]
=
MalfunctionProcessData
.
_make
(
data
[
'
malfunction
'
])
if
"
malfunction
"
in
data
:
if
"
malfunction
"
in
env_dict
:
env_dict
[
'
malfunction
'
]
=
oMPD
=
MalfunctionProcessData
.
_make
(
env_dict
[
'
malfunction
'
])
else
:
oMPD
=
None
if
oMPD
is
not
None
:
# Mean malfunction in number of time steps
mean_malfunction_rate
=
data
[
"
malfunction
"
]
.
malfunction_rate
mean_malfunction_rate
=
oMPD
.
malfunction_rate
# Uniform distribution parameters for malfunction duration
min_number_of_steps_broken
=
data
[
"
malfunction
"
]
.
min_duration
max_number_of_steps_broken
=
data
[
"
malfunction
"
]
.
max_duration
min_number_of_steps_broken
=
oMPD
.
min_duration
max_number_of_steps_broken
=
oMPD
.
max_duration
else
:
# Mean malfunction in number of time steps
mean_malfunction_rate
=
0.
...
...
This diff is collapsed.
Click to expand it.
flatland/envs/persistence.py
0 → 100644
+
266
−
0
View file @
fce9451b
import
pickle
import
msgpack
import
numpy
as
np
from
flatland.envs
import
rail_env
#from flatland.core.env import Environment
from
flatland.core.env_observation_builder
import
DummyObservationBuilder
#from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions
#from flatland.core.grid.grid4_utils import get_new_position
#from flatland.core.grid.grid_utils import IntVector2D
from
flatland.core.transition_map
import
GridTransitionMap
from
flatland.envs.agent_utils
import
Agent
,
EnvAgent
,
RailAgentStatus
from
flatland.envs.distance_map
import
DistanceMap
#from flatland.envs.observations import GlobalObsForRailEnv
# cannot import objects / classes directly because of circular import
from
flatland.envs
import
malfunction_generators
as
mal_gen
from
flatland.envs
import
rail_generators
as
rail_gen
from
flatland.envs
import
schedule_generators
as
sched_gen
class
RailEnvPersister
(
object
):
@classmethod
def
save
(
cls
,
env
,
filename
,
save_distance_maps
=
False
):
"""
Saves environment and distance map information in a file
Parameters:
---------
filename: string
save_distance_maps: bool
"""
env_dict
=
cls
.
get_full_state
(
env
)
if
save_distance_maps
is
True
:
oDistMap
=
env
.
distance_map
.
get
()
if
oDistMap
is
not
None
:
if
len
(
oDistMap
)
>
0
:
env_dict
[
"
distance_map
"
]
=
oDistMap
else
:
print
(
"
[WARNING] Unable to save the distance map for this environment, as none was found !
"
)
else
:
print
(
"
[WARNING] Unable to save the distance map for this environment, as none was found !
"
)
with
open
(
filename
,
"
wb
"
)
as
file_out
:
if
filename
.
endswith
(
"
mpk
"
):
file_out
.
write
(
msgpack
.
packb
(
env_dict
))
elif
filename
.
endswith
(
"
pkl
"
):
pickle
.
dump
(
env_dict
,
file_out
)
@classmethod
def
save_episode
(
cls
,
env
,
filename
):
dict_env
=
cls
.
get_full_state
(
env
)
lAgents
=
dict_env
[
"
agents
"
]
print
(
"
Saving agents:
"
,
len
(
lAgents
))
print
(
"
Agent 0:
"
,
type
(
lAgents
[
0
]),
lAgents
[
0
])
dict_env
[
"
episode
"
]
=
env
.
cur_episode
dict_env
[
"
shape
"
]
=
(
env
.
width
,
env
.
height
)
with
open
(
filename
,
"
wb
"
)
as
file_out
:
if
filename
.
endswith
(
"
.mpk
"
):
file_out
.
write
(
msgpack
.
packb
(
dict_env
))
elif
filename
.
endswith
(
"
.pkl
"
):
pickle
.
dump
(
dict_env
,
file_out
)
@classmethod
def
load
(
cls
,
env
,
filename
,
load_from_package
=
None
):
"""
Load environment with distance map from a file
Parameters:
-------
filename: string
"""
env_dict
=
cls
.
load_env_dict
(
filename
,
load_from_package
=
load_from_package
)
cls
.
set_full_state
(
env
,
env_dict
)
@classmethod
def
load_new
(
cls
,
filename
,
load_from_package
=
None
):
env_dict
=
cls
.
load_env_dict
(
filename
,
load_from_package
=
load_from_package
)
# TODO: inefficient - each one of these generators loads the complete env file.
env
=
rail_env
.
RailEnv
(
width
=
1
,
height
=
1
,
rail_generator
=
rail_gen
.
rail_from_file
(
filename
),
schedule_generator
=
sched_gen
.
schedule_from_file
(
filename
),
malfunction_generator_and_process_data
=
mal_gen
.
malfunction_from_file
(
filename
),
obs_builder_object
=
DummyObservationBuilder
(),
record_steps
=
True
)
env
.
rail
=
GridTransitionMap
(
1
,
1
)
# dummy
cls
.
set_full_state
(
env
,
env_dict
)
return
env
,
env_dict
@classmethod
def
load_env_dict
(
cls
,
filename
,
load_from_package
=
None
):
if
load_from_package
is
not
None
:
from
importlib_resources
import
read_binary
load_data
=
read_binary
(
load_from_package
,
filename
)
else
:
with
open
(
filename
,
"
rb
"
)
as
file_in
:
load_data
=
file_in
.
read
()
if
filename
.
endswith
(
"
mpk
"
):
env_dict
=
msgpack
.
unpackb
(
load_data
,
use_list
=
False
,
encoding
=
"
utf-8
"
)
elif
filename
.
endswith
(
"
pkl
"
):
env_dict
=
pickle
.
loads
(
load_data
)
else
:
print
(
f
"
filename
{
filename
}
must end with either pkl or mpk
"
)
env_dict
=
{}
return
env_dict
@classmethod
def
load_resource
(
cls
,
package
,
resource
):
"""
Load environment (with distance map?) from a binary
"""
from
importlib_resources
import
read_binary
load_data
=
read_binary
(
package
,
resource
)
if
resource
.
endswith
(
"
pkl
"
):
env_dict
=
pickle
.
loads
(
load_data
)
elif
resource
.
endswith
(
"
mpk
"
):
env_dict
=
msgpack
.
unpackb
(
load_data
,
encoding
=
"
utf-8
"
)
cls
.
set_full_state
(
env
,
env_dict
)
@classmethod
def
set_full_state
(
cls
,
env
,
env_dict
):
"""
Sets environment state from env_dict
Parameters
-------
env_dict: dict
"""
env
.
rail
.
grid
=
np
.
array
(
env_dict
[
"
grid
"
])
# agents are always reset as not moving
if
"
agents_static
"
in
env_dict
:
# no idea if this still works
env
.
agents
=
EnvAgent
.
load_legacy_static_agent
(
env_dict
[
"
agents_static
"
])
else
:
env
.
agents
=
[
EnvAgent
(
*
d
[
0
:
12
])
for
d
in
env_dict
[
"
agents
"
]]
# setup with loaded data
env
.
height
,
env
.
width
=
env
.
rail
.
grid
.
shape
env
.
rail
.
height
=
env
.
height
env
.
rail
.
width
=
env
.
width
env
.
dones
=
dict
.
fromkeys
(
list
(
range
(
env
.
get_num_agents
()))
+
[
"
__all__
"
],
False
)
@classmethod
def
get_full_state
(
cls
,
env
):
"""
Returns state of environment in dict object, ready for serialization
"""
grid_data
=
env
.
rail
.
grid
.
tolist
()
# msgpack cannot persist EnvAgent so use the Agent namedtuple.
agent_data
=
[
agent
.
to_agent
()
for
agent
in
env
.
agents
]
malfunction_data
:
MalfunctionProcessData
=
env
.
malfunction_process_data
msg_data_dict
=
{
"
grid
"
:
grid_data
,
"
agents
"
:
agent_data
,
"
malfunction
"
:
malfunction_data
}
return
msg_data_dict
################################################################################################
# deprecated methods moved from RailEnv. Most likely broken.
def
deprecated_get_full_state_msg
(
self
)
->
msgpack
.
Packer
:
"""
Returns state of environment in msgpack object
"""
msg_data_dict
=
self
.
get_full_state_dict
()
return
msgpack
.
packb
(
msg_data_dict
,
use_bin_type
=
True
)
def
deprecated_get_agent_state_msg
(
self
)
->
msgpack
.
Packer
:
"""
Returns agents information in msgpack object
"""
agent_data
=
[
agent
.
to_agent
()
for
agent
in
self
.
agents
]
msg_data
=
{
"
agents
"
:
agent_data
}
return
msgpack
.
packb
(
msg_data
,
use_bin_type
=
True
)
def
deprecated_get_full_state_dist_msg
(
self
)
->
msgpack
.
Packer
:
"""
Returns environment information with distance map information as msgpack object
"""
grid_data
=
self
.
rail
.
grid
.
tolist
()
agent_data
=
[
agent
.
to_agent
()
for
agent
in
self
.
agents
]
# I think these calls do nothing - they create packed data and it is discarded
#msgpack.packb(grid_data, use_bin_type=True)
#msgpack.packb(agent_data, use_bin_type=True)
distance_map_data
=
self
.
distance_map
.
get
()
malfunction_data
:
MalfunctionProcessData
=
self
.
malfunction_process_data
#msgpack.packb(distance_map_data, use_bin_type=True) # does nothing
msg_data
=
{
"
grid
"
:
grid_data
,
"
agents
"
:
agent_data
,
"
distance_map
"
:
distance_map_data
,
"
malfunction
"
:
malfunction_data
}
return
msgpack
.
packb
(
msg_data
,
use_bin_type
=
True
)
def
deprecated_set_full_state_msg
(
self
,
msg_data
):
"""
Sets environment state with msgdata object passed as argument
Parameters
-------
msg_data: msgpack object
"""
data
=
msgpack
.
unpackb
(
msg_data
,
use_list
=
False
,
encoding
=
'
utf-8
'
)
self
.
rail
.
grid
=
np
.
array
(
data
[
"
grid
"
])
# agents are always reset as not moving
if
"
agents_static
"
in
data
:
self
.
agents
=
EnvAgent
.
load_legacy_static_agent
(
data
[
"
agents_static
"
])
else
:
self
.
agents
=
[
EnvAgent
(
*
d
[
0
:
12
])
for
d
in
data
[
"
agents
"
]]
# setup with loaded data
self
.
height
,
self
.
width
=
self
.
rail
.
grid
.
shape
self
.
rail
.
height
=
self
.
height
self
.
rail
.
width
=
self
.
width
self
.
dones
=
dict
.
fromkeys
(
list
(
range
(
self
.
get_num_agents
()))
+
[
"
__all__
"
],
False
)
def
deprecated_set_full_state_dist_msg
(
self
,
msg_data
):
"""
Sets environment grid state and distance map with msgdata object passed as argument
Parameters
-------
msg_data: msgpack object
"""
data
=
msgpack
.
unpackb
(
msg_data
,
use_list
=
False
,
encoding
=
'
utf-8
'
)
self
.
rail
.
grid
=
np
.
array
(
data
[
"
grid
"
])
# agents are always reset as not moving
if
"
agents_static
"
in
data
:
self
.
agents
=
EnvAgent
.
load_legacy_static_agent
(
data
[
"
agents_static
"
])
else
:
self
.
agents
=
[
EnvAgent
(
*
d
[
0
:
12
])
for
d
in
data
[
"
agents
"
]]
if
"
distance_map
"
in
data
.
keys
():
self
.
distance_map
.
set
(
data
[
"
distance_map
"
])
# setup with loaded data
self
.
height
,
self
.
width
=
self
.
rail
.
grid
.
shape
self
.
rail
.
height
=
self
.
height
self
.
rail
.
width
=
self
.
width
self
.
dones
=
dict
.
fromkeys
(
list
(
range
(
self
.
get_num_agents
()))
+
[
"
__all__
"
],
False
)
\ No newline at end of file
This diff is collapsed.
Click to expand it.
requirements_dev.txt
+
1
−
1
View file @
fce9451b
...
...
@@ -9,7 +9,7 @@ recordtype>=1.3
matplotlib>=3.0.2
Pillow>=5.4.1
CairoSVG>=2.3.1
msgpack>=
1.0.0
msgpack>=
0.6.1
msgpack-numpy>=0.4.4.0
svgutils>=0.3.1
pyarrow>=0.13.0
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment