Skip to content
Snippets Groups Projects
Commit 8300a78d authored by hagrid67's avatar hagrid67
Browse files

adding Test_9_Level_1.pkl as test env in render-episode notebook - this one has malfunctions

parent 433f4a84
No related branches found
No related tags found
No related merge requests found
File added
%% Cell type:markdown id: tags:
# Render Episode
Render a stored episode. Env file needs to have "episode" and "action" keys.
- checks that the agent actions match the saved steps (row, col, dir)
- creates a moving gif file of the episode
- displays the episode in a widget with a slider for the time steps.
%% Cell type:markdown id: tags:
# Setup
%% Cell type:code id: tags:
``` python
#!apt -qq install graphviz libgraphviz-dev pkg-config
#!pip install -qq git+
%% Cell type:code id: tags:
``` python
%load_ext autoreload
%autoreload 2
%% Cell type:code id: tags:
``` python
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%% Cell type:code id: tags:
``` python
import PIL
from flatland.utils.rendertools import RenderTool
import imageio
import os
%% Cell type:code id: tags:
``` python
from IPython.display import clear_output
from IPython.core import display
display.display(display.HTML("<style>.container { width:95% !important; }</style>"))
%% Cell type:code id: tags:
``` python
def render_env(env_renderer):
ag0= env_renderer.env.agents[0]
#print("render_env ag0: ",ag0.position, ag0.direction)
aImage = env_renderer.render_env(show_rowcols=True, return_image=True)
pil_image = PIL.Image.fromarray(aImage)
return pil_image
%% Cell type:markdown id: tags:
# Experiments
This has been mostly changed to load envs using `importlib_resources`. It's getting them from the package "envdata.tests`
%% Cell type:code id: tags:
``` python
#env_file = "Test_20_Level_0.pkl"
#env_file = "../../evaluation_visualization/round2/or-0827/Test_23/Level_1.pkl"
#env_file = "../../evaluation_visualization/round2/rl-0827/Test_23/Level_1.pkl"
%% Cell type:code id: tags:
``` python
#if os.path.exists("../env_data"):
# env_file = "../env_data/tests/Test_2_Level_0.pkl"
# env_file = "./env_data/tests/Test_2_Level_0.pkl"
%% Cell type:code id: tags:
``` python
sPack, sResource = "env_data.tests", "Test_2_Level_0.pkl"
#sPack, sResource = "env_data.tests", "Test_2_Level_0.pkl"
sPack, sResource = "env_data.tests", "Test_9_Level_1.pkl"
%% Cell type:code id: tags:
``` python
#env_file = "../../evaluation_visualization/round2/or-0827/Test_23/Level_1.pkl"
#env_file = "../../evaluation_visualization/round2/rl-0827/Test_23/Level_1.pkl"
%% Cell type:code id: tags:
``` python
import pickle
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.schedule_generators import sparse_schedule_generator
from flatland.envs.malfunction_generators import malfunction_from_file, no_malfunction_generator
from flatland.envs.rail_generators import rail_from_file
from flatland.envs.schedule_generators import schedule_from_file
%% Cell type:code id: tags:
``` python
from flatland.envs.persistence import RailEnvPersister
%% Cell type:code id: tags:
``` python
env, env_dict = RailEnvPersister.load_new(sResource, load_from_package=sPack) # env_file)
oRT = RenderTool(env, show_debug=True)
aImg = oRT.render_env(show_rowcols=True, return_image=True, show_inactive_agents=True)
%% Cell type:code id: tags:
``` python
loAgs = env_dict["agents"]
lCols = "initial_direction,direction,initial_position,position".split(",")
pd.DataFrame([ [getattr(oAg, sCol) for sCol in lCols]
for oAg in loAgs], columns=lCols)
%% Cell type:code id: tags:
``` python
pd.DataFrame([ [getattr(oAg, sCol) for sCol in lCols]
for oAg in env.agents], columns=lCols)
%% Cell type:code id: tags:
``` python
pd.DataFrame([ vars(oAg) for oAg in env.agents])
%% Cell type:code id: tags:
``` python
# from
def get_agent_state(env):
list_agents_state = []
for iAg, oAg in enumerate(env.agents):
# the int cast is to avoid numpy types which may cause problems with msgpack
# in env v2, agents may have position None, before starting
if oAg.position is None:
pos = (0, 0)
pos = (int(oAg.position[0]), int(oAg.position[1]))
# print("pos:", pos, type(pos[0]))
[*pos, int(oAg.direction), oAg.malfunction_data["malfunction"]])
return list_agents_state
%% Cell type:code id: tags:
``` python
expert_actions = env_dict['actions']
episode_states = env_dict['episode']
%% Cell type:code id: tags:
``` python
pd.DataFrame([ vars(oAg) for oAg in env.agents])
%% Cell type:code id: tags:
``` python
env_renderer = RenderTool(env, gl="PGL", show_debug=True)
n_agents = env.get_num_agents()
x_dim, y_dim = env.width, env.height
max_steps = env_dict['max_episode_steps']
action_dict = {}
frames = []
# log everything in original state
statuses = []
for a in range(n_agents):
pilImg = render_env(env_renderer)
'image': pilImg,
'statuses': statuses
step = 0
all_done = False
print("Processing episode steps:")
while not all_done and step < max_steps:
print(step, end=", ")
for a in range(n_agents):
if info['action_required'][a]:
if step < len(expert_actions):
if a in expert_actions[step]:
action = expert_actions[step][a]
print("Step {}: agent {} needs action but not provided! only got {}".format(step, a, expert_actions[step]))
action = 0
action_dict.update({a: action})
dAct = expert_actions[step]
if step < len(expert_actions):
next_obs, all_rewards, done, info = env.step(expert_actions[step])
if True:
# Check that agent states match recorded states
if get_agent_state(env) == episode_states[step]:
#print("env:", get_agent_state(env))
#print("epi:", episode_states[step])
llAgSt = get_agent_state(env)
llEpSt = episode_states[step]
for iAg, (lAgSt, lEpSt) in enumerate(zip(llAgSt, llEpSt)):
if lAgSt != lEpSt:
print("Ag:", iAg, "Env: ", lAgSt, "Epi:", lEpSt, end = "; ")
# Force agent states from the recorded states
if False:
for idx, agent in enumerate(env.agents):
rcPos = episode_states[step][idx][0:2]
#print(idx, rcPos)
if rcPos == [0,0]:
agent.position = None
agent.position = (*rcPos,) # episode_states[step][idx][0], episode_states[step][idx][1]#, episode_states[step][idx][2]
agent.malfunction_data["malfunction"] = episode_states[step][idx][3]
agent.direction = int(episode_states[step][idx][2])
agent.old_direction = int(episode_states[step-1][idx][2])
agent.old_position = episode_states[step-1][idx][:2]
statuses = []
for a in range(n_agents):
pilImg = render_env(env_renderer)
'image': pilImg,
'statuses': statuses
#print("Replaying {}/{}".format(step, max_steps))
if done['__all__']:
all_done = True
max_steps = step + 1
step += 1
%% Cell type:code id: tags:
``` python
%% Cell type:code id: tags:
``` python
[ oAg.malfunction_data for oAg in env_dict["agents"] ][:3]
%% Cell type:code id: tags:
``` python
%% Cell type:code id: tags:
``` python
%% Cell type:code id: tags:
``` python
sfImg = sResource.replace("pkl", "gif")
imageio.mimsave(sfImg, [d["image"] for d in frames], subrectangles=True)
%% Cell type:code id: tags:
``` python
%% Cell type:code id: tags:
``` python
%% Cell type:code id: tags:
``` python
from ipywidgets import interact, interactive, fixed, interact_manual, Play
import ipywidgets as widgets
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from IPython.display import HTML
display.display(HTML('<link rel="stylesheet" href="//"/>'))
def plot_func(frame_idx):
frame = frames[int(frame_idx)]
if True:
slider = widgets.FloatSlider(value=0, min=0, max=max_steps, step=1)
interact(plot_func, frame_idx = slider)
play = Play(
), 'value'), (slider, 'value'))
%% Cell type:code id: tags:
``` python
import numpy as np
%% Cell type:code id: tags:
``` python
g3Ep = np.array(episode_states)
np.sum(g3Ep[:,:,3] > 0)
%% Cell type:code id: tags:
``` python
plt.plot(np.sum(g3Ep[:,:,3]>0, axis=1))
plt.title(sResource + "\nmalfunctioning agents by time step")
%% Cell type:code id: tags:
``` python
%% Cell type:code id: tags:
``` python
g3Ep2 = np.array(env.cur_episode)
%% Cell type:code id: tags:
``` python
plt.step(range(len(g3Ep2)), np.sum(g3Ep2[:,:,4]==1, axis=1), label="Active")
plt.step(range(len(g3Ep2)), np.sum(g3Ep2[:,:,4]==0, axis=1), label="Ready to depart")
plt.title("env: "+ sResource +"\nActive Agents by timestep")
%% Cell type:code id: tags:
``` python
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment