import cv2 import numpy as np from ultralytics import YOLO import os from PIL import Image, ImageDraw, ImageFont class LicensePlateYOLO: """ 车牌YOLO检测器类 负责加载YOLO pose模型并进行车牌检测和角点提取 """ def __init__(self, model_path=None): """ 初始化YOLO检测器 参数: model_path: 模型文件路径,如果为None则使用默认路径 """ self.model = None self.model_path = model_path or self._get_default_model_path() self.class_names = {0: '蓝牌', 1: '绿牌'} self.load_model() def _get_default_model_path(self): """获取默认模型路径""" current_dir = os.path.dirname(__file__) return os.path.join(current_dir, "yolo11s-pose42.pt") def load_model(self): """ 加载YOLO pose模型 返回: bool: 加载是否成功 """ try: if os.path.exists(self.model_path): self.model = YOLO(self.model_path) print(f"YOLO模型加载成功: {self.model_path}") return True else: print(f"模型文件不存在: {self.model_path}") return False except Exception as e: print(f"YOLO模型加载失败: {e}") return False def detect_license_plates(self, image, conf_threshold=0.5): """ 检测图像中的车牌 参数: image: 输入图像 (numpy数组) conf_threshold: 置信度阈值 返回: list: 检测结果列表,每个元素包含: - box: 边界框坐标 [x1, y1, x2, y2] - keypoints: 四个角点坐标 [[x1,y1], [x2,y2], [x3,y3], [x4,y4]] - confidence: 置信度 - class_id: 类别ID (0=蓝牌, 1=绿牌) - class_name: 类别名称 """ if self.model is None: print("模型未加载") return [] try: # 进行推理 results = self.model(image, conf=conf_threshold, verbose=False) detections = [] for result in results: # 检查是否有检测结果 if result.boxes is None or result.keypoints is None: continue # 提取检测信息 boxes = result.boxes.xyxy.cpu().numpy() # 边界框 keypoints = result.keypoints.xy.cpu().numpy() # 关键点 confidences = result.boxes.conf.cpu().numpy() # 置信度 classes = result.boxes.cls.cpu().numpy() # 类别 # 处理每个检测结果 for i in range(len(boxes)): # 检查关键点数量是否为4个 if len(keypoints[i]) == 4: class_id = int(classes[i]) detection = { 'box': boxes[i], 'keypoints': keypoints[i], 'confidence': confidences[i], 'class_id': class_id, 'class_name': self.class_names.get(class_id, '未知') } detections.append(detection) else: # 关键点不足4个,记录但标记为不完整 class_id = int(classes[i]) detection = { 'box': boxes[i], 'keypoints': keypoints[i] if len(keypoints[i]) > 0 else [], 'confidence': confidences[i], 'class_id': class_id, 'class_name': self.class_names.get(class_id, '未知'), 'incomplete': True # 标记为不完整 } detections.append(detection) return detections except Exception as e: print(f"检测过程中出错: {e}") return [] def draw_detections(self, image, detections, plate_numbers=None): """ 在图像上绘制检测结果 参数: image: 输入图像 detections: 检测结果列表 plate_numbers: 车牌号列表,与detections对应 返回: numpy.ndarray: 绘制了检测结果的图像 """ draw_image = image.copy() # 转换为PIL图像以支持中文字符 pil_image = Image.fromarray(cv2.cvtColor(draw_image, cv2.COLOR_BGR2RGB)) draw = ImageDraw.Draw(pil_image) # 尝试加载中文字体 try: # Windows系统常见的中文字体 font_path = "C:/Windows/Fonts/simhei.ttf" # 黑体 if not os.path.exists(font_path): font_path = "C:/Windows/Fonts/msyh.ttc" # 微软雅黑 if not os.path.exists(font_path): font_path = "C:/Windows/Fonts/simsun.ttc" # 宋体 font = ImageFont.truetype(font_path, 20) except: # 如果无法加载字体,使用默认字体 font = ImageFont.load_default() for i, detection in enumerate(detections): box = detection['box'] keypoints = detection['keypoints'] class_name = detection['class_name'] confidence = detection['confidence'] incomplete = detection.get('incomplete', False) # 获取对应的车牌号 plate_number = "" if plate_numbers and i < len(plate_numbers): plate_number = plate_numbers[i] # 绘制边界框 x1, y1, x2, y2 = map(int, box) # 根据车牌类型选择颜色 if class_name == '绿牌': box_color = (0, 255, 0) # 绿色 elif class_name == '蓝牌': box_color = (0, 0, 255) # 蓝色 else: box_color = (128, 128, 128) # 灰色 # 在PIL图像上绘制边界框 draw.rectangle([(x1, y1), (x2, y2)], outline=box_color, width=2) # 构建标签文本 if plate_number: label = f"{class_name} {plate_number} {confidence:.2f}" else: label = f"{class_name} {confidence:.2f}" if incomplete: label += " (不完整)" # 计算文本大小 bbox = draw.textbbox((0, 0), label, font=font) text_width = bbox[2] - bbox[0] text_height = bbox[3] - bbox[1] # 绘制文本背景 draw.rectangle([(x1, y1 - text_height - 10), (x1 + text_width, y1)], fill=box_color) # 绘制文本 draw.text((x1, y1 - text_height - 5), label, fill=(255, 255, 255), font=font) # 转换回OpenCV格式 draw_image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR) # 绘制关键点和连线(使用OpenCV) for i, detection in enumerate(detections): box = detection['box'] keypoints = detection['keypoints'] incomplete = detection.get('incomplete', False) x1, y1, x2, y2 = map(int, box) # 根据车牌类型选择颜色 class_name = detection['class_name'] if class_name == '绿牌': box_color = (0, 255, 0) # 绿色 elif class_name == '蓝牌': box_color = (0, 0, 255) # 蓝色 else: box_color = (128, 128, 128) # 灰色 # 绘制关键点和连线 if len(keypoints) >= 4 and not incomplete: # 四个角点完整,用黄色连线 points = [(int(kp[0]), int(kp[1])) for kp in keypoints[:4]] # 绘制关键点 for point in points: cv2.circle(draw_image, point, 5, (0, 255, 255), -1) # 连接关键点形成四边形(按顺序连接) # 假设关键点顺序为: right_bottom, left_bottom, left_top, right_top for j in range(4): cv2.line(draw_image, points[j], points[(j+1)%4], (0, 255, 255), 2) elif len(keypoints) > 0: # 关键点不完整,用红色标记现有点 for kp in keypoints: point = (int(kp[0]), int(kp[1])) cv2.circle(draw_image, point, 5, (0, 0, 255), -1) return draw_image def correct_license_plate(self, image, keypoints, target_size=(240, 80)): """ 使用四个角点对车牌进行透视变换矫正 参数: image: 原始图像 keypoints: 四个角点坐标 target_size: 目标尺寸 (width, height) 返回: numpy.ndarray: 矫正后的车牌图像,如果失败返回None """ if len(keypoints) != 4: return None try: # 将关键点转换为numpy数组 src_points = np.array(keypoints, dtype=np.float32) # 定义目标矩形的四个角点 # 假设关键点顺序为: right_bottom, left_bottom, left_top, right_top # 重新排序为标准顺序: left_top, right_top, right_bottom, left_bottom width, height = target_size dst_points = np.array([ [0, 0], # left_top [width, 0], # right_top [width, height], # right_bottom [0, height] # left_bottom ], dtype=np.float32) # 重新排序源点以匹配目标点 # 原顺序: right_bottom, left_bottom, left_top, right_top # 目标顺序: left_top, right_top, right_bottom, left_bottom reordered_src = np.array([ src_points[2], # left_top src_points[3], # right_top src_points[0], # right_bottom src_points[1] # left_bottom ], dtype=np.float32) # 计算透视变换矩阵 matrix = cv2.getPerspectiveTransform(reordered_src, dst_points) # 应用透视变换 corrected = cv2.warpPerspective(image, matrix, target_size) return corrected except Exception as e: print(f"车牌矫正失败: {e}") return None def get_model_info(self): """ 获取模型信息 返回: dict: 模型信息字典 """ if self.model is None: return {"status": "未加载", "path": self.model_path} return { "status": "已加载", "path": self.model_path, "model_type": "YOLO11 Pose", "classes": self.class_names } def initialize_yolo_detector(model_path=None): """ 初始化YOLO检测器的便捷函数 参数: model_path: 模型文件路径 返回: LicensePlateYOLO: 初始化后的检测器实例 """ detector = LicensePlateYOLO(model_path) return detector if __name__ == "__main__": # 测试代码 detector = initialize_yolo_detector() print("检测器信息:", detector.get_model_info())