Skip to content
Snippets Groups Projects
Unverified Commit 08a11c17 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

add a jupyter notebook demo (#1158)

parent 63b9d104
No related branches found
No related tags found
No related merge requests found
......@@ -62,13 +62,13 @@ python tools/test.py configs/mask_rcnn_r50_fpn_1x.py \
We provide a webcam demo to illustrate the results.
```shell
python tools/webcam_demo.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [--device ${GPU_ID}] [--camera-id ${CAMERA-ID}] [--score-thr ${CAMERA-ID}]
python demo/webcam_demo.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [--device ${GPU_ID}] [--camera-id ${CAMERA-ID}] [--score-thr ${CAMERA-ID}]
```
Examples:
```shell
python tools/webcam_demo.py configs/faster_rcnn_r50_fpn_1x.py \
python demo/webcam_demo.py configs/faster_rcnn_r50_fpn_1x.py \
checkpoints/faster_rcnn_r50_fpn_1x_20181010-3d1b3351.pth
```
......@@ -103,6 +103,8 @@ for frame in video:
show_result(frame, result, model.CLASSES, wait_time=1)
```
A notebook demo can be found in [demo/inference_demo.ipynb](demo/inference_demo.ipynb).
## Train a model
......
demo/demo.jpg

254 KiB

This diff is collapsed.
File moved
from .env import get_root_logger, init_dist, set_random_seed
from .inference import inference_detector, init_detector, show_result
from .inference import (inference_detector, init_detector, show_result,
show_result_pyplot)
from .train import train_detector
__all__ = [
'init_dist', 'get_root_logger', 'set_random_seed', 'train_detector',
'init_detector', 'inference_detector', 'show_result'
'init_detector', 'inference_detector', 'show_result', 'show_result_pyplot'
]
import warnings
import matplotlib.pyplot as plt
import mmcv
import numpy as np
import pycocotools.mask as maskUtils
......@@ -105,6 +106,7 @@ def show_result(img,
class_names,
score_thr=0.3,
wait_time=0,
show=True,
out_file=None):
"""Visualize the detection results on the image.
......@@ -115,11 +117,17 @@ def show_result(img,
class_names (list[str] or tuple[str]): A list of class names.
score_thr (float): The threshold to visualize the bboxes and masks.
wait_time (int): Value of waitKey param.
show (bool, optional): Whether to show the image with opencv or not.
out_file (str, optional): If specified, the visualization result will
be written to the out file instead of shown in a window.
Returns:
np.ndarray or None: If neither `show` nor `out_file` is specified, the
visualized image is returned, otherwise None is returned.
"""
assert isinstance(class_names, (tuple, list))
img = mmcv.imread(img)
img = img.copy()
if isinstance(result, tuple):
bbox_result, segm_result = result
else:
......@@ -140,11 +148,36 @@ def show_result(img,
]
labels = np.concatenate(labels)
mmcv.imshow_det_bboxes(
img.copy(),
img,
bboxes,
labels,
class_names=class_names,
score_thr=score_thr,
show=out_file is None,
show=show,
wait_time=wait_time,
out_file=out_file)
if not (show or out_file):
return img
def show_result_pyplot(img,
result,
class_names,
score_thr=0.3,
fig_size=(15, 10)):
"""Visualize the detection results on the image.
Args:
img (str or np.ndarray): Image filename or loaded image.
result (tuple[list] or list): The detection result, can be either
(bbox, segm) or just bbox.
class_names (list[str] or tuple[str]): A list of class names.
score_thr (float): The threshold to visualize the bboxes and masks.
fig_size (tuple): Figure size of the pyplot figure.
out_file (str, optional): If specified, the visualization result will
be written to the out file instead of shown in a window.
"""
img = show_result(
img, result, class_names, score_thr=score_thr, show=False)
plt.figure(figsize=fig_size)
plt.imshow(mmcv.bgr2rgb(img))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment