frame2face.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. from flask import Flask, request, jsonify
  2. import cv2
  3. import numpy as np
  4. from sklearn.metrics.pairwise import cosine_similarity
  5. import os
  6. import mysql.connector
  7. from flask_cors import CORS
  8. import dlib
  9. app = Flask(__name__)
  10. # 修改 CORS 配置,明确指定允许的来源和方法
  11. # 允许所有源的访问(仅在开发环境使用)
  12. CORS(app, resources={
  13. r"/*": {
  14. "origins": ["http://localhost:8080"],
  15. "methods": ["GET", "POST", "OPTIONS"],
  16. "allow_headers": ["Content-Type"]
  17. }
  18. })
  19. CORS(app, supports_credentials=True)
  20. # 预先保存的标准图片路径
  21. STANDARD_IMAGES = {
  22. 'oval': 'oval.jpg',
  23. 'square': 'square.jpg',
  24. 'heart': 'heart.jpg',
  25. 'round': 'round.jpg',
  26. 'long': 'long.jpg',
  27. 'diamond': 'diamond.jpg',
  28. 'pear': 'pear.jpg'
  29. }
  30. # MySQL数据库配置
  31. db_config = {
  32. 'user': 'root',
  33. 'password': '123456',
  34. 'host': '192.168.3.80',
  35. 'database': 'citu_new'
  36. }
  37. # 获取当前文件所在目录的绝对路径
  38. BASE_DIR = os.path.dirname(os.path.abspath(__file__))
  39. # 构建数据文件的完整路径
  40. SHAPE_PREDICTOR_PATH = os.path.join(BASE_DIR, 'data', 'shape_predictor_68_face_landmarks.dat')
  41. def compare_faces(image1, image2):
  42. # 使用OpenCV进行脸部轮廓对比
  43. face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
  44. gray1 = cv2.cvtColor(image1, cv2.COLOR_BGR2GRAY)
  45. gray2 = cv2.cvtColor(image2, cv2.COLOR_BGR2GRAY)
  46. faces1 = face_cascade.detectMultiScale(gray1, 1.1, 4)
  47. faces2 = face_cascade.detectMultiScale(gray2, 1.1, 4)
  48. if len(faces1) == 0 or len(faces2) == 0:
  49. return 0.0
  50. # 计算相似度
  51. face1 = gray1[faces1[0][1]:faces1[0][1]+faces1[0][3], faces1[0][0]:faces1[0][0]+faces1[0][2]]
  52. face2 = gray2[faces2[0][1]:faces2[0][1]+faces2[0][3], faces2[0][0]:faces2[0][0]+faces2[0][2]]
  53. face1 = cv2.resize(face1, (100, 100))
  54. face2 = cv2.resize(face2, (100, 100))
  55. face1 = face1.flatten()
  56. face2 = face2.flatten()
  57. similarity = cosine_similarity([face1], [face2])[0][0]
  58. return similarity * 100
  59. @app.route('/upload', methods=['POST'])
  60. def upload():
  61. try:
  62. if not os.path.exists(SHAPE_PREDICTOR_PATH):
  63. return jsonify({
  64. 'error': True,
  65. 'message': f'找不到面部特征点数据文件: {SHAPE_PREDICTOR_PATH}',
  66. 'data': None
  67. }), 500
  68. if 'file' not in request.files:
  69. return jsonify({
  70. 'error': True,
  71. 'message': 'No file part',
  72. 'data': None
  73. }), 400
  74. file = request.files['file']
  75. image = cv2.imdecode(np.frombuffer(file.read(), np.uint8), cv2.IMREAD_COLOR)
  76. # 使用绝对路径加载模型
  77. detector = dlib.get_frontal_face_detector()
  78. predictor = dlib.shape_predictor(SHAPE_PREDICTOR_PATH)
  79. # 检测人脸
  80. faces = detector(image)
  81. if len(faces) == 0:
  82. return jsonify({
  83. 'error': True,
  84. 'message': '未检测到人脸',
  85. 'data': None
  86. }), 400
  87. # 获取第一个检测到的人脸的关键点
  88. landmarks = predictor(image, faces[0])
  89. # 获取特定点的坐标
  90. # dlib的68个关键点中:
  91. # 左脸颊点约为点1-2
  92. # 右脸颊点约为点15-16
  93. # 鼻梁点约为点27
  94. facial_points = {
  95. 'cheek_left_3': {'x': landmarks.part(2).x, 'y': landmarks.part(2).y},
  96. 'cheek_right_3': {'x': landmarks.part(14).x, 'y': landmarks.part(14).y},
  97. 'nose_bridge_1': {'x': landmarks.part(27).x, 'y': landmarks.part(27).y}
  98. }
  99. # 进行人脸对比
  100. results = []
  101. for name, path in STANDARD_IMAGES.items():
  102. standard_image = cv2.imread(path)
  103. similarity = compare_faces(image, standard_image)
  104. results.append({
  105. 'name': name,
  106. 'similarity': similarity,
  107. 'facial_points': facial_points # 添加面部关键点数据
  108. })
  109. # 挑选相似度最高的两个
  110. results.sort(key=lambda x: x['similarity'], reverse=True)
  111. top_results = results[:2]
  112. return jsonify({
  113. 'error': False,
  114. 'message': 'success',
  115. 'data': top_results
  116. })
  117. except Exception as e:
  118. error_message = str(e)
  119. print(f"Error in upload: {error_message}")
  120. return jsonify({
  121. 'error': True,
  122. 'message': f'处理图片时出错: {error_message}',
  123. 'data': None
  124. }), 500
  125. @app.route('/select', methods=['POST'])
  126. def select():
  127. try:
  128. data = request.get_json()
  129. if not data or 'name' not in data:
  130. return jsonify({
  131. 'error': True,
  132. 'message': '请求数据格式错误',
  133. 'data': None
  134. }), 400
  135. face_type = data['name'] # 从前端传来的name就是脸型
  136. # 从MySQL数据库读取对应的记录
  137. try:
  138. conn = mysql.connector.connect(**db_config)
  139. cursor = conn.cursor(dictionary=True) # 使用dictionary=True返回字典格式的结果
  140. # 修改查询语句匹配新表结构
  141. query = """
  142. SELECT id, facetype, brand, frametype, material,
  143. pic_url, detail_info
  144. FROM zeiss_glass_shelf
  145. WHERE facetype = %s
  146. """
  147. cursor.execute(query, (face_type,))
  148. rows = cursor.fetchall()
  149. # 格式化返回数据
  150. formatted_rows = []
  151. for row in rows:
  152. formatted_rows.append({
  153. 'id': row['id'],
  154. 'facetype': row['facetype'],
  155. 'brand': row['brand'],
  156. 'frametype': row['frametype'],
  157. 'material': row['material'],
  158. 'pic_url': row['pic_url'],
  159. 'detail_info': row['detail_info']
  160. })
  161. return jsonify({
  162. 'error': False,
  163. 'message': 'success',
  164. 'data': formatted_rows
  165. })
  166. except mysql.connector.Error as db_err:
  167. error_message = str(db_err)
  168. print(f"Database error: {error_message}")
  169. return jsonify({
  170. 'error': True,
  171. 'message': f'数据库操作失败: {error_message}',
  172. 'data': None
  173. }), 500
  174. finally:
  175. if 'conn' in locals() and conn.is_connected():
  176. cursor.close()
  177. conn.close()
  178. except Exception as e:
  179. error_message = str(e)
  180. print(f"General error: {error_message}")
  181. return jsonify({
  182. 'error': True,
  183. 'message': f'服务器内部错误: {error_message}',
  184. 'data': None
  185. }), 500
  186. if __name__ == '__main__':
  187. app.run(host='0.0.0.0', port=8080, debug=True)