From bd5009e0571973d4eef043b6b8a4a303068e0266 Mon Sep 17 00:00:00 2001
From: Shivam Khandelwal <skbly7@gmail.com>
Date: Thu, 6 May 2021 13:26:02 +0530
Subject: [PATCH] valid_encounters addition to Dataset class

---
 core/dataset.py | 25 +++++++++++++++++++++++--
 core/flight.py  | 15 ++++++++++++++-
 2 files changed, 37 insertions(+), 3 deletions(-)

diff --git a/core/dataset.py b/core/dataset.py
index 0f7d878..ebe625f 100644
--- a/core/dataset.py
+++ b/core/dataset.py
@@ -10,20 +10,37 @@ class Dataset:
     metadata = None
     flights = {}
 
-    def __init__(self, local_path, s3_path, download_if_required=True):
+    def __init__(self, local_path, s3_path, download_if_required=True, partial=False):
         self.file_handler = None
+        self.partial = partial
+        self.valid_encounter = {}
         self.add(local_path, s3_path, download_if_required)
 
     def load_gt(self):
         logger.info("Loading ground truth...")
         gt_content = self.file_handler.get_file_content(self.gt_loc)
         gt = json.loads(gt_content)
+
         self.metadata = gt["metadata"]
         for flight_id in gt["samples"].keys():
-            self.flights[flight_id] = Flight(flight_id, gt["samples"][flight_id], self.file_handler)
+            if self.partial and flight_id not in self.valid_encounter:
+                logger.info("Skipping flight, not present in valid encounters: %s" % flight_id)
+                continue
+            self.flights[flight_id] = Flight(flight_id, gt["samples"][flight_id], self.file_handler, self.valid_encounter.get(flight_id))
+
+    def load_ve(self):
+        if self.partial:
+            logger.info("Loading valid encounters...")
+            ve = self.file_handler.get_file_content(self.valid_encounters_loc)
+            for valid_encounter in ve.split('\n\n    '):
+                valid_encounter = json.loads(valid_encounter)
+                if valid_encounter["flight_id"] not in self.valid_encounter:
+                    self.valid_encounter[valid_encounter["flight_id"]] = []
+                self.valid_encounter[valid_encounter["flight_id"]].append(valid_encounter)
 
     def add(self, local_path, s3_path, download_if_required=True):
         self.file_handler = FileHandler(local_path, s3_path, download_if_required)
+        self.load_ve()
         self.load_gt()
 
     def get_flight_ids(self):
@@ -33,6 +50,10 @@ class Dataset:
     def gt_loc(self):
         return 'ImageSets/groundtruth.json'
 
+    @property
+    def valid_encounters_loc(self):
+        return 'ImageSets/valid_encounters_maxRange700_maxGap3_minEncLen30.json'
+
     def get_flight_by_id(self, flight_id):
         return self.flights[flight_id]
 
diff --git a/core/flight.py b/core/flight.py
index b9ffed8..f870957 100644
--- a/core/flight.py
+++ b/core/flight.py
@@ -57,15 +57,25 @@ class Flight:
     def location(self):
         return 'Images/' + self.id
 
-    def __init__(self, flight_id, flight_data: dict, file_handler):
+    def __init__(self, flight_id, flight_data: dict, file_handler, valid_encounter=None):
         self.id = flight_id
         self.frames = {}
         self.detected_objects = {}
         self.file_handler = file_handler
         self.metadata = FlightMetadata(flight_data['metadata'])
+        self.valid_encounter = valid_encounter
         for entity in flight_data['entities']:
             frame_id = entity['blob']['frame']
 
+            if self.valid_encounter is not None:
+                valid = False
+                for encounter in self.valid_encounter:
+                    if encounter["framemin"] <= int(frame_id) <= encounter["framemax"]:
+                        valid = True
+
+                if not valid:
+                    continue
+
             if frame_id not in self.frames:
                 self.frames[frame_id] = Frame(entity, self.file_handler, self)
 
@@ -92,6 +102,9 @@ class Flight:
         return self.detected_objects.values()
 
     def get_frame(self, id):
+        if self.valid_encounter is not None and id not in self.frames:
+            logger.info("frame_id not present in partial dataset")
+            return None
         return self.frames[id]
 
     def get_metadata(self):
-- 
GitLab