Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • flatland/flatland
  • stefan_otte/flatland
  • jiaodaxiaozi/flatland
  • sfwatergit/flatland
  • utozx126/flatland
  • ChenKuanSun/flatland
  • ashivani/flatland
  • minhhoa/flatland
  • pranjal_dhole/flatland
  • darthgera123/flatland
  • rivesunder/flatland
  • thomaslecat/flatland
  • joel_joseph/flatland
  • kchour/flatland
  • alex_zharichenko/flatland
  • yoogottamk/flatland
  • troye_fang/flatland
  • elrichgro/flatland
  • jun_jin/flatland
  • nimishsantosh107/flatland
20 results
Show changes
Commits on Source (1)
......@@ -206,7 +206,8 @@ env = RailEnv(width=10,
rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
schedule_generator=complex_schedule_generator(),
number_of_agents=3,
obs_builder_object=CustomObsBuilder)
obs_builder_object=CustomObsBuilder,
save_episodes=True)
obs = env.reset()
env_renderer = RenderTool(env, gl="PILSVG")
......@@ -222,4 +223,8 @@ for step in range(100):
obs, all_rewards, done, _ = env.step(action_dict)
print("Rewards: ", all_rewards, " [done=", done, "]")
env_renderer.render_env(show=True, frames=True, show_observations=True, show_predictions=False)
time.sleep(0.5)
time.sleep(0.01)
sFilename = "saved_episode_{:}x{:}.mpk".format(*env.rail.grid.shape)
env.save(sFilename)
......@@ -108,7 +108,8 @@ class RailEnv(Environment):
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2),
max_episode_steps=None,
stochastic_data=None
stochastic_data=None,
save_episodes=False
):
"""
Environment init.
......@@ -201,6 +202,11 @@ class RailEnv(Environment):
self.valid_positions = None
# save episode timesteps ie agent positions, orientations. (not yet actions / observations)
self.save_episodes = save_episodes
self.episodes = []
self.cur_episode = []
# no more agent_handles
def get_agent_handles(self):
return range(self.get_num_agents())
......@@ -291,7 +297,7 @@ class RailEnv(Environment):
# If counter has come to zero --> Agent has malfunction
# set next malfunction time and duration of current malfunction
if agent.malfunction_data['malfunction_rate'] > 0 >= agent.malfunction_data['malfunction'] and \
agent.malfunction_data['next_malfunction'] <= 0:
agent.malfunction_data['next_malfunction'] <= 0:
# Increase number of malfunctions
agent.malfunction_data['nr_malfunctions'] += 1
......@@ -310,6 +316,10 @@ class RailEnv(Environment):
# TODO refactor to decrease length of this method!
def step(self, action_dict_):
if self.save_episodes:
self.record_timestep()
self._elapsed_steps += 1
# Reset the step rewards
......@@ -364,7 +374,7 @@ class RailEnv(Environment):
self.rewards_dict[i_agent] += self.stop_penalty
if not agent.moving and not (
action == RailEnvActions.DO_NOTHING or action == RailEnvActions.STOP_MOVING):
action == RailEnvActions.DO_NOTHING or action == RailEnvActions.STOP_MOVING):
# Allow agent to start with any forward or direction action
agent.moving = True
self.rewards_dict[i_agent] += self.start_penalty
......@@ -435,8 +445,8 @@ class RailEnv(Environment):
# so we only have to check cell_free now!
# cell and transition validity was checked when we stored transition_action_on_cellexit!
cell_free, new_cell_valid, new_direction, new_position, transition_valid = self._check_action_on_agent(
agent.speed_data['transition_action_on_cellexit'], agent)
cell_free, new_cell_valid, new_direction, new_position, transition_valid = \
self._check_action_on_agent(agent.speed_data['transition_action_on_cellexit'], agent)
if cell_free:
agent.position = new_position
......@@ -475,6 +485,15 @@ class RailEnv(Environment):
return self._get_observations(), self.rewards_dict, self.dones, info_dict
def record_timestep(self):
''' Record the positions and orientations of all agents in memory.
'''
list_timestep = []
for i_agent in range(self.get_num_agents()):
agent = self.agents[i_agent]
list_timestep.append([*agent.position, int(agent.direction)])
self.cur_episode.append(list_timestep)
def _check_action_on_agent(self, action, agent):
# compute number of possible transitions in the current
......@@ -540,13 +559,16 @@ class RailEnv(Environment):
grid_data = self.rail.grid.tolist()
agent_static_data = [agent.to_list() for agent in self.agents_static]
agent_data = [agent.to_list() for agent in self.agents]
episode_data = [self.cur_episode]
msgpack.packb(grid_data, use_bin_type=True)
msgpack.packb(agent_data, use_bin_type=True)
msgpack.packb(agent_static_data, use_bin_type=True)
msgpack.packb(episode_data, use_bin_type=True)
msg_data = {
"grid": grid_data,
"agents_static": agent_static_data,
"agents": agent_data}
"agents": agent_data,
"episodes": episode_data}
return msgpack.packb(msg_data, use_bin_type=True)
def get_agent_state_msg(self):
......@@ -585,9 +607,11 @@ class RailEnv(Environment):
grid_data = self.rail.grid.tolist()
agent_static_data = [agent.to_list() for agent in self.agents_static]
agent_data = [agent.to_list() for agent in self.agents]
episode_data = [self.cur_episode]
msgpack.packb(grid_data, use_bin_type=True)
msgpack.packb(agent_data, use_bin_type=True)
msgpack.packb(agent_static_data, use_bin_type=True)
msgpack.packb(episode_data, use_bin_type=True)
if hasattr(self.obs_builder, 'distance_map'):
distance_map_data = self.obs_builder.distance_map
msgpack.packb(distance_map_data, use_bin_type=True)
......@@ -595,12 +619,14 @@ class RailEnv(Environment):
"grid": grid_data,
"agents_static": agent_static_data,
"agents": agent_data,
"distance_maps": distance_map_data}
"distance_maps": distance_map_data,
"episodes": episode_data}
else:
msg_data = {
"grid": grid_data,
"agents_static": agent_static_data,
"agents": agent_data}
"agents": agent_data,
"episodes": episode_data}
return msgpack.packb(msg_data, use_bin_type=True)
......