Skip to content
Snippets Groups Projects
Commit 8300a78d authored by hagrid67's avatar hagrid67
Browse files

adding Test_9_Level_1.pkl as test env in render-episode notebook - this one has malfunctions

parent 433f4a84
No related branches found
No related tags found
No related merge requests found
File added
%% Cell type:markdown id: tags:
# Render Episode
Render a stored episode. Env file needs to have "episode" and "action" keys.
- checks that the agent actions match the saved steps (row, col, dir)
- creates a moving gif file of the episode
- displays the episode in a widget with a slider for the time steps.
%% Cell type:markdown id: tags:
# Setup
%% Cell type:code id: tags:
``` python
#!apt -qq install graphviz libgraphviz-dev pkg-config
#!pip install -qq git+https://gitlab.aicrowd.com/flatland/flatland.git
```
%% Cell type:code id: tags:
``` python
%load_ext autoreload
%autoreload 2
```
%% Cell type:code id: tags:
``` python
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
```
%% Cell type:code id: tags:
``` python
import PIL
from flatland.utils.rendertools import RenderTool
import imageio
import os
```
%% Cell type:code id: tags:
``` python
from IPython.display import clear_output
from IPython.core import display
display.display(display.HTML("<style>.container { width:95% !important; }</style>"))
```
%% Cell type:code id: tags:
``` python
def render_env(env_renderer):
ag0= env_renderer.env.agents[0]
#print("render_env ag0: ",ag0.position, ag0.direction)
aImage = env_renderer.render_env(show_rowcols=True, return_image=True)
pil_image = PIL.Image.fromarray(aImage)
return pil_image
```
%% Cell type:markdown id: tags:
# Experiments
This has been mostly changed to load envs using `importlib_resources`. It's getting them from the package "envdata.tests`
%% Cell type:code id: tags:
``` python
# ENV FILE PATH
#env_file = "Test_20_Level_0.pkl"
#env_file = "../../evaluation_visualization/round2/or-0827/Test_23/Level_1.pkl"
#env_file = "../../evaluation_visualization/round2/rl-0827/Test_23/Level_1.pkl"
```
%% Cell type:code id: tags:
``` python
#if os.path.exists("../env_data"):
# env_file = "../env_data/tests/Test_2_Level_0.pkl"
#else:
# env_file = "./env_data/tests/Test_2_Level_0.pkl"
```
%% Cell type:code id: tags:
``` python
sPack, sResource = "env_data.tests", "Test_2_Level_0.pkl"
#sPack, sResource = "env_data.tests", "Test_2_Level_0.pkl"
sPack, sResource = "env_data.tests", "Test_9_Level_1.pkl"
```
%% Cell type:code id: tags:
``` python
#env_file = "../../evaluation_visualization/round2/or-0827/Test_23/Level_1.pkl"
#env_file = "../../evaluation_visualization/round2/rl-0827/Test_23/Level_1.pkl"
```
%% Cell type:code id: tags:
``` python
import pickle
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.schedule_generators import sparse_schedule_generator
from flatland.envs.malfunction_generators import malfunction_from_file, no_malfunction_generator
from flatland.envs.rail_generators import rail_from_file
from flatland.envs.schedule_generators import schedule_from_file
```
%% Cell type:code id: tags:
``` python
from flatland.envs.persistence import RailEnvPersister
```
%% Cell type:code id: tags:
``` python
env, env_dict = RailEnvPersister.load_new(sResource, load_from_package=sPack) # env_file)
env.reset(random_seed=1001)
oRT = RenderTool(env, show_debug=True)
aImg = oRT.render_env(show_rowcols=True, return_image=True, show_inactive_agents=True)
print(env._max_episode_steps)
PIL.Image.fromarray(aImg)
```
%% Cell type:code id: tags:
``` python
loAgs = env_dict["agents"]
lCols = "initial_direction,direction,initial_position,position".split(",")
pd.DataFrame([ [getattr(oAg, sCol) for sCol in lCols]
for oAg in loAgs], columns=lCols)
```
%% Cell type:code id: tags:
``` python
pd.DataFrame([ [getattr(oAg, sCol) for sCol in lCols]
for oAg in env.agents], columns=lCols)
```
%% Cell type:code id: tags:
``` python
pd.DataFrame([ vars(oAg) for oAg in env.agents])
```
%% Cell type:code id: tags:
``` python
# from persistence.py
def get_agent_state(env):
list_agents_state = []
for iAg, oAg in enumerate(env.agents):
# the int cast is to avoid numpy types which may cause problems with msgpack
# in env v2, agents may have position None, before starting
if oAg.position is None:
pos = (0, 0)
else:
pos = (int(oAg.position[0]), int(oAg.position[1]))
# print("pos:", pos, type(pos[0]))
list_agents_state.append(
[*pos, int(oAg.direction), oAg.malfunction_data["malfunction"]])
return list_agents_state
```
%% Cell type:code id: tags:
``` python
expert_actions = env_dict['actions']
episode_states = env_dict['episode']
```
%% Cell type:code id: tags:
``` python
pd.DataFrame([ vars(oAg) for oAg in env.agents])
```
%% Cell type:code id: tags:
``` python
env_renderer = RenderTool(env, gl="PGL", show_debug=True)
n_agents = env.get_num_agents()
x_dim, y_dim = env.width, env.height
max_steps = env_dict['max_episode_steps']
action_dict = {}
frames = []
# log everything in original state
statuses = []
for a in range(n_agents):
statuses.append(env.agents[a].status)
pilImg = render_env(env_renderer)
frames.append({
'image': pilImg,
'statuses': statuses
})
step = 0
all_done = False
print("Processing episode steps:")
while not all_done and step < max_steps:
print(step, end=", ")
"""
for a in range(n_agents):
if info['action_required'][a]:
if step < len(expert_actions):
if a in expert_actions[step]:
action = expert_actions[step][a]
else:
print("Step {}: agent {} needs action but not provided! only got {}".format(step, a, expert_actions[step]))
else:
action = 0
action_dict.update({a: action})
"""
dAct = expert_actions[step]
#print(dAct)
if step < len(expert_actions):
next_obs, all_rewards, done, info = env.step(expert_actions[step])
if True:
# Check that agent states match recorded states
if get_agent_state(env) == episode_states[step]:
pass
else:
print("MISMATCH")
#print("env:", get_agent_state(env))
#print("epi:", episode_states[step])
llAgSt = get_agent_state(env)
llEpSt = episode_states[step]
for iAg, (lAgSt, lEpSt) in enumerate(zip(llAgSt, llEpSt)):
if lAgSt != lEpSt:
print("Ag:", iAg, "Env: ", lAgSt, "Epi:", lEpSt, end = "; ")
print("------")
# Force agent states from the recorded states
if False:
for idx, agent in enumerate(env.agents):
#print(episode_states[step][idx])
rcPos = episode_states[step][idx][0:2]
#print(idx, rcPos)
if rcPos == [0,0]:
agent.position = None
else:
agent.position = (*rcPos,) # episode_states[step][idx][0], episode_states[step][idx][1]#, episode_states[step][idx][2]
agent.malfunction_data["malfunction"] = episode_states[step][idx][3]
agent.direction = int(episode_states[step][idx][2])
agent.old_direction = int(episode_states[step-1][idx][2])
agent.old_position = episode_states[step-1][idx][:2]
statuses = []
for a in range(n_agents):
statuses.append(env.agents[a].status)
#clear_output(wait=True)
pilImg = render_env(env_renderer)
frames.append({
'image': pilImg,
'statuses': statuses
})
#print("Replaying {}/{}".format(step, max_steps))
if done['__all__']:
all_done = True
max_steps = step + 1
print("done")
step += 1
```
%% Cell type:code id: tags:
``` python
env.agents[0]
```
%% Cell type:code id: tags:
``` python
[ oAg.malfunction_data for oAg in env_dict["agents"] ][:3]
```
%% Cell type:code id: tags:
``` python
env_dict["malfunction"]
```
%% Cell type:code id: tags:
``` python
env._max_episode_steps
```
%% Cell type:code id: tags:
``` python
sfImg = sResource.replace("pkl", "gif")
imageio.mimsave(sfImg, [d["image"] for d in frames], subrectangles=True)
```
%% Cell type:code id: tags:
``` python
sfImg
```
%% Cell type:code id: tags:
``` python
display.Image(sfImg)
```
%% Cell type:code id: tags:
``` python
from ipywidgets import interact, interactive, fixed, interact_manual, Play
import ipywidgets as widgets
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from IPython.display import HTML
display.display(HTML('<link rel="stylesheet" href="//stackpath.bootstrapcdn.com/font-awesome/4.7.0/css/font-awesome.min.css"/>'))
def plot_func(frame_idx):
frame = frames[int(frame_idx)]
display.display(frame['image'])
#print(frame['statuses'])
if True:
slider = widgets.FloatSlider(value=0, min=0, max=max_steps, step=1)
interact(plot_func, frame_idx = slider)
play = Play(
max=max_steps,
value=0,
step=1,
interval=250
)
widgets.link((play, 'value'), (slider, 'value'))
widgets.VBox([play])
```
%% Cell type:code id: tags:
``` python
import numpy as np
```
%% Cell type:code id: tags:
``` python
g3Ep = np.array(episode_states)
np.sum(g3Ep[:,:,3] > 0)
```
%% Cell type:code id: tags:
``` python
plt.plot(np.sum(g3Ep[:,:,3]>0, axis=1))
plt.title(sResource + "\nmalfunctioning agents by time step")
```
%% Cell type:code id: tags:
``` python
g3Ep.shape
```
%% Cell type:code id: tags:
``` python
g3Ep2 = np.array(env.cur_episode)
g3Ep2.shape
```
%% Cell type:code id: tags:
``` python
plt.step(range(len(g3Ep2)), np.sum(g3Ep2[:,:,4]==1, axis=1), label="Active")
plt.step(range(len(g3Ep2)), np.sum(g3Ep2[:,:,4]==0, axis=1), label="Ready to depart")
plt.title("env: "+ sResource +"\nActive Agents by timestep")
plt.legend()
```
%% Cell type:code id: tags:
``` python
```
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment