323 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			323 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import cv2
 | 
						||
import numpy as np
 | 
						||
from ultralytics import YOLO
 | 
						||
import os
 | 
						||
from PIL import Image, ImageDraw, ImageFont
 | 
						||
 | 
						||
class LicensePlateYOLO:
 | 
						||
    """
 | 
						||
    车牌YOLO检测器类
 | 
						||
    负责加载YOLO pose模型并进行车牌检测和角点提取
 | 
						||
    """
 | 
						||
    
 | 
						||
    def __init__(self, model_path=None):
 | 
						||
        """
 | 
						||
        初始化YOLO检测器
 | 
						||
        
 | 
						||
        参数:
 | 
						||
            model_path: 模型文件路径,如果为None则使用默认路径
 | 
						||
        """
 | 
						||
        self.model = None
 | 
						||
        self.model_path = model_path or self._get_default_model_path()
 | 
						||
        self.class_names = {0: '蓝牌', 1: '绿牌'}
 | 
						||
        self.load_model()
 | 
						||
    
 | 
						||
    def _get_default_model_path(self):
 | 
						||
        """获取默认模型路径"""
 | 
						||
        current_dir = os.path.dirname(__file__)
 | 
						||
        return os.path.join(current_dir, "yolo11s-pose42.pt")
 | 
						||
    
 | 
						||
    def load_model(self):
 | 
						||
        """
 | 
						||
        加载YOLO pose模型
 | 
						||
        
 | 
						||
        返回:
 | 
						||
            bool: 加载是否成功
 | 
						||
        """
 | 
						||
        try:
 | 
						||
            if os.path.exists(self.model_path):
 | 
						||
                self.model = YOLO(self.model_path)
 | 
						||
                print(f"YOLO模型加载成功: {self.model_path}")
 | 
						||
                return True
 | 
						||
            else:
 | 
						||
                print(f"模型文件不存在: {self.model_path}")
 | 
						||
                return False
 | 
						||
        except Exception as e:
 | 
						||
            print(f"YOLO模型加载失败: {e}")
 | 
						||
            return False
 | 
						||
    
 | 
						||
    def detect_license_plates(self, image, conf_threshold=0.6):
 | 
						||
        """
 | 
						||
        检测图像中的车牌
 | 
						||
        
 | 
						||
        参数:
 | 
						||
            image: 输入图像 (numpy数组)
 | 
						||
            conf_threshold: 置信度阈值
 | 
						||
        
 | 
						||
        返回:
 | 
						||
            list: 检测结果列表,每个元素包含:
 | 
						||
                - box: 边界框坐标 [x1, y1, x2, y2]
 | 
						||
                - keypoints: 四个角点坐标 [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
 | 
						||
                - confidence: 置信度
 | 
						||
                - class_id: 类别ID (0=蓝牌, 1=绿牌)
 | 
						||
                - class_name: 类别名称
 | 
						||
        """
 | 
						||
        if self.model is None:
 | 
						||
            print("模型未加载")
 | 
						||
            return []
 | 
						||
        
 | 
						||
        try:
 | 
						||
            # 进行推理
 | 
						||
            results = self.model(image, conf=conf_threshold, verbose=False)
 | 
						||
            detections = []
 | 
						||
            
 | 
						||
            for result in results:
 | 
						||
                # 检查是否有检测结果
 | 
						||
                if result.boxes is None or result.keypoints is None:
 | 
						||
                    continue
 | 
						||
                
 | 
						||
                # 提取检测信息
 | 
						||
                boxes = result.boxes.xyxy.cpu().numpy()  # 边界框
 | 
						||
                keypoints = result.keypoints.xy.cpu().numpy()  # 关键点
 | 
						||
                confidences = result.boxes.conf.cpu().numpy()  # 置信度
 | 
						||
                classes = result.boxes.cls.cpu().numpy()  # 类别
 | 
						||
                
 | 
						||
                # 处理每个检测结果
 | 
						||
                for i in range(len(boxes)):
 | 
						||
                    # 检查关键点数量是否为4个
 | 
						||
                    if len(keypoints[i]) == 4:
 | 
						||
                        class_id = int(classes[i])
 | 
						||
                        detection = {
 | 
						||
                            'box': boxes[i],
 | 
						||
                            'keypoints': keypoints[i],
 | 
						||
                            'confidence': confidences[i],
 | 
						||
                            'class_id': class_id,
 | 
						||
                            'class_name': self.class_names.get(class_id, '未知')
 | 
						||
                        }
 | 
						||
                        detections.append(detection)
 | 
						||
                    else:
 | 
						||
                        # 关键点不足4个,记录但标记为不完整
 | 
						||
                        class_id = int(classes[i])
 | 
						||
                        detection = {
 | 
						||
                            'box': boxes[i],
 | 
						||
                            'keypoints': keypoints[i] if len(keypoints[i]) > 0 else [],
 | 
						||
                            'confidence': confidences[i],
 | 
						||
                            'class_id': class_id,
 | 
						||
                            'class_name': self.class_names.get(class_id, '未知'),
 | 
						||
                            'incomplete': True  # 标记为不完整
 | 
						||
                        }
 | 
						||
                        detections.append(detection)
 | 
						||
            
 | 
						||
            return detections
 | 
						||
            
 | 
						||
        except Exception as e:
 | 
						||
            print(f"检测过程中出错: {e}")
 | 
						||
            return []
 | 
						||
    
 | 
						||
    def draw_detections(self, image, detections, plate_numbers=None):
 | 
						||
        """
 | 
						||
        在图像上绘制检测结果
 | 
						||
        
 | 
						||
        参数:
 | 
						||
            image: 输入图像
 | 
						||
            detections: 检测结果列表
 | 
						||
            plate_numbers: 车牌号列表,与detections对应
 | 
						||
        
 | 
						||
        返回:
 | 
						||
            numpy.ndarray: 绘制了检测结果的图像
 | 
						||
        """
 | 
						||
        draw_image = image.copy()
 | 
						||
        
 | 
						||
        # 转换为PIL图像以支持中文字符
 | 
						||
        pil_image = Image.fromarray(cv2.cvtColor(draw_image, cv2.COLOR_BGR2RGB))
 | 
						||
        draw = ImageDraw.Draw(pil_image)
 | 
						||
        
 | 
						||
        # 尝试加载中文字体
 | 
						||
        try:
 | 
						||
            # Windows系统常见的中文字体
 | 
						||
            font_path = "C:/Windows/Fonts/simhei.ttf"  # 黑体
 | 
						||
            if not os.path.exists(font_path):
 | 
						||
                font_path = "C:/Windows/Fonts/msyh.ttc"  # 微软雅黑
 | 
						||
            if not os.path.exists(font_path):
 | 
						||
                font_path = "C:/Windows/Fonts/simsun.ttc"  # 宋体
 | 
						||
            
 | 
						||
            font = ImageFont.truetype(font_path, 20)
 | 
						||
        except:
 | 
						||
            # 如果无法加载字体,使用默认字体
 | 
						||
            font = ImageFont.load_default()
 | 
						||
        
 | 
						||
        for i, detection in enumerate(detections):
 | 
						||
            box = detection['box']
 | 
						||
            keypoints = detection['keypoints']
 | 
						||
            class_name = detection['class_name']
 | 
						||
            confidence = detection['confidence']
 | 
						||
            incomplete = detection.get('incomplete', False)
 | 
						||
            
 | 
						||
            # 获取对应的车牌号
 | 
						||
            plate_number = ""
 | 
						||
            if plate_numbers and i < len(plate_numbers):
 | 
						||
                plate_number = plate_numbers[i]
 | 
						||
            
 | 
						||
            # 绘制边界框
 | 
						||
            x1, y1, x2, y2 = map(int, box)
 | 
						||
            
 | 
						||
            # 根据车牌类型选择颜色
 | 
						||
            if class_name == '绿牌':
 | 
						||
                box_color = (0, 255, 0)  # 绿色
 | 
						||
            elif class_name == '蓝牌':
 | 
						||
                box_color = (0, 0, 255)  # 蓝色
 | 
						||
            else:
 | 
						||
                box_color = (128, 128, 128)  # 灰色
 | 
						||
            
 | 
						||
            # 在PIL图像上绘制边界框
 | 
						||
            draw.rectangle([(x1, y1), (x2, y2)], outline=box_color, width=2)
 | 
						||
            
 | 
						||
            # 构建标签文本
 | 
						||
            if plate_number:
 | 
						||
                label = f"{class_name} {plate_number} {confidence:.2f}"
 | 
						||
            else:
 | 
						||
                label = f"{class_name} {confidence:.2f}"
 | 
						||
            
 | 
						||
            if incomplete:
 | 
						||
                label += " (不完整)"
 | 
						||
            
 | 
						||
            # 计算文本大小
 | 
						||
            bbox = draw.textbbox((0, 0), label, font=font)
 | 
						||
            text_width = bbox[2] - bbox[0]
 | 
						||
            text_height = bbox[3] - bbox[1]
 | 
						||
            
 | 
						||
            # 绘制文本背景
 | 
						||
            draw.rectangle([(x1, y1 - text_height - 10), (x1 + text_width, y1)], 
 | 
						||
                         fill=box_color)
 | 
						||
            
 | 
						||
            # 绘制文本
 | 
						||
            draw.text((x1, y1 - text_height - 5), label, fill=(255, 255, 255), font=font)
 | 
						||
        
 | 
						||
        # 转换回OpenCV格式
 | 
						||
        draw_image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
 | 
						||
        
 | 
						||
        # 绘制关键点和连线(使用OpenCV)
 | 
						||
        for i, detection in enumerate(detections):
 | 
						||
            box = detection['box']
 | 
						||
            keypoints = detection['keypoints']
 | 
						||
            incomplete = detection.get('incomplete', False)
 | 
						||
            
 | 
						||
            x1, y1, x2, y2 = map(int, box)
 | 
						||
            
 | 
						||
            # 根据车牌类型选择颜色
 | 
						||
            class_name = detection['class_name']
 | 
						||
            if class_name == '绿牌':
 | 
						||
                box_color = (0, 255, 0)  # 绿色
 | 
						||
            elif class_name == '蓝牌':
 | 
						||
                box_color = (0, 0, 255)  # 蓝色
 | 
						||
            else:
 | 
						||
                box_color = (128, 128, 128)  # 灰色
 | 
						||
            
 | 
						||
            # 绘制关键点和连线
 | 
						||
            if len(keypoints) >= 4 and not incomplete:
 | 
						||
                # 四个角点完整,用黄色连线
 | 
						||
                points = [(int(kp[0]), int(kp[1])) for kp in keypoints[:4]]
 | 
						||
                
 | 
						||
                # 绘制关键点
 | 
						||
                for point in points:
 | 
						||
                    cv2.circle(draw_image, point, 5, (0, 255, 255), -1)
 | 
						||
                
 | 
						||
                # 连接关键点形成四边形(按顺序连接)
 | 
						||
                # 假设关键点顺序为: right_bottom, left_bottom, left_top, right_top
 | 
						||
                for j in range(4):
 | 
						||
                    cv2.line(draw_image, points[j], points[(j+1)%4], (0, 255, 255), 2)
 | 
						||
                
 | 
						||
            elif len(keypoints) > 0:
 | 
						||
                # 关键点不完整,用红色标记现有点
 | 
						||
                for kp in keypoints:
 | 
						||
                    point = (int(kp[0]), int(kp[1]))
 | 
						||
                    cv2.circle(draw_image, point, 5, (0, 0, 255), -1)
 | 
						||
        
 | 
						||
        return draw_image
 | 
						||
    
 | 
						||
    def correct_license_plate(self, image, keypoints, target_size=(240, 80)):
 | 
						||
        """
 | 
						||
        使用四个角点对车牌进行透视变换矫正
 | 
						||
        
 | 
						||
        参数:
 | 
						||
            image: 原始图像
 | 
						||
            keypoints: 四个角点坐标
 | 
						||
            target_size: 目标尺寸 (width, height)
 | 
						||
        
 | 
						||
        返回:
 | 
						||
            numpy.ndarray: 矫正后的车牌图像,如果失败返回None
 | 
						||
        """
 | 
						||
        if len(keypoints) != 4:
 | 
						||
            return None
 | 
						||
        
 | 
						||
        try:
 | 
						||
            # 将关键点转换为numpy数组
 | 
						||
            src_points = np.array(keypoints, dtype=np.float32)
 | 
						||
            
 | 
						||
            # 定义目标矩形的四个角点
 | 
						||
            # 假设关键点顺序为: right_bottom, left_bottom, left_top, right_top
 | 
						||
            # 重新排序为标准顺序: left_top, right_top, right_bottom, left_bottom
 | 
						||
            width, height = target_size
 | 
						||
            dst_points = np.array([
 | 
						||
                [0, 0],           # left_top
 | 
						||
                [width, 0],       # right_top
 | 
						||
                [width, height],  # right_bottom
 | 
						||
                [0, height]       # left_bottom
 | 
						||
            ], dtype=np.float32)
 | 
						||
            
 | 
						||
            # 重新排序源点以匹配目标点
 | 
						||
            # 原顺序: right_bottom, left_bottom, left_top, right_top
 | 
						||
            # 目标顺序: left_top, right_top, right_bottom, left_bottom
 | 
						||
            reordered_src = np.array([
 | 
						||
                src_points[2],  # left_top
 | 
						||
                src_points[3],  # right_top
 | 
						||
                src_points[0],  # right_bottom
 | 
						||
                src_points[1]   # left_bottom
 | 
						||
            ], dtype=np.float32)
 | 
						||
            
 | 
						||
            # 计算透视变换矩阵
 | 
						||
            matrix = cv2.getPerspectiveTransform(reordered_src, dst_points)
 | 
						||
            
 | 
						||
            # 应用透视变换
 | 
						||
            corrected = cv2.warpPerspective(image, matrix, target_size)
 | 
						||
            
 | 
						||
            return corrected
 | 
						||
            
 | 
						||
        except Exception as e:
 | 
						||
            print(f"车牌矫正失败: {e}")
 | 
						||
            return None
 | 
						||
    
 | 
						||
    def get_model_info(self):
 | 
						||
        """
 | 
						||
        获取模型信息
 | 
						||
        
 | 
						||
        返回:
 | 
						||
            dict: 模型信息字典
 | 
						||
        """
 | 
						||
        if self.model is None:
 | 
						||
            return {"status": "未加载", "path": self.model_path}
 | 
						||
        
 | 
						||
        return {
 | 
						||
            "status": "已加载",
 | 
						||
            "path": self.model_path,
 | 
						||
            "model_type": "YOLO11 Pose",
 | 
						||
            "classes": self.class_names
 | 
						||
        }
 | 
						||
 | 
						||
def initialize_yolo_detector(model_path=None):
 | 
						||
    """
 | 
						||
    初始化YOLO检测器的便捷函数
 | 
						||
    
 | 
						||
    参数:
 | 
						||
        model_path: 模型文件路径
 | 
						||
    
 | 
						||
    返回:
 | 
						||
        LicensePlateYOLO: 初始化后的检测器实例
 | 
						||
    """
 | 
						||
    detector = LicensePlateYOLO(model_path)
 | 
						||
    return detector
 | 
						||
 | 
						||
if __name__ == "__main__":
 | 
						||
    # 测试代码
 | 
						||
    detector = initialize_yolo_detector()
 | 
						||
    print("检测器信息:", detector.get_model_info()) |