Chengrui Wang commited on
Commit
cf14767
·
0 Parent(s):

Add a model for facial expression recognition (#100)

Browse files
Files changed (3) hide show
  1. README.md +40 -0
  2. demo.py +131 -0
  3. facial_fer_model.py +178 -0
README.md ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Progressive Teacher
3
+
4
+ Progressive Teacher: [Boosting Facial Expression Recognition by A Semi-Supervised Progressive Teacher](https://scholar.google.com/citations?view_op=view_citation&hl=zh-CN&user=OCwcfAwAAAAJ&citation_for_view=OCwcfAwAAAAJ:u5HHmVD_uO8C)
5
+
6
+ Note:
7
+ - Progressive Teacher is contributed by [Jing Jiang](https://scholar.google.com/citations?user=OCwcfAwAAAAJ&hl=zh-CN).
8
+ - [MobileFaceNet](https://link.springer.com/chapter/10.1007/978-3-319-97909-0_46) is used as the backbone and the model is able to classify seven basic facial expressions (angry, disgust, fearful, happy, neutral, sad, surprised).
9
+ - [facial_expression_recognition_mobilefacenet_2022july.onnx](https://github.com/opencv/opencv_zoo/raw/master/models/facial_expression_recognition/facial_expression_recognition_mobilefacenet_2022july.onnx) is implemented thanks to [Chengrui Wang](https://github.com/opencv).
10
+
11
+ Results of accuracy evaluation on [RAF-DB](http://whdeng.cn/RAF/model1.html).
12
+
13
+ | Models | Accuracy |
14
+ |-------------|----------|
15
+ | Progressive Teacher | 88.27% |
16
+
17
+
18
+ ## Demo
19
+
20
+ ***NOTE***: This demo uses [../face_detection_yunet](../face_detection_yunet) as face detector, which supports 5-landmark detection for now (2021sep).
21
+
22
+ Run the following command to try the demo:
23
+ ```shell
24
+ # recognize the facial expression on images
25
+ python demo.py --input /path/to/image
26
+ ```
27
+
28
+ ### Example outputs
29
+
30
+ Note: Zoom in to to see the recognized facial expression in the top-left corner of each face boxes.
31
+
32
+ ![fer demo](./examples/selfie.jpg)
33
+
34
+ ## License
35
+
36
+ All files in this directory are licensed under [Apache 2.0 License](./LICENSE).
37
+
38
+ ## Reference
39
+
40
+ - https://ieeexplore.ieee.org/abstract/document/9629313
demo.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import argparse
3
+ import copy
4
+ import datetime
5
+
6
+ import numpy as np
7
+ import cv2 as cv
8
+
9
+ from facial_fer_model import FacialExpressionRecog
10
+
11
+ sys.path.append('../face_detection_yunet')
12
+ from yunet import YuNet
13
+
14
+
15
+ def str2bool(v):
16
+ if v.lower() in ['on', 'yes', 'true', 'y', 't']:
17
+ return True
18
+ elif v.lower() in ['off', 'no', 'false', 'n', 'f']:
19
+ return False
20
+ else:
21
+ raise NotImplementedError
22
+
23
+
24
+ backends = [cv.dnn.DNN_BACKEND_OPENCV, cv.dnn.DNN_BACKEND_CUDA]
25
+ targets = [cv.dnn.DNN_TARGET_CPU, cv.dnn.DNN_TARGET_CUDA, cv.dnn.DNN_TARGET_CUDA_FP16]
26
+ help_msg_backends = "Choose one of the computation backends: {:d}: OpenCV implementation (default); {:d}: CUDA"
27
+ help_msg_targets = "Chose one of the target computation devices: {:d}: CPU (default); {:d}: CUDA; {:d}: CUDA fp16"
28
+ try:
29
+ backends += [cv.dnn.DNN_BACKEND_TIMVX]
30
+ targets += [cv.dnn.DNN_TARGET_NPU]
31
+ help_msg_backends += "; {:d}: TIMVX"
32
+ help_msg_targets += "; {:d}: NPU"
33
+ except:
34
+ print('This version of OpenCV does not support TIM-VX and NPU. Visit https://github.com/opencv/opencv/wiki/TIM-VX-Backend-For-Running-OpenCV-On-NPU for more information.')
35
+
36
+ parser = argparse.ArgumentParser(description='Facial Expression Recognition')
37
+ parser.add_argument('--input', '-i', type=str, help='Path to the input image. Omit for using default camera.')
38
+ parser.add_argument('--model', '-fm', type=str, default='./facial_expression_recognition_mobilefacenet_2022july.onnx', help='Path to the facial expression recognition model.')
39
+ parser.add_argument('--backend', '-b', type=int, default=backends[0], help=help_msg_backends.format(*backends))
40
+ parser.add_argument('--target', '-t', type=int, default=targets[0], help=help_msg_targets.format(*targets))
41
+ parser.add_argument('--save', '-s', type=str, default=False, help='Set true to save results. This flag is invalid when using camera.')
42
+ parser.add_argument('--vis', '-v', type=str2bool, default=True, help='Set true to open a window for result visualization. This flag is invalid when using camera.')
43
+ args = parser.parse_args()
44
+
45
+
46
+ def visualize(image, det_res, fer_res, box_color=(0, 255, 0), text_color=(0, 0, 255)):
47
+
48
+ print('%s %3d faces detected.' % (datetime.datetime.now(), len(det_res)))
49
+
50
+ output = image.copy()
51
+ landmark_color = [
52
+ (255, 0, 0), # right eye
53
+ (0, 0, 255), # left eye
54
+ (0, 255, 0), # nose tip
55
+ (255, 0, 255), # right mouth corner
56
+ (0, 255, 255) # left mouth corner
57
+ ]
58
+
59
+ for ind, (det, fer_type) in enumerate(zip(det_res, fer_res)):
60
+ bbox = det[0:4].astype(np.int32)
61
+ fer_type = FacialExpressionRecog.getDesc(fer_type)
62
+ print("Face %2d: %d %d %d %d %s." % (ind, bbox[0], bbox[1], bbox[0]+bbox[2], bbox[1]+bbox[3], fer_type))
63
+ cv.rectangle(output, (bbox[0], bbox[1]), (bbox[0]+bbox[2], bbox[1]+bbox[3]), box_color, 2)
64
+ cv.putText(output, fer_type, (bbox[0], bbox[1]+12), cv.FONT_HERSHEY_DUPLEX, 0.5, text_color)
65
+ landmarks = det[4:14].astype(np.int32).reshape((5, 2))
66
+ for idx, landmark in enumerate(landmarks):
67
+ cv.circle(output, landmark, 2, landmark_color[idx], 2)
68
+ return output
69
+
70
+
71
+ def process(detect_model, fer_model, frame):
72
+ h, w, _ = frame.shape
73
+ detect_model.setInputSize([w, h])
74
+ dets = detect_model.infer(frame)
75
+
76
+ if dets is None:
77
+ return False, None, None
78
+
79
+ fer_res = np.zeros(0, dtype=np.int8)
80
+ for face_points in dets:
81
+ fer_res = np.concatenate((fer_res, fer_model.infer(frame, face_points[:-1])), axis=0)
82
+ return True, dets, fer_res
83
+
84
+
85
+ if __name__ == '__main__':
86
+ detect_model = YuNet(modelPath='../face_detection_yunet/face_detection_yunet_2022mar.onnx')
87
+
88
+ fer_model = FacialExpressionRecog(modelPath=args.model,
89
+ backendId=args.backend,
90
+ targetId=args.target)
91
+
92
+ # If input is an image
93
+ if args.input is not None:
94
+ image = cv.imread(args.input)
95
+
96
+ # Get detection and fer results
97
+ status, dets, fer_res = process(detect_model, fer_model, image)
98
+
99
+ if status:
100
+ # Draw results on the input image
101
+ image = visualize(image, dets, fer_res)
102
+
103
+ # Save results
104
+ if args.save:
105
+ cv.imwrite('result.jpg', image)
106
+ print('Results saved to result.jpg\n')
107
+
108
+ # Visualize results in a new window
109
+ if args.vis:
110
+ cv.namedWindow(args.input, cv.WINDOW_AUTOSIZE)
111
+ cv.imshow(args.input, image)
112
+ cv.waitKey(0)
113
+ else: # Omit input to call default camera
114
+ deviceId = 0
115
+ cap = cv.VideoCapture(deviceId)
116
+
117
+ while cv.waitKey(1) < 0:
118
+ hasFrame, frame = cap.read()
119
+ if not hasFrame:
120
+ print('No frames grabbed!')
121
+ break
122
+
123
+ # Get detection and fer results
124
+ status, dets, fer_res = process(detect_model, fer_model, frame)
125
+
126
+ if status:
127
+ # Draw results on the input image
128
+ frame = visualize(frame, dets, fer_res)
129
+
130
+ # Visualize results in a new window
131
+ cv.imshow('FER Demo', frame)
facial_fer_model.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is part of OpenCV Zoo project.
2
+ # It is subject to the license terms in the LICENSE file found in the same directory.
3
+ #
4
+ # Copyright (C) 2022, Shenzhen Institute of Artificial Intelligence and Robotics for Society, all rights reserved.
5
+ # Third party copyrights are property of their respective owners.
6
+
7
+ import numpy as np
8
+ import cv2 as cv
9
+
10
+ class FacialExpressionRecog:
11
+ def __init__(self, modelPath, backendId=0, targetId=0):
12
+ self._modelPath = modelPath
13
+ self._backendId = backendId
14
+ self._targetId = targetId
15
+
16
+ self._model = cv.dnn.readNet(self._modelPath)
17
+ self._model.setPreferableBackend(self._backendId)
18
+ self._model.setPreferableTarget(self._targetId)
19
+
20
+ self._align_model = FaceAlignment()
21
+
22
+ self._inputNames = 'data'
23
+ self._outputNames = ['label']
24
+ self._inputSize = [112, 112]
25
+ self._mean = np.array([0.5, 0.5, 0.5])[np.newaxis, np.newaxis, :]
26
+ self._std = np.array([0.5, 0.5, 0.5])[np.newaxis, np.newaxis, :]
27
+
28
+ @property
29
+ def name(self):
30
+ return self.__class__.__name__
31
+
32
+ def setBackend(self, backend_id):
33
+ self._backendId = backend_id
34
+ self._model.setPreferableBackend(self._backendId)
35
+
36
+ def setTarget(self, target_id):
37
+ self._targetId = target_id
38
+ self._model.setPreferableTarget(self._targetId)
39
+
40
+ def _preprocess(self, image, bbox):
41
+ if bbox is not None:
42
+ image = self._align_model.get_align_image(image, bbox[4:].reshape(-1, 2))
43
+ image = cv.cvtColor(image, cv.COLOR_BGR2RGB)
44
+ image = image.astype(np.float32, copy=False) / 255.0
45
+ image -= self._mean
46
+ image /= self._std
47
+ return cv.dnn.blobFromImage(image)
48
+
49
+ def infer(self, image, bbox=None):
50
+ # Preprocess
51
+ inputBlob = self._preprocess(image, bbox)
52
+
53
+ # Forward
54
+ self._model.setInput(inputBlob, self._inputNames)
55
+ outputBlob = self._model.forward(self._outputNames)
56
+
57
+ # Postprocess
58
+ results = self._postprocess(outputBlob)
59
+
60
+ return results
61
+
62
+ def _postprocess(self, outputBlob):
63
+ result = np.argmax(outputBlob[0], axis=1).astype(np.uint8)
64
+ return result
65
+
66
+ @staticmethod
67
+ def getDesc(ind):
68
+ _expression_enum = ["angry", "disgust", "fearful", "happy", "neutral", "sad", "surprised"]
69
+ return _expression_enum[ind]
70
+
71
+
72
+ class FaceAlignment():
73
+ def __init__(self, reflective=False):
74
+ self._std_points = np.array([[38.2946, 51.6963], [73.5318, 51.5014], [56.0252, 71.7366], [41.5493, 92.3655], [70.7299, 92.2041]])
75
+ self.reflective = reflective
76
+
77
+ def __tformfwd(self, trans, uv):
78
+ uv = np.hstack((uv, np.ones((uv.shape[0], 1))))
79
+ xy = np.dot(uv, trans)
80
+ xy = xy[:, 0:-1]
81
+ return xy
82
+
83
+ def __tforminv(self, trans, uv):
84
+ Tinv = np.linalg.inv(trans)
85
+ xy = self.__tformfwd(Tinv, uv)
86
+ return xy
87
+
88
+ def __findNonreflectiveSimilarity(self, uv, xy, options=None):
89
+ options = {"K": 2}
90
+
91
+ K = options["K"]
92
+ M = xy.shape[0]
93
+ x = xy[:, 0].reshape((-1, 1)) # use reshape to keep a column vector
94
+ y = xy[:, 1].reshape((-1, 1)) # use reshape to keep a column vector
95
+ # print '--->x, y:\n', x, y
96
+
97
+ tmp1 = np.hstack((x, y, np.ones((M, 1)), np.zeros((M, 1))))
98
+ tmp2 = np.hstack((y, -x, np.zeros((M, 1)), np.ones((M, 1))))
99
+ X = np.vstack((tmp1, tmp2))
100
+ # print '--->X.shape: ', X.shape
101
+ # print 'X:\n', X
102
+
103
+ u = uv[:, 0].reshape((-1, 1)) # use reshape to keep a column vector
104
+ v = uv[:, 1].reshape((-1, 1)) # use reshape to keep a column vector
105
+ U = np.vstack((u, v))
106
+ # print '--->U.shape: ', U.shape
107
+ # print 'U:\n', U
108
+
109
+ # We know that X * r = U
110
+ if np.linalg.matrix_rank(X) >= 2 * K:
111
+ r, _, _, _ = np.linalg.lstsq(X, U, rcond=-1)
112
+ # print(r, X, U, sep="\n")
113
+ r = np.squeeze(r)
114
+ else:
115
+ raise Exception("cp2tform:twoUniquePointsReq")
116
+
117
+ sc = r[0]
118
+ ss = r[1]
119
+ tx = r[2]
120
+ ty = r[3]
121
+
122
+ Tinv = np.array([[sc, -ss, 0], [ss, sc, 0], [tx, ty, 1]])
123
+ T = np.linalg.inv(Tinv)
124
+ T[:, 2] = np.array([0, 0, 1])
125
+
126
+ return T, Tinv
127
+
128
+ def __findSimilarity(self, uv, xy, options=None):
129
+ options = {"K": 2}
130
+
131
+ # uv = np.array(uv)
132
+ # xy = np.array(xy)
133
+
134
+ # Solve for trans1
135
+ trans1, trans1_inv = self.__findNonreflectiveSimilarity(uv, xy, options)
136
+
137
+ # manually reflect the xy data across the Y-axis
138
+ xyR = xy
139
+ xyR[:, 0] = -1 * xyR[:, 0]
140
+ # Solve for trans2
141
+ trans2r, trans2r_inv = self.__findNonreflectiveSimilarity(uv, xyR, options)
142
+
143
+ # manually reflect the tform to undo the reflection done on xyR
144
+ TreflectY = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, 1]])
145
+ trans2 = np.dot(trans2r, TreflectY)
146
+
147
+ # Figure out if trans1 or trans2 is better
148
+ xy1 = self.__tformfwd(trans1, uv)
149
+ norm1 = np.linalg.norm(xy1 - xy)
150
+ xy2 = self.__tformfwd(trans2, uv)
151
+ norm2 = np.linalg.norm(xy2 - xy)
152
+
153
+ if norm1 <= norm2:
154
+ return trans1, trans1_inv
155
+ else:
156
+ trans2_inv = np.linalg.inv(trans2)
157
+ return trans2, trans2_inv
158
+
159
+ def __get_similarity_transform(self, src_pts, dst_pts):
160
+ if self.reflective:
161
+ trans, trans_inv = self.__findSimilarity(src_pts, dst_pts)
162
+ else:
163
+ trans, trans_inv = self.__findNonreflectiveSimilarity(src_pts, dst_pts)
164
+ return trans, trans_inv
165
+
166
+ def __cvt_tform_mat_for_cv2(self, trans):
167
+ cv2_trans = trans[:, 0:2].T
168
+ return cv2_trans
169
+
170
+ def get_similarity_transform_for_cv2(self, src_pts, dst_pts):
171
+ trans, trans_inv = self.__get_similarity_transform(src_pts, dst_pts)
172
+ cv2_trans = self.__cvt_tform_mat_for_cv2(trans)
173
+ return cv2_trans, trans
174
+
175
+ def get_align_image(self, image, lm5_points):
176
+ assert lm5_points is not None
177
+ tfm, trans = self.get_similarity_transform_for_cv2(lm5_points, self._std_points)
178
+ return cv.warpAffine(image, tfm, (112, 112))