From 3764e1f6d1b867ac44a125b5c2f86f0fdf3456fd Mon Sep 17 00:00:00 2001 From: Shivam Khandelwal <skbly7@gmail.com> Date: Sun, 9 May 2021 13:36:59 +0530 Subject: [PATCH] Flights can be present across training data split --- core/dataset.py | 5 +++-- core/flight.py | 5 ++++- core/frame.py | 2 ++ 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/core/dataset.py b/core/dataset.py index ebe625f..1c2d963 100644 --- a/core/dataset.py +++ b/core/dataset.py @@ -10,10 +10,11 @@ class Dataset: metadata = None flights = {} - def __init__(self, local_path, s3_path, download_if_required=True, partial=False): + def __init__(self, local_path, s3_path, download_if_required=True, partial=False, prefix=None): self.file_handler = None self.partial = partial self.valid_encounter = {} + self.prefix = prefix self.add(local_path, s3_path, download_if_required) def load_gt(self): @@ -26,7 +27,7 @@ class Dataset: if self.partial and flight_id not in self.valid_encounter: logger.info("Skipping flight, not present in valid encounters: %s" % flight_id) continue - self.flights[flight_id] = Flight(flight_id, gt["samples"][flight_id], self.file_handler, self.valid_encounter.get(flight_id)) + self.flights[flight_id] = Flight(flight_id, gt["samples"][flight_id], self.file_handler, self.valid_encounter.get(flight_id), prefix=self.prefix) def load_ve(self): if self.partial: diff --git a/core/flight.py b/core/flight.py index f870957..8a5b134 100644 --- a/core/flight.py +++ b/core/flight.py @@ -55,13 +55,16 @@ class Flight: @property def location(self): + if self.prefix: + return 'Images/' + self.prefix + self.id return 'Images/' + self.id - def __init__(self, flight_id, flight_data: dict, file_handler, valid_encounter=None): + def __init__(self, flight_id, flight_data: dict, file_handler, valid_encounter=None, prefix=None): self.id = flight_id self.frames = {} self.detected_objects = {} self.file_handler = file_handler + self.prefix = prefix self.metadata = FlightMetadata(flight_data['metadata']) self.valid_encounter = valid_encounter for entity in flight_data['entities']: diff --git a/core/frame.py b/core/frame.py index a1e3d38..831e36d 100644 --- a/core/frame.py +++ b/core/frame.py @@ -51,6 +51,8 @@ class Frame: return len(self.detected_objects) def image_path(self): + if self.flight.prefix: + return os.path.join('Images', self.flight.prefix + self.flight.id, (str(self.timestamp) + str(self.flight.id) + '.png')) return os.path.join('Images', self.flight.id, (str(self.timestamp) + str(self.flight.id) + '.png')) def image(self, type='pil'): -- GitLab