Skip to content
Snippets Groups Projects
Commit bd5009e0 authored by Shivam Khandelwal's avatar Shivam Khandelwal
Browse files

valid_encounters addition to Dataset class

parent 7d3b2d4f
No related branches found
No related tags found
No related merge requests found
......@@ -10,20 +10,37 @@ class Dataset:
metadata = None
flights = {}
def __init__(self, local_path, s3_path, download_if_required=True):
def __init__(self, local_path, s3_path, download_if_required=True, partial=False):
self.file_handler = None
self.partial = partial
self.valid_encounter = {}
self.add(local_path, s3_path, download_if_required)
def load_gt(self):
logger.info("Loading ground truth...")
gt_content = self.file_handler.get_file_content(self.gt_loc)
gt = json.loads(gt_content)
self.metadata = gt["metadata"]
for flight_id in gt["samples"].keys():
self.flights[flight_id] = Flight(flight_id, gt["samples"][flight_id], self.file_handler)
if self.partial and flight_id not in self.valid_encounter:
logger.info("Skipping flight, not present in valid encounters: %s" % flight_id)
continue
self.flights[flight_id] = Flight(flight_id, gt["samples"][flight_id], self.file_handler, self.valid_encounter.get(flight_id))
def load_ve(self):
if self.partial:
logger.info("Loading valid encounters...")
ve = self.file_handler.get_file_content(self.valid_encounters_loc)
for valid_encounter in ve.split('\n\n '):
valid_encounter = json.loads(valid_encounter)
if valid_encounter["flight_id"] not in self.valid_encounter:
self.valid_encounter[valid_encounter["flight_id"]] = []
self.valid_encounter[valid_encounter["flight_id"]].append(valid_encounter)
def add(self, local_path, s3_path, download_if_required=True):
self.file_handler = FileHandler(local_path, s3_path, download_if_required)
self.load_ve()
self.load_gt()
def get_flight_ids(self):
......@@ -33,6 +50,10 @@ class Dataset:
def gt_loc(self):
return 'ImageSets/groundtruth.json'
@property
def valid_encounters_loc(self):
return 'ImageSets/valid_encounters_maxRange700_maxGap3_minEncLen30.json'
def get_flight_by_id(self, flight_id):
return self.flights[flight_id]
......
......@@ -57,15 +57,25 @@ class Flight:
def location(self):
return 'Images/' + self.id
def __init__(self, flight_id, flight_data: dict, file_handler):
def __init__(self, flight_id, flight_data: dict, file_handler, valid_encounter=None):
self.id = flight_id
self.frames = {}
self.detected_objects = {}
self.file_handler = file_handler
self.metadata = FlightMetadata(flight_data['metadata'])
self.valid_encounter = valid_encounter
for entity in flight_data['entities']:
frame_id = entity['blob']['frame']
if self.valid_encounter is not None:
valid = False
for encounter in self.valid_encounter:
if encounter["framemin"] <= int(frame_id) <= encounter["framemax"]:
valid = True
if not valid:
continue
if frame_id not in self.frames:
self.frames[frame_id] = Frame(entity, self.file_handler, self)
......@@ -92,6 +102,9 @@ class Flight:
return self.detected_objects.values()
def get_frame(self, id):
if self.valid_encounter is not None and id not in self.frames:
logger.info("frame_id not present in partial dataset")
return None
return self.frames[id]
def get_metadata(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment