diff --git a/examples/play_model.py b/examples/play_model.py index 68530e6f694a3a43aa2f76307f83300df9ac7f6e..04aa55c1e5a713cf7c49a9ac87bc1ba6babd4406 100644 --- a/examples/play_model.py +++ b/examples/play_model.py @@ -10,7 +10,6 @@ import matplotlib.pyplot as plt import time - class Player(object): def __init__(self, env): self.env = env @@ -35,7 +34,9 @@ class Player(object): self.tStart = time.time() # Reset environment - self.obs = self.env.reset() + #self.obs = self.env.reset() + self.env.obs_builder.reset() + self.obs = self.env._get_observations() for a in range(self.env.number_of_agents): norm = max(1, max_lt(self.obs[a], np.inf)) self.obs[a] = np.clip(np.array(self.obs[a]) / norm, -1, 1) @@ -148,7 +149,7 @@ def main(render=True, delay=0.0): for trials in range(1, n_trials + 1): # Reset environment - obs = env.reset() + # obs = env.reset() for a in range(env.number_of_agents): norm = max(1, max_lt(obs[a],np.inf)) diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 0e150c62a2c8ef0e5d5623abf1407c9570606e3b..6b0c64a17cd22b0d26073755e9d26509d53334cf 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -464,11 +464,84 @@ class RailEnv(Environment): self.reset() self.num_resets = 0 + self.valid_positions = None + def get_agent_handles(self): return self.agents_handles - def reset(self): - self.rail = self.rail_generator(self.width, self.height, self.num_resets) + def fill_valid_positions(self): + self.valid_positions = valid_positions = [] + for r in range(self.height): + for c in range(self.width): + if self.rail.get_transitions((r, c)) > 0: + valid_positions.append((r, c)) + + def check_agent_lists(self): + for lAgents, name in zip( + [self.agents_handles, self.agents_position, self.agents_direction], + ["handles", "positions", "directions"]): + assert self.number_of_agents == len(lAgents), "Inconsistent agent list:"+name + + def check_agent_locdirpath(self, iAgent): + valid_movements = [] + for direction in range(4): + position = self.agents_position[iAgent] + moves = self.rail.get_transitions((position[0], position[1], direction)) + for move_index in range(4): + if moves[move_index]: + valid_movements.append((direction, move_index)) + + valid_starting_directions = [] + for m in valid_movements: + new_position = self._new_position(self.agents_position[iAgent], m[1]) + if m[0] not in valid_starting_directions and \ + self._path_exists(new_position, m[0], self.agents_target[iAgent]): + valid_starting_directions.append(m[0]) + + if len(valid_starting_directions) == 0: + return False + + def pick_agent_direction(self, rcPos, rcTarget): + valid_movements = [] + for direction in range(4): + moves = self.rail.get_transitions((*rcPos, direction)) + for move_index in range(4): + if moves[move_index]: + valid_movements.append((direction, move_index)) + + valid_starting_directions = [] + for m in valid_movements: + new_position = self._new_position(rcPos, m[1]) + if m[0] not in valid_starting_directions and \ + self._path_exists(new_position, m[0], rcTarget): + valid_starting_directions.append(m[0]) + + if len(valid_starting_directions) == 0: + return None + else: + return valid_starting_directions[np.random.choice(len(valid_starting_directions), 1)[0]] + + def add_agent(self, rcPos=None, rcTarget=None, iDir=None): + self.check_agent_lists() + + if rcPos is None: + rcPos = np.random.choice(len(self.valid_positions)) + + # iAgent = self.number_of_agents + self.number_of_agents += 1 + + self.env.agents_position.append(rcPos) + self.env.agents_handles.append(max(self.env.agents_handles + [-1]) + 1) # max(handles) + 1, starting at 0 + self.env.agents_direction.append(0) + self.env.agents_target.append(rcPos) # set the target to the origin initially + + self.check_agent_lists() + + def reset(self, regen_rail=True, replace_agents=True): + if regen_rail or self.rail is None: + self.rail = self.rail_generator(self.width, self.height, self.num_resets) + self.fill_valid_positions() + self.num_resets += 1 self.dones = {"__all__": False} @@ -477,50 +550,58 @@ class RailEnv(Environment): # Use a TreeObsForRailEnv to compute distance maps to each agent's target, to sample initial # agent's orientations that allow a valid solution. - re_generate = True - while re_generate: - valid_positions = [] - for r in range(self.height): - for c in range(self.width): - if self.rail.get_transitions((r, c)) > 0: - valid_positions.append((r, c)) - - # self.agents_position = random.sample(valid_positions, - # self.number_of_agents) - self.agents_position = [ - valid_positions[i] for i in - np.random.choice(len(valid_positions), self.number_of_agents)] - self.agents_target = [ - valid_positions[i] for i in - np.random.choice(len(valid_positions), self.number_of_agents)] - - # agents_direction must be a direction for which a solution is - # guaranteed. - self.agents_direction = [0] * self.number_of_agents - re_generate = False - for i in range(self.number_of_agents): - valid_movements = [] - for direction in range(4): - position = self.agents_position[i] - moves = self.rail.get_transitions((position[0], position[1], direction)) - for move_index in range(4): - if moves[move_index]: - valid_movements.append((direction, move_index)) - - valid_starting_directions = [] - for m in valid_movements: - new_position = self._new_position(self.agents_position[i], - m[1]) - if m[0] not in valid_starting_directions and \ - self._path_exists(new_position, m[0], - self.agents_target[i]): - valid_starting_directions.append(m[0]) - - if len(valid_starting_directions) == 0: - re_generate = True - else: - self.agents_direction[i] = valid_starting_directions[ - np.random.choice(len(valid_starting_directions), 1)[0]] + + self.fill_valid_positions() + + if replace_agents: + re_generate = True + while re_generate: + + # self.agents_position = random.sample(valid_positions, + # self.number_of_agents) + self.agents_position = [ + self.valid_positions[i] for i in + np.random.choice(len(self.valid_positions), self.number_of_agents)] + self.agents_target = [ + self.valid_positions[i] for i in + np.random.choice(len(self.valid_positions), self.number_of_agents)] + + # agents_direction must be a direction for which a solution is + # guaranteed. + self.agents_direction = [0] * self.number_of_agents + re_generate = False + + for i in range(self.number_of_agents): + direction = self.pick_agent_direction(self.agents_position[i], self.agents_target) + if direction is None: + re_generate = True + break + else: + self.agents_direction = direction + + # Jeremy extracted this into the method pick_agent_direction + if False: + for i in range(self.number_of_agents): + valid_movements = [] + for direction in range(4): + position = self.agents_position[i] + moves = self.rail.get_transitions((position[0], position[1], direction)) + for move_index in range(4): + if moves[move_index]: + valid_movements.append((direction, move_index)) + + valid_starting_directions = [] + for m in valid_movements: + new_position = self._new_position(self.agents_position[i], m[1]) + if m[0] not in valid_starting_directions and \ + self._path_exists(new_position, m[0], self.agents_target[i]): + valid_starting_directions.append(m[0]) + + if len(valid_starting_directions) == 0: + re_generate = True + else: + self.agents_direction[i] = valid_starting_directions[ + np.random.choice(len(valid_starting_directions), 1)[0]] # Reset the state of the observation builder with the new environment self.obs_builder.reset() diff --git a/flatland/utils/editor.py b/flatland/utils/editor.py index 16b0b029da535c235008dfdacf8c5eec4ecd619a..4c612b53c08798a7de0c4eab072cf540f8df0f20 100644 --- a/flatland/utils/editor.py +++ b/flatland/utils/editor.py @@ -3,6 +3,7 @@ from numpy import array import time from collections import deque from matplotlib import pyplot as plt +import threading # from contextlib import redirect_stdout # import os # import sys @@ -54,6 +55,7 @@ class JupEditor(object): self.set_env(env) self.iAgent = None self.player = None + self.thread = None def set_env(self, env): self.env = env @@ -77,14 +79,15 @@ class JupEditor(object): rcCell = ((array([y, x]) - self.yxBase) / self.nPixCell).astype(int) if self.drawMode == "Origin": - self.iAgent = len(self.env.agents_position) - self.env.agents_position.append(rcCell) - self.env.agents_handles.append(max(self.env.agents_handles + [-1]) + 1) - self.env.agents_direction.append(0) - self.env.agents_target.append(rcCell) # set the target to the origin initially - self.env.number_of_agents = self.iAgent + 1 + self.env.add_agent(rcCell, rcCell, 0) + # self.iAgent = len(self.env.agents_position) + # self.env.agents_position.append(rcCell) + # self.env.agents_handles.append(max(self.env.agents_handles + [-1]) + 1) + # self.env.agents_direction.append(0) + # self.env.agents_target.append(rcCell) # set the target to the origin initially + # self.env.number_of_agents = self.iAgent + 1 self.drawMode = "Destination" - + self.player = None # will need to start a new player elif self.drawMode == "Destination" and self.iAgent is not None: self.env.agents_target[self.iAgent] = rcCell self.drawMode = "Origin" @@ -203,9 +206,9 @@ class JupEditor(object): plt.clf() plt.close() - if update: - self.wid_img.data = img - return img + if update: + self.wid_img.data = img + return img def redraw_event(self, event): img = self.redraw() @@ -245,7 +248,7 @@ class JupEditor(object): rail_generator=random_rail_generator(cell_type_relative_proportion=[1, 1] + [0.5] * 6), number_of_agents=self.env.number_of_agents, obs_builder_object=TreeObsForRailEnv(max_depth=2)) - self.env.reset() + self.env.reset(regen_rail=True) self.set_env(self.env) self.player = Player(self.env) self.redraw() @@ -256,9 +259,29 @@ class JupEditor(object): def step_event(self, event=None): if self.player is None: self.player = Player(self.env) + self.env.reset(regen_rail=False) self.player.step() self.redraw() + def start_run_event(self, event=None): + if self.thread is None: + self.thread = threading.Thread(target=self.bg_updater, args=()) + self.thread.start() + else: + self.log("thread already present") + + + + def bg_updater(self): + try: + for i in range(20): + self.log("step ", i) + self.step_event() + time.sleep(0.2) + finally: + self.thread = None + + def fix_env(self): self.env.width = self.env.rail.width self.env.height = self.env.rail.height