diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py index 66e7e7f5862c637d65049afed6b7cf54d458e18a..b228e10b6c146f5692166e179bb9f574a68c9134 100644 --- a/flatland/envs/agent_utils.py +++ b/flatland/envs/agent_utils.py @@ -30,10 +30,11 @@ class EnvAgentStatic(object): lambda: dict({'malfunction': 0, 'malfunction_rate': 0, 'next_malfunction': 0, 'nr_malfunctions': 0}))) @classmethod - def from_lists(cls, positions, directions, targets, speeds=None): + def from_lists(cls, positions, directions, targets, speeds=None, malfunction_rates=None): """ Create a list of EnvAgentStatics from lists of positions, directions and targets """ speed_datas = [] + for i in range(len(positions)): speed_datas.append({'position_fraction': 0.0, 'speed': speeds[i] if speeds is not None else 1.0, @@ -41,10 +42,11 @@ class EnvAgentStatic(object): # TODO: on initialization, all agents are re-set as non-broken. Perhaps it may be desirable to set # some as broken? + malfunction_datas = [] for i in range(len(positions)): malfunction_datas.append({'malfunction': 0, - 'malfunction_rate': 0, + 'malfunction_rate': malfunction_rates[i] if malfunction_rates is not None else 0., 'next_malfunction': 0, 'nr_malfunctions': 0}) diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 6adcb305d29adbb4fc40e681a61ed3200ad622a0..280fd345d8c1db206c42dc30ba2d7b5fa2e8a69e 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -238,7 +238,6 @@ class RailEnv(Environment): agents_hints = optionals['agents_hints'] self.agents_static = EnvAgentStatic.from_lists( *self.schedule_generator(self.rail, self.get_num_agents(), hints=agents_hints)) - self.restart_agents() for i_agent in range(self.get_num_agents()): @@ -248,7 +247,6 @@ class RailEnv(Environment): if np.random.random() < self.proportion_malfunctioning_trains: agent.malfunction_data['malfunction_rate'] = self.mean_malfunction_rate - agent.speed_data['position_fraction'] = 0.0 agent.malfunction_data['malfunction'] = 0 self._agent_malfunction(agent) @@ -510,9 +508,9 @@ class RailEnv(Environment): grid_data = self.rail.grid.tolist() agent_static_data = [agent.to_list() for agent in self.agents_static] agent_data = [agent.to_list() for agent in self.agents] - msgpack.packb(grid_data) - msgpack.packb(agent_data) - msgpack.packb(agent_static_data) + msgpack.packb(grid_data, use_bin_type=True) + msgpack.packb(agent_data, use_bin_type=True) + msgpack.packb(agent_static_data, use_bin_type=True) msg_data = { "grid": grid_data, "agents_static": agent_static_data, @@ -526,11 +524,11 @@ class RailEnv(Environment): return msgpack.packb(msg_data, use_bin_type=True) def set_full_state_msg(self, msg_data): - data = msgpack.unpackb(msg_data, use_list=False) - self.rail.grid = np.array(data[b"grid"]) + data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8') + self.rail.grid = np.array(data["grid"]) # agents are always reset as not moving - self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]] - self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8], d[9]) for d in data[b"agents"]] + self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data["agents_static"]] + self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8]) for d in data["agents"]] # setup with loaded data self.height, self.width = self.rail.grid.shape self.rail.height = self.height @@ -538,13 +536,13 @@ class RailEnv(Environment): self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False) def set_full_state_dist_msg(self, msg_data): - data = msgpack.unpackb(msg_data, use_list=False) - self.rail.grid = np.array(data[b"grid"]) + data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8') + self.rail.grid = np.array(data["grid"]) # agents are always reset as not moving - self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]] - self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8], d[9]) for d in data[b"agents"]] - if hasattr(self.obs_builder, 'distance_map') and b"distance_maps" in data.keys(): - self.obs_builder.distance_map = data[b"distance_maps"] + self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data["agents_static"]] + self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8]) for d in data["agents"]] + if hasattr(self.obs_builder, 'distance_map') and "distance_maps" in data.keys(): + self.obs_builder.distance_map = data["distance_maps"] # setup with loaded data self.height, self.width = self.rail.grid.shape self.rail.height = self.height @@ -555,12 +553,12 @@ class RailEnv(Environment): grid_data = self.rail.grid.tolist() agent_static_data = [agent.to_list() for agent in self.agents_static] agent_data = [agent.to_list() for agent in self.agents] - msgpack.packb(grid_data) - msgpack.packb(agent_data) - msgpack.packb(agent_static_data) + msgpack.packb(grid_data, use_bin_type=True) + msgpack.packb(agent_data, use_bin_type=True) + msgpack.packb(agent_static_data, use_bin_type=True) if hasattr(self.obs_builder, 'distance_map'): distance_map_data = self.obs_builder.distance_map - msgpack.packb(distance_map_data) + msgpack.packb(distance_map_data, use_bin_type=True) msg_data = { "grid": grid_data, "agents_static": agent_static_data, diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py index 918b130ab14c05fd65d64a484907fd3bea11d270..8d1a04d094e78903156d554b297703e7078941be 100644 --- a/flatland/envs/schedule_generators.py +++ b/flatland/envs/schedule_generators.py @@ -224,16 +224,19 @@ def schedule_from_file(filename) -> ScheduleGenerator: def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None) -> ScheduleGeneratorProduct: with open(filename, "rb") as file_in: load_data = file_in.read() - data = msgpack.unpackb(load_data, use_list=False) + data = msgpack.unpackb(load_data, use_list=False, encoding='utf-8') # agents are always reset as not moving - agents_static = [EnvAgentStatic(d[0], d[1], d[2], d[3], d[4], d[5]) for d in data[b"agents_static"]] + agents_static = [EnvAgentStatic(d[0], d[1], d[2], d[3], d[4], d[5]) for d in data["agents_static"]] + # setup with loaded data agents_position = [a.position for a in agents_static] agents_direction = [a.direction for a in agents_static] agents_target = [a.target for a in agents_static] - agents_speed = [a.speed_data[b'speed'] for a in agents_static] - return agents_position, agents_direction, agents_target, agents_speed + agents_speed = [a.speed_data['speed'] for a in agents_static] + agents_malfunction = [a.malfunction_data['malfunction_rate'] for a in agents_static] + return agents_position, agents_direction, agents_target, agents_speed, agents_malfunction return generator +