diff --git a/siam_mot_test.py b/siam_mot_test.py index fc66a9f8affe17bc94494607446724543b0c4ca9..e4b92a5689a61abe4997bf7b04350e1769a933fc 100644 --- a/siam_mot_test.py +++ b/siam_mot_test.py @@ -48,28 +48,27 @@ class SiamMOTPredictor(AirbornePredictor): def flight_started(self): self.track_id_results = {} self.visited_frame = {} + self.track_len_so_far = {} def proxy_register_object_and_location(self, class_name, track_id, bbox, confidence, img_name): - if track_id not in self.track_id_results: - self.track_id_results[track_id] = [] + # MIN_TRACK_LEN check + if track_id not in self.track_len_so_far: + self.track_len_so_far[track_id] = 0 + self.track_len_so_far[track_id] += 1 + if self.track_len_so_far[track_id] <= MIN_TRACK_LEN: + return + + # MIN_SCORE check + if confidence < MIN_SCORE: + return + if img_name not in self.visited_frame: self.visited_frame[img_name] = [] - if track_id in self.visited_frame[img_name]: raise Exception('two entities within the same frame {} have the same track id'.format(img_name)) - - self.track_id_results[track_id].append([class_name, track_id, bbox, confidence, img_name]) self.visited_frame[img_name].append(track_id) - def flight_completed(self): - for track_id in self.track_id_results.keys(): - track_len = len(self.track_id_results[track_id]) - if track_len < MIN_TRACK_LEN: - continue - for entity in self.track_id_results[track_id][MIN_TRACK_LEN:]: - if entity[3] < MIN_SCORE: - continue - self.register_object_and_location(*entity) + self.register_object_and_location(class_name, track_id, bbox, confidence, img_name) """ PARTICIPANT_TODO: @@ -103,7 +102,6 @@ class SiamMOTPredictor(AirbornePredictor): self.proxy_register_object_and_location(class_name, int(track_id), bbox, float(confidence), frame_image) - self.flight_completed() if __name__ == "__main__": submission = SiamMOTPredictor()