Skip to content
Snippets Groups Projects
Commit b6c4bb58 authored by hagrid67's avatar hagrid67
Browse files

added basic save episode functionality to rail_env

and hacked custom_observation_example.py to save an env with an episode
parent 77511dfc
No related branches found
No related tags found
No related merge requests found
......@@ -206,7 +206,8 @@ env = RailEnv(width=10,
rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
schedule_generator=complex_schedule_generator(),
number_of_agents=3,
obs_builder_object=CustomObsBuilder)
obs_builder_object=CustomObsBuilder,
save_episodes=True)
obs = env.reset()
env_renderer = RenderTool(env, gl="PILSVG")
......@@ -222,4 +223,8 @@ for step in range(100):
obs, all_rewards, done, _ = env.step(action_dict)
print("Rewards: ", all_rewards, " [done=", done, "]")
env_renderer.render_env(show=True, frames=True, show_observations=True, show_predictions=False)
time.sleep(0.5)
time.sleep(0.01)
sFilename = "saved_episode_{:}x{:}.mpk".format(*env.rail.grid.shape)
env.save(sFilename)
......@@ -108,7 +108,8 @@ class RailEnv(Environment):
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2),
max_episode_steps=None,
stochastic_data=None
stochastic_data=None,
save_episodes=False
):
"""
Environment init.
......@@ -201,6 +202,11 @@ class RailEnv(Environment):
self.valid_positions = None
# save episode timesteps ie agent positions, orientations. (not yet actions / observations)
self.save_episodes = save_episodes
self.episodes = []
self.cur_episode = []
# no more agent_handles
def get_agent_handles(self):
return range(self.get_num_agents())
......@@ -291,7 +297,7 @@ class RailEnv(Environment):
# If counter has come to zero --> Agent has malfunction
# set next malfunction time and duration of current malfunction
if agent.malfunction_data['malfunction_rate'] > 0 >= agent.malfunction_data['malfunction'] and \
agent.malfunction_data['next_malfunction'] <= 0:
agent.malfunction_data['next_malfunction'] <= 0:
# Increase number of malfunctions
agent.malfunction_data['nr_malfunctions'] += 1
......@@ -310,6 +316,10 @@ class RailEnv(Environment):
# TODO refactor to decrease length of this method!
def step(self, action_dict_):
if self.save_episodes:
self.record_timestep()
self._elapsed_steps += 1
# Reset the step rewards
......@@ -364,7 +374,7 @@ class RailEnv(Environment):
self.rewards_dict[i_agent] += self.stop_penalty
if not agent.moving and not (
action == RailEnvActions.DO_NOTHING or action == RailEnvActions.STOP_MOVING):
action == RailEnvActions.DO_NOTHING or action == RailEnvActions.STOP_MOVING):
# Allow agent to start with any forward or direction action
agent.moving = True
self.rewards_dict[i_agent] += self.start_penalty
......@@ -435,8 +445,8 @@ class RailEnv(Environment):
# so we only have to check cell_free now!
# cell and transition validity was checked when we stored transition_action_on_cellexit!
cell_free, new_cell_valid, new_direction, new_position, transition_valid = self._check_action_on_agent(
agent.speed_data['transition_action_on_cellexit'], agent)
cell_free, new_cell_valid, new_direction, new_position, transition_valid = \
self._check_action_on_agent(agent.speed_data['transition_action_on_cellexit'], agent)
if cell_free:
agent.position = new_position
......@@ -475,6 +485,15 @@ class RailEnv(Environment):
return self._get_observations(), self.rewards_dict, self.dones, info_dict
def record_timestep(self):
''' Record the positions and orientations of all agents in memory.
'''
list_timestep = []
for i_agent in range(self.get_num_agents()):
agent = self.agents[i_agent]
list_timestep.append([*agent.position, int(agent.direction)])
self.cur_episode.append(list_timestep)
def _check_action_on_agent(self, action, agent):
# compute number of possible transitions in the current
......@@ -540,13 +559,16 @@ class RailEnv(Environment):
grid_data = self.rail.grid.tolist()
agent_static_data = [agent.to_list() for agent in self.agents_static]
agent_data = [agent.to_list() for agent in self.agents]
episode_data = [self.cur_episode]
msgpack.packb(grid_data, use_bin_type=True)
msgpack.packb(agent_data, use_bin_type=True)
msgpack.packb(agent_static_data, use_bin_type=True)
msgpack.packb(episode_data, use_bin_type=True)
msg_data = {
"grid": grid_data,
"agents_static": agent_static_data,
"agents": agent_data}
"agents": agent_data,
"episodes": episode_data}
return msgpack.packb(msg_data, use_bin_type=True)
def get_agent_state_msg(self):
......@@ -585,9 +607,11 @@ class RailEnv(Environment):
grid_data = self.rail.grid.tolist()
agent_static_data = [agent.to_list() for agent in self.agents_static]
agent_data = [agent.to_list() for agent in self.agents]
episode_data = [self.cur_episode]
msgpack.packb(grid_data, use_bin_type=True)
msgpack.packb(agent_data, use_bin_type=True)
msgpack.packb(agent_static_data, use_bin_type=True)
msgpack.packb(episode_data, use_bin_type=True)
if hasattr(self.obs_builder, 'distance_map'):
distance_map_data = self.obs_builder.distance_map
msgpack.packb(distance_map_data, use_bin_type=True)
......@@ -595,12 +619,14 @@ class RailEnv(Environment):
"grid": grid_data,
"agents_static": agent_static_data,
"agents": agent_data,
"distance_maps": distance_map_data}
"distance_maps": distance_map_data,
"episodes": episode_data}
else:
msg_data = {
"grid": grid_data,
"agents_static": agent_static_data,
"agents": agent_data}
"agents": agent_data,
"episodes": episode_data}
return msgpack.packb(msg_data, use_bin_type=True)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment