Skip to content
Snippets Groups Projects
Commit 3764e1f6 authored by Shivam Khandelwal's avatar Shivam Khandelwal
Browse files

Flights can be present across training data split

parent bd5009e0
No related branches found
No related tags found
No related merge requests found
......@@ -10,10 +10,11 @@ class Dataset:
metadata = None
flights = {}
def __init__(self, local_path, s3_path, download_if_required=True, partial=False):
def __init__(self, local_path, s3_path, download_if_required=True, partial=False, prefix=None):
self.file_handler = None
self.partial = partial
self.valid_encounter = {}
self.prefix = prefix
self.add(local_path, s3_path, download_if_required)
def load_gt(self):
......@@ -26,7 +27,7 @@ class Dataset:
if self.partial and flight_id not in self.valid_encounter:
logger.info("Skipping flight, not present in valid encounters: %s" % flight_id)
continue
self.flights[flight_id] = Flight(flight_id, gt["samples"][flight_id], self.file_handler, self.valid_encounter.get(flight_id))
self.flights[flight_id] = Flight(flight_id, gt["samples"][flight_id], self.file_handler, self.valid_encounter.get(flight_id), prefix=self.prefix)
def load_ve(self):
if self.partial:
......
......@@ -55,13 +55,16 @@ class Flight:
@property
def location(self):
if self.prefix:
return 'Images/' + self.prefix + self.id
return 'Images/' + self.id
def __init__(self, flight_id, flight_data: dict, file_handler, valid_encounter=None):
def __init__(self, flight_id, flight_data: dict, file_handler, valid_encounter=None, prefix=None):
self.id = flight_id
self.frames = {}
self.detected_objects = {}
self.file_handler = file_handler
self.prefix = prefix
self.metadata = FlightMetadata(flight_data['metadata'])
self.valid_encounter = valid_encounter
for entity in flight_data['entities']:
......
......@@ -51,6 +51,8 @@ class Frame:
return len(self.detected_objects)
def image_path(self):
if self.flight.prefix:
return os.path.join('Images', self.flight.prefix + self.flight.id, (str(self.timestamp) + str(self.flight.id) + '.png'))
return os.path.join('Images', self.flight.id, (str(self.timestamp) + str(self.flight.id) + '.png'))
def image(self, type='pil'):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment