1
+ import os
2
+ import sys
3
+ import argparse
4
+ import ast
5
+ import cv2
6
+ import torch
7
+ from vidgear .gears import CamGear
8
+
9
+ from simpleHRnet import SimpleHRNet
10
+ from misc .visualization import draw_points , draw_skeleton , draw_points_and_skeleton , joints_dict
11
+
12
+ def main (camera_id , filename , hrnet_c , hrnet_j , hrnet_weights , hrnet_joints_set , image_resolution , single_person ,
13
+ max_batch_size , disable_vidgear , device ):
14
+
15
+ if device is not None :
16
+ device = torch .device (device )
17
+ else :
18
+ if torch .cuda .is_available () and True :
19
+ torch .backends .cudnn .deterministic = True
20
+ device = torch .device ('cuda:0' )
21
+ else :
22
+ device = torch .device ('cpu' )
23
+ print (device )
24
+
25
+ image_resolution = ast .literal_eval (image_resolution )
26
+ has_display = 'DISPLAY' in os .environ .keys () or sys .platform == 'win32'
27
+
28
+ if filename is not None :
29
+ video = cv2 .VideoCapture (filename )
30
+ assert video .isOpened ()
31
+ else :
32
+ if disable_vidgear :
33
+ video = cv2 .VideoCapture (camera_id )
34
+ assert video .isOpened ()
35
+ else :
36
+ video = CamGear (camera_id ).start ()
37
+
38
+ model = SimpleHRNet (
39
+ hrnet_c ,
40
+ hrnet_j ,
41
+ hrnet_weights ,
42
+ resolution = image_resolution ,
43
+ multiperson = not single_person ,
44
+ max_batch_size = max_batch_size ,
45
+ device = device
46
+ )
47
+
48
+ while True :
49
+ if filename is not None or disable_vidgear :
50
+ ret , frame = video .read ()
51
+ if not ret :
52
+ break
53
+ else :
54
+ frame = video .read ()
55
+ if frame is None :
56
+ break
57
+
58
+ pts = model .predict (frame )
59
+
60
+ for i , pt in enumerate (pts ):
61
+ frame = draw_points_and_skeleton (frame , pt , joints_dict ()[hrnet_joints_set ]['skeleton' ], person_index = i ,
62
+ points_color_palette = 'gist_rainbow' , skeleton_color_palette = 'jet' ,
63
+ points_palette_samples = 10 )
64
+
65
+ if has_display :
66
+ cv2 .imshow ('frame.png' , frame )
67
+ k = cv2 .waitKey (1 )
68
+ if k == 27 : # Esc button
69
+ if disable_vidgear :
70
+ video .release ()
71
+ else :
72
+ video .stop ()
73
+ break
74
+ else :
75
+ cv2 .imwrite ('frame.png' , frame )
76
+
77
+
78
+ if __name__ == '__main__' :
79
+ parser = argparse .ArgumentParser ()
80
+ parser .add_argument ("--camera_id" , "-d" , help = "open the camera with the specified id" , type = int , default = 0 )
81
+ parser .add_argument ("--filename" , "-f" , help = "open the specified video (overrides the --camera_id option)" ,
82
+ type = str , default = None )
83
+ parser .add_argument ("--hrnet_c" , "-c" , help = "hrnet parameters - number of channels" , type = int , default = 32 )
84
+ parser .add_argument ("--hrnet_j" , "-j" , help = "hrnet parameters - number of joints" , type = int , default = 17 )
85
+ parser .add_argument ("--hrnet_weights" , "-w" , help = "hrnet parameters - path to the pretrained weights" ,
86
+ type = str , default = "weights/mod_pose_hrnet_w32_256x192.pth" )
87
+ parser .add_argument ("--hrnet_joints_set" ,
88
+ help = "use the specified set of joints ('coco' and 'mpii' are currently supported)" ,
89
+ type = str , default = "coco" )
90
+ parser .add_argument ("--image_resolution" , "-r" , help = "image resolution" , type = str , default = '(256, 192)' )
91
+ parser .add_argument ("--single_person" ,
92
+ help = "disable the multiperson detection (YOLOv3 or an equivalen detector is required for"
93
+ "multiperson detection)" ,
94
+ action = "store_true" )
95
+ parser .add_argument ("--max_batch_size" , help = "maximum batch size used for inference" , type = int , default = 16 )
96
+ parser .add_argument ("--disable_vidgear" ,
97
+ help = "disable vidgear (which is used for slightly better realtime performance)" ,
98
+ action = "store_true" ) # see https://pypi.org/project/vidgear/
99
+ parser .add_argument ("--device" , help = "device to be used (default: cuda, if available)" , type = str , default = None )
100
+ args = parser .parse_args ()
101
+ import warnings
102
+
103
+ warnings .filterwarnings ("ignore" ,category = UserWarning )
104
+
105
+
106
+ main (** args .__dict__ )
0 commit comments