Skip to content

Commit 9d97cff

Browse files
authored
Example to use Mask R-CNN
1 parent 29bbece commit 9d97cff

File tree

1 file changed

+50
-0
lines changed

1 file changed

+50
-0
lines changed

maskrcnn_predict.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import mrcnn
2+
import mrcnn.config
3+
import mrcnn.model
4+
import mrcnn.visualize
5+
import cv2
6+
import os
7+
8+
# load the class label names from disk, one label per line
9+
# CLASS_NAMES = open("coco_labels.txt").read().strip().split("\n")
10+
11+
CLASS_NAMES = ['BG', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']
12+
13+
class SimpleConfig(mrcnn.config.Config):
14+
# Give the configuration a recognizable name
15+
NAME = "coco_inference"
16+
17+
# set the number of GPUs to use along with the number of images per GPU
18+
GPU_COUNT = 1
19+
IMAGES_PER_GPU = 1
20+
21+
# Number of classes = number of classes + 1 (+1 for the background). The background class is named BG
22+
NUM_CLASSES = len(CLASS_NAMES)
23+
24+
# Initialize the Mask R-CNN model for inference and then load the weights.
25+
# This step builds the Keras model architecture.
26+
model = mrcnn.model.MaskRCNN(mode="inference",
27+
config=SimpleConfig(),
28+
model_dir=os.getcwd())
29+
30+
# Load the weights into the model.
31+
model.load_weights(filepath="mask_rcnn_coco.h5",
32+
by_name=True)
33+
34+
# load the input image, convert it from BGR to RGB channel
35+
image = cv2.imread("test.jpg")
36+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
37+
38+
# Perform a forward pass of the network to obtain the results
39+
r = model.detect([image])
40+
41+
# Get the results for the first image.
42+
r = r[0]
43+
44+
# Visualize the detected objects.
45+
mrcnn.visualize.display_instances(image=image,
46+
boxes=r['rois'],
47+
masks=r['masks'],
48+
class_ids=r['class_ids'],
49+
class_names=CLASS_NAMES,
50+
scores=r['scores'])

0 commit comments

Comments
 (0)