diff --git a/core/dataset.py b/core/dataset.py index 0f7d87898da6c657269d9bef6aeea32a00fa59d4..ebe625fc1f435b1547457a2de94cba6e4725efac 100644 --- a/core/dataset.py +++ b/core/dataset.py @@ -10,20 +10,37 @@ class Dataset: metadata = None flights = {} - def __init__(self, local_path, s3_path, download_if_required=True): + def __init__(self, local_path, s3_path, download_if_required=True, partial=False): self.file_handler = None + self.partial = partial + self.valid_encounter = {} self.add(local_path, s3_path, download_if_required) def load_gt(self): logger.info("Loading ground truth...") gt_content = self.file_handler.get_file_content(self.gt_loc) gt = json.loads(gt_content) + self.metadata = gt["metadata"] for flight_id in gt["samples"].keys(): - self.flights[flight_id] = Flight(flight_id, gt["samples"][flight_id], self.file_handler) + if self.partial and flight_id not in self.valid_encounter: + logger.info("Skipping flight, not present in valid encounters: %s" % flight_id) + continue + self.flights[flight_id] = Flight(flight_id, gt["samples"][flight_id], self.file_handler, self.valid_encounter.get(flight_id)) + + def load_ve(self): + if self.partial: + logger.info("Loading valid encounters...") + ve = self.file_handler.get_file_content(self.valid_encounters_loc) + for valid_encounter in ve.split('\n\n '): + valid_encounter = json.loads(valid_encounter) + if valid_encounter["flight_id"] not in self.valid_encounter: + self.valid_encounter[valid_encounter["flight_id"]] = [] + self.valid_encounter[valid_encounter["flight_id"]].append(valid_encounter) def add(self, local_path, s3_path, download_if_required=True): self.file_handler = FileHandler(local_path, s3_path, download_if_required) + self.load_ve() self.load_gt() def get_flight_ids(self): @@ -33,6 +50,10 @@ class Dataset: def gt_loc(self): return 'ImageSets/groundtruth.json' + @property + def valid_encounters_loc(self): + return 'ImageSets/valid_encounters_maxRange700_maxGap3_minEncLen30.json' + def get_flight_by_id(self, flight_id): return self.flights[flight_id] diff --git a/core/flight.py b/core/flight.py index b9ffed8e1d19f6d80c394a5ef89ec6842206250d..f8709573a6bf75e21c53ff00638d8856e0b5ae4a 100644 --- a/core/flight.py +++ b/core/flight.py @@ -57,15 +57,25 @@ class Flight: def location(self): return 'Images/' + self.id - def __init__(self, flight_id, flight_data: dict, file_handler): + def __init__(self, flight_id, flight_data: dict, file_handler, valid_encounter=None): self.id = flight_id self.frames = {} self.detected_objects = {} self.file_handler = file_handler self.metadata = FlightMetadata(flight_data['metadata']) + self.valid_encounter = valid_encounter for entity in flight_data['entities']: frame_id = entity['blob']['frame'] + if self.valid_encounter is not None: + valid = False + for encounter in self.valid_encounter: + if encounter["framemin"] <= int(frame_id) <= encounter["framemax"]: + valid = True + + if not valid: + continue + if frame_id not in self.frames: self.frames[frame_id] = Frame(entity, self.file_handler, self) @@ -92,6 +102,9 @@ class Flight: return self.detected_objects.values() def get_frame(self, id): + if self.valid_encounter is not None and id not in self.frames: + logger.info("frame_id not present in partial dataset") + return None return self.frames[id] def get_metadata(self):