From bd5009e0571973d4eef043b6b8a4a303068e0266 Mon Sep 17 00:00:00 2001 From: Shivam Khandelwal <skbly7@gmail.com> Date: Thu, 6 May 2021 13:26:02 +0530 Subject: [PATCH] valid_encounters addition to Dataset class --- core/dataset.py | 25 +++++++++++++++++++++++-- core/flight.py | 15 ++++++++++++++- 2 files changed, 37 insertions(+), 3 deletions(-) diff --git a/core/dataset.py b/core/dataset.py index 0f7d878..ebe625f 100644 --- a/core/dataset.py +++ b/core/dataset.py @@ -10,20 +10,37 @@ class Dataset: metadata = None flights = {} - def __init__(self, local_path, s3_path, download_if_required=True): + def __init__(self, local_path, s3_path, download_if_required=True, partial=False): self.file_handler = None + self.partial = partial + self.valid_encounter = {} self.add(local_path, s3_path, download_if_required) def load_gt(self): logger.info("Loading ground truth...") gt_content = self.file_handler.get_file_content(self.gt_loc) gt = json.loads(gt_content) + self.metadata = gt["metadata"] for flight_id in gt["samples"].keys(): - self.flights[flight_id] = Flight(flight_id, gt["samples"][flight_id], self.file_handler) + if self.partial and flight_id not in self.valid_encounter: + logger.info("Skipping flight, not present in valid encounters: %s" % flight_id) + continue + self.flights[flight_id] = Flight(flight_id, gt["samples"][flight_id], self.file_handler, self.valid_encounter.get(flight_id)) + + def load_ve(self): + if self.partial: + logger.info("Loading valid encounters...") + ve = self.file_handler.get_file_content(self.valid_encounters_loc) + for valid_encounter in ve.split('\n\n '): + valid_encounter = json.loads(valid_encounter) + if valid_encounter["flight_id"] not in self.valid_encounter: + self.valid_encounter[valid_encounter["flight_id"]] = [] + self.valid_encounter[valid_encounter["flight_id"]].append(valid_encounter) def add(self, local_path, s3_path, download_if_required=True): self.file_handler = FileHandler(local_path, s3_path, download_if_required) + self.load_ve() self.load_gt() def get_flight_ids(self): @@ -33,6 +50,10 @@ class Dataset: def gt_loc(self): return 'ImageSets/groundtruth.json' + @property + def valid_encounters_loc(self): + return 'ImageSets/valid_encounters_maxRange700_maxGap3_minEncLen30.json' + def get_flight_by_id(self, flight_id): return self.flights[flight_id] diff --git a/core/flight.py b/core/flight.py index b9ffed8..f870957 100644 --- a/core/flight.py +++ b/core/flight.py @@ -57,15 +57,25 @@ class Flight: def location(self): return 'Images/' + self.id - def __init__(self, flight_id, flight_data: dict, file_handler): + def __init__(self, flight_id, flight_data: dict, file_handler, valid_encounter=None): self.id = flight_id self.frames = {} self.detected_objects = {} self.file_handler = file_handler self.metadata = FlightMetadata(flight_data['metadata']) + self.valid_encounter = valid_encounter for entity in flight_data['entities']: frame_id = entity['blob']['frame'] + if self.valid_encounter is not None: + valid = False + for encounter in self.valid_encounter: + if encounter["framemin"] <= int(frame_id) <= encounter["framemax"]: + valid = True + + if not valid: + continue + if frame_id not in self.frames: self.frames[frame_id] = Frame(entity, self.file_handler, self) @@ -92,6 +102,9 @@ class Flight: return self.detected_objects.values() def get_frame(self, id): + if self.valid_encounter is not None and id not in self.frames: + logger.info("frame_id not present in partial dataset") + return None return self.frames[id] def get_metadata(self): -- GitLab