From 3764e1f6d1b867ac44a125b5c2f86f0fdf3456fd Mon Sep 17 00:00:00 2001
From: Shivam Khandelwal <skbly7@gmail.com>
Date: Sun, 9 May 2021 13:36:59 +0530
Subject: [PATCH] Flights can be present across training data split

---
 core/dataset.py | 5 +++--
 core/flight.py  | 5 ++++-
 core/frame.py   | 2 ++
 3 files changed, 9 insertions(+), 3 deletions(-)

diff --git a/core/dataset.py b/core/dataset.py
index ebe625f..1c2d963 100644
--- a/core/dataset.py
+++ b/core/dataset.py
@@ -10,10 +10,11 @@ class Dataset:
     metadata = None
     flights = {}
 
-    def __init__(self, local_path, s3_path, download_if_required=True, partial=False):
+    def __init__(self, local_path, s3_path, download_if_required=True, partial=False, prefix=None):
         self.file_handler = None
         self.partial = partial
         self.valid_encounter = {}
+        self.prefix = prefix
         self.add(local_path, s3_path, download_if_required)
 
     def load_gt(self):
@@ -26,7 +27,7 @@ class Dataset:
             if self.partial and flight_id not in self.valid_encounter:
                 logger.info("Skipping flight, not present in valid encounters: %s" % flight_id)
                 continue
-            self.flights[flight_id] = Flight(flight_id, gt["samples"][flight_id], self.file_handler, self.valid_encounter.get(flight_id))
+            self.flights[flight_id] = Flight(flight_id, gt["samples"][flight_id], self.file_handler, self.valid_encounter.get(flight_id), prefix=self.prefix)
 
     def load_ve(self):
         if self.partial:
diff --git a/core/flight.py b/core/flight.py
index f870957..8a5b134 100644
--- a/core/flight.py
+++ b/core/flight.py
@@ -55,13 +55,16 @@ class Flight:
 
     @property
     def location(self):
+        if self.prefix:
+            return 'Images/' + self.prefix + self.id
         return 'Images/' + self.id
 
-    def __init__(self, flight_id, flight_data: dict, file_handler, valid_encounter=None):
+    def __init__(self, flight_id, flight_data: dict, file_handler, valid_encounter=None, prefix=None):
         self.id = flight_id
         self.frames = {}
         self.detected_objects = {}
         self.file_handler = file_handler
+        self.prefix = prefix
         self.metadata = FlightMetadata(flight_data['metadata'])
         self.valid_encounter = valid_encounter
         for entity in flight_data['entities']:
diff --git a/core/frame.py b/core/frame.py
index a1e3d38..831e36d 100644
--- a/core/frame.py
+++ b/core/frame.py
@@ -51,6 +51,8 @@ class Frame:
         return len(self.detected_objects)
 
     def image_path(self):
+        if self.flight.prefix:
+            return os.path.join('Images', self.flight.prefix + self.flight.id, (str(self.timestamp) + str(self.flight.id) + '.png'))
         return os.path.join('Images', self.flight.id, (str(self.timestamp) + str(self.flight.id) + '.png'))
 
     def image(self, type='pil'):
-- 
GitLab