Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • flatland/flatland
  • stefan_otte/flatland
  • jiaodaxiaozi/flatland
  • sfwatergit/flatland
  • utozx126/flatland
  • ChenKuanSun/flatland
  • ashivani/flatland
  • minhhoa/flatland
  • pranjal_dhole/flatland
  • darthgera123/flatland
  • rivesunder/flatland
  • thomaslecat/flatland
  • joel_joseph/flatland
  • kchour/flatland
  • alex_zharichenko/flatland
  • yoogottamk/flatland
  • troye_fang/flatland
  • elrichgro/flatland
  • jun_jin/flatland
  • nimishsantosh107/flatland
20 results
Show changes
Commits on Source (1)
...@@ -206,7 +206,8 @@ env = RailEnv(width=10, ...@@ -206,7 +206,8 @@ env = RailEnv(width=10,
rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=1, min_dist=8, max_dist=99999, seed=0), rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
schedule_generator=complex_schedule_generator(), schedule_generator=complex_schedule_generator(),
number_of_agents=3, number_of_agents=3,
obs_builder_object=CustomObsBuilder) obs_builder_object=CustomObsBuilder,
save_episodes=True)
obs = env.reset() obs = env.reset()
env_renderer = RenderTool(env, gl="PILSVG") env_renderer = RenderTool(env, gl="PILSVG")
...@@ -222,4 +223,8 @@ for step in range(100): ...@@ -222,4 +223,8 @@ for step in range(100):
obs, all_rewards, done, _ = env.step(action_dict) obs, all_rewards, done, _ = env.step(action_dict)
print("Rewards: ", all_rewards, " [done=", done, "]") print("Rewards: ", all_rewards, " [done=", done, "]")
env_renderer.render_env(show=True, frames=True, show_observations=True, show_predictions=False) env_renderer.render_env(show=True, frames=True, show_observations=True, show_predictions=False)
time.sleep(0.5) time.sleep(0.01)
sFilename = "saved_episode_{:}x{:}.mpk".format(*env.rail.grid.shape)
env.save(sFilename)
...@@ -108,7 +108,8 @@ class RailEnv(Environment): ...@@ -108,7 +108,8 @@ class RailEnv(Environment):
number_of_agents=1, number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2), obs_builder_object=TreeObsForRailEnv(max_depth=2),
max_episode_steps=None, max_episode_steps=None,
stochastic_data=None stochastic_data=None,
save_episodes=False
): ):
""" """
Environment init. Environment init.
...@@ -201,6 +202,11 @@ class RailEnv(Environment): ...@@ -201,6 +202,11 @@ class RailEnv(Environment):
self.valid_positions = None self.valid_positions = None
# save episode timesteps ie agent positions, orientations. (not yet actions / observations)
self.save_episodes = save_episodes
self.episodes = []
self.cur_episode = []
# no more agent_handles # no more agent_handles
def get_agent_handles(self): def get_agent_handles(self):
return range(self.get_num_agents()) return range(self.get_num_agents())
...@@ -291,7 +297,7 @@ class RailEnv(Environment): ...@@ -291,7 +297,7 @@ class RailEnv(Environment):
# If counter has come to zero --> Agent has malfunction # If counter has come to zero --> Agent has malfunction
# set next malfunction time and duration of current malfunction # set next malfunction time and duration of current malfunction
if agent.malfunction_data['malfunction_rate'] > 0 >= agent.malfunction_data['malfunction'] and \ if agent.malfunction_data['malfunction_rate'] > 0 >= agent.malfunction_data['malfunction'] and \
agent.malfunction_data['next_malfunction'] <= 0: agent.malfunction_data['next_malfunction'] <= 0:
# Increase number of malfunctions # Increase number of malfunctions
agent.malfunction_data['nr_malfunctions'] += 1 agent.malfunction_data['nr_malfunctions'] += 1
...@@ -310,6 +316,10 @@ class RailEnv(Environment): ...@@ -310,6 +316,10 @@ class RailEnv(Environment):
# TODO refactor to decrease length of this method! # TODO refactor to decrease length of this method!
def step(self, action_dict_): def step(self, action_dict_):
if self.save_episodes:
self.record_timestep()
self._elapsed_steps += 1 self._elapsed_steps += 1
# Reset the step rewards # Reset the step rewards
...@@ -364,7 +374,7 @@ class RailEnv(Environment): ...@@ -364,7 +374,7 @@ class RailEnv(Environment):
self.rewards_dict[i_agent] += self.stop_penalty self.rewards_dict[i_agent] += self.stop_penalty
if not agent.moving and not ( if not agent.moving and not (
action == RailEnvActions.DO_NOTHING or action == RailEnvActions.STOP_MOVING): action == RailEnvActions.DO_NOTHING or action == RailEnvActions.STOP_MOVING):
# Allow agent to start with any forward or direction action # Allow agent to start with any forward or direction action
agent.moving = True agent.moving = True
self.rewards_dict[i_agent] += self.start_penalty self.rewards_dict[i_agent] += self.start_penalty
...@@ -435,8 +445,8 @@ class RailEnv(Environment): ...@@ -435,8 +445,8 @@ class RailEnv(Environment):
# so we only have to check cell_free now! # so we only have to check cell_free now!
# cell and transition validity was checked when we stored transition_action_on_cellexit! # cell and transition validity was checked when we stored transition_action_on_cellexit!
cell_free, new_cell_valid, new_direction, new_position, transition_valid = self._check_action_on_agent( cell_free, new_cell_valid, new_direction, new_position, transition_valid = \
agent.speed_data['transition_action_on_cellexit'], agent) self._check_action_on_agent(agent.speed_data['transition_action_on_cellexit'], agent)
if cell_free: if cell_free:
agent.position = new_position agent.position = new_position
...@@ -475,6 +485,15 @@ class RailEnv(Environment): ...@@ -475,6 +485,15 @@ class RailEnv(Environment):
return self._get_observations(), self.rewards_dict, self.dones, info_dict return self._get_observations(), self.rewards_dict, self.dones, info_dict
def record_timestep(self):
''' Record the positions and orientations of all agents in memory.
'''
list_timestep = []
for i_agent in range(self.get_num_agents()):
agent = self.agents[i_agent]
list_timestep.append([*agent.position, int(agent.direction)])
self.cur_episode.append(list_timestep)
def _check_action_on_agent(self, action, agent): def _check_action_on_agent(self, action, agent):
# compute number of possible transitions in the current # compute number of possible transitions in the current
...@@ -540,13 +559,16 @@ class RailEnv(Environment): ...@@ -540,13 +559,16 @@ class RailEnv(Environment):
grid_data = self.rail.grid.tolist() grid_data = self.rail.grid.tolist()
agent_static_data = [agent.to_list() for agent in self.agents_static] agent_static_data = [agent.to_list() for agent in self.agents_static]
agent_data = [agent.to_list() for agent in self.agents] agent_data = [agent.to_list() for agent in self.agents]
episode_data = [self.cur_episode]
msgpack.packb(grid_data, use_bin_type=True) msgpack.packb(grid_data, use_bin_type=True)
msgpack.packb(agent_data, use_bin_type=True) msgpack.packb(agent_data, use_bin_type=True)
msgpack.packb(agent_static_data, use_bin_type=True) msgpack.packb(agent_static_data, use_bin_type=True)
msgpack.packb(episode_data, use_bin_type=True)
msg_data = { msg_data = {
"grid": grid_data, "grid": grid_data,
"agents_static": agent_static_data, "agents_static": agent_static_data,
"agents": agent_data} "agents": agent_data,
"episodes": episode_data}
return msgpack.packb(msg_data, use_bin_type=True) return msgpack.packb(msg_data, use_bin_type=True)
def get_agent_state_msg(self): def get_agent_state_msg(self):
...@@ -585,9 +607,11 @@ class RailEnv(Environment): ...@@ -585,9 +607,11 @@ class RailEnv(Environment):
grid_data = self.rail.grid.tolist() grid_data = self.rail.grid.tolist()
agent_static_data = [agent.to_list() for agent in self.agents_static] agent_static_data = [agent.to_list() for agent in self.agents_static]
agent_data = [agent.to_list() for agent in self.agents] agent_data = [agent.to_list() for agent in self.agents]
episode_data = [self.cur_episode]
msgpack.packb(grid_data, use_bin_type=True) msgpack.packb(grid_data, use_bin_type=True)
msgpack.packb(agent_data, use_bin_type=True) msgpack.packb(agent_data, use_bin_type=True)
msgpack.packb(agent_static_data, use_bin_type=True) msgpack.packb(agent_static_data, use_bin_type=True)
msgpack.packb(episode_data, use_bin_type=True)
if hasattr(self.obs_builder, 'distance_map'): if hasattr(self.obs_builder, 'distance_map'):
distance_map_data = self.obs_builder.distance_map distance_map_data = self.obs_builder.distance_map
msgpack.packb(distance_map_data, use_bin_type=True) msgpack.packb(distance_map_data, use_bin_type=True)
...@@ -595,12 +619,14 @@ class RailEnv(Environment): ...@@ -595,12 +619,14 @@ class RailEnv(Environment):
"grid": grid_data, "grid": grid_data,
"agents_static": agent_static_data, "agents_static": agent_static_data,
"agents": agent_data, "agents": agent_data,
"distance_maps": distance_map_data} "distance_maps": distance_map_data,
"episodes": episode_data}
else: else:
msg_data = { msg_data = {
"grid": grid_data, "grid": grid_data,
"agents_static": agent_static_data, "agents_static": agent_static_data,
"agents": agent_data} "agents": agent_data,
"episodes": episode_data}
return msgpack.packb(msg_data, use_bin_type=True) return msgpack.packb(msg_data, use_bin_type=True)
......