323 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import cv2
import numpy as np
from ultralytics import YOLO
import os
from PIL import Image, ImageDraw, ImageFont
class LicensePlateYOLO:
"""
车牌YOLO检测器类
负责加载YOLO pose模型并进行车牌检测和角点提取
"""
def __init__(self, model_path=None):
"""
初始化YOLO检测器
参数:
model_path: 模型文件路径如果为None则使用默认路径
"""
self.model = None
self.model_path = model_path or self._get_default_model_path()
self.class_names = {0: '蓝牌', 1: '绿牌'}
self.load_model()
def _get_default_model_path(self):
"""获取默认模型路径"""
current_dir = os.path.dirname(__file__)
return os.path.join(current_dir, "yolo11s-pose42.pt")
def load_model(self):
"""
加载YOLO pose模型
返回:
bool: 加载是否成功
"""
try:
if os.path.exists(self.model_path):
self.model = YOLO(self.model_path)
print(f"YOLO模型加载成功: {self.model_path}")
return True
else:
print(f"模型文件不存在: {self.model_path}")
return False
except Exception as e:
print(f"YOLO模型加载失败: {e}")
return False
def detect_license_plates(self, image, conf_threshold=0.5):
"""
检测图像中的车牌
参数:
image: 输入图像 (numpy数组)
conf_threshold: 置信度阈值
返回:
list: 检测结果列表,每个元素包含:
- box: 边界框坐标 [x1, y1, x2, y2]
- keypoints: 四个角点坐标 [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
- confidence: 置信度
- class_id: 类别ID (0=蓝牌, 1=绿牌)
- class_name: 类别名称
"""
if self.model is None:
print("模型未加载")
return []
try:
# 进行推理
results = self.model(image, conf=conf_threshold, verbose=False)
detections = []
for result in results:
# 检查是否有检测结果
if result.boxes is None or result.keypoints is None:
continue
# 提取检测信息
boxes = result.boxes.xyxy.cpu().numpy() # 边界框
keypoints = result.keypoints.xy.cpu().numpy() # 关键点
confidences = result.boxes.conf.cpu().numpy() # 置信度
classes = result.boxes.cls.cpu().numpy() # 类别
# 处理每个检测结果
for i in range(len(boxes)):
# 检查关键点数量是否为4个
if len(keypoints[i]) == 4:
class_id = int(classes[i])
detection = {
'box': boxes[i],
'keypoints': keypoints[i],
'confidence': confidences[i],
'class_id': class_id,
'class_name': self.class_names.get(class_id, '未知')
}
detections.append(detection)
else:
# 关键点不足4个记录但标记为不完整
class_id = int(classes[i])
detection = {
'box': boxes[i],
'keypoints': keypoints[i] if len(keypoints[i]) > 0 else [],
'confidence': confidences[i],
'class_id': class_id,
'class_name': self.class_names.get(class_id, '未知'),
'incomplete': True # 标记为不完整
}
detections.append(detection)
return detections
except Exception as e:
print(f"检测过程中出错: {e}")
return []
def draw_detections(self, image, detections, plate_numbers=None):
"""
在图像上绘制检测结果
参数:
image: 输入图像
detections: 检测结果列表
plate_numbers: 车牌号列表与detections对应
返回:
numpy.ndarray: 绘制了检测结果的图像
"""
draw_image = image.copy()
# 转换为PIL图像以支持中文字符
pil_image = Image.fromarray(cv2.cvtColor(draw_image, cv2.COLOR_BGR2RGB))
draw = ImageDraw.Draw(pil_image)
# 尝试加载中文字体
try:
# Windows系统常见的中文字体
font_path = "C:/Windows/Fonts/simhei.ttf" # 黑体
if not os.path.exists(font_path):
font_path = "C:/Windows/Fonts/msyh.ttc" # 微软雅黑
if not os.path.exists(font_path):
font_path = "C:/Windows/Fonts/simsun.ttc" # 宋体
font = ImageFont.truetype(font_path, 20)
except:
# 如果无法加载字体,使用默认字体
font = ImageFont.load_default()
for i, detection in enumerate(detections):
box = detection['box']
keypoints = detection['keypoints']
class_name = detection['class_name']
confidence = detection['confidence']
incomplete = detection.get('incomplete', False)
# 获取对应的车牌号
plate_number = ""
if plate_numbers and i < len(plate_numbers):
plate_number = plate_numbers[i]
# 绘制边界框
x1, y1, x2, y2 = map(int, box)
# 根据车牌类型选择颜色
if class_name == '绿牌':
box_color = (0, 255, 0) # 绿色
elif class_name == '蓝牌':
box_color = (0, 0, 255) # 蓝色
else:
box_color = (128, 128, 128) # 灰色
# 在PIL图像上绘制边界框
draw.rectangle([(x1, y1), (x2, y2)], outline=box_color, width=2)
# 构建标签文本
if plate_number:
label = f"{class_name} {plate_number} {confidence:.2f}"
else:
label = f"{class_name} {confidence:.2f}"
if incomplete:
label += " (不完整)"
# 计算文本大小
bbox = draw.textbbox((0, 0), label, font=font)
text_width = bbox[2] - bbox[0]
text_height = bbox[3] - bbox[1]
# 绘制文本背景
draw.rectangle([(x1, y1 - text_height - 10), (x1 + text_width, y1)],
fill=box_color)
# 绘制文本
draw.text((x1, y1 - text_height - 5), label, fill=(255, 255, 255), font=font)
# 转换回OpenCV格式
draw_image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
# 绘制关键点和连线使用OpenCV
for i, detection in enumerate(detections):
box = detection['box']
keypoints = detection['keypoints']
incomplete = detection.get('incomplete', False)
x1, y1, x2, y2 = map(int, box)
# 根据车牌类型选择颜色
class_name = detection['class_name']
if class_name == '绿牌':
box_color = (0, 255, 0) # 绿色
elif class_name == '蓝牌':
box_color = (0, 0, 255) # 蓝色
else:
box_color = (128, 128, 128) # 灰色
# 绘制关键点和连线
if len(keypoints) >= 4 and not incomplete:
# 四个角点完整,用黄色连线
points = [(int(kp[0]), int(kp[1])) for kp in keypoints[:4]]
# 绘制关键点
for point in points:
cv2.circle(draw_image, point, 5, (0, 255, 255), -1)
# 连接关键点形成四边形(按顺序连接)
# 假设关键点顺序为: right_bottom, left_bottom, left_top, right_top
for j in range(4):
cv2.line(draw_image, points[j], points[(j+1)%4], (0, 255, 255), 2)
elif len(keypoints) > 0:
# 关键点不完整,用红色标记现有点
for kp in keypoints:
point = (int(kp[0]), int(kp[1]))
cv2.circle(draw_image, point, 5, (0, 0, 255), -1)
return draw_image
def correct_license_plate(self, image, keypoints, target_size=(240, 80)):
"""
使用四个角点对车牌进行透视变换矫正
参数:
image: 原始图像
keypoints: 四个角点坐标
target_size: 目标尺寸 (width, height)
返回:
numpy.ndarray: 矫正后的车牌图像如果失败返回None
"""
if len(keypoints) != 4:
return None
try:
# 将关键点转换为numpy数组
src_points = np.array(keypoints, dtype=np.float32)
# 定义目标矩形的四个角点
# 假设关键点顺序为: right_bottom, left_bottom, left_top, right_top
# 重新排序为标准顺序: left_top, right_top, right_bottom, left_bottom
width, height = target_size
dst_points = np.array([
[0, 0], # left_top
[width, 0], # right_top
[width, height], # right_bottom
[0, height] # left_bottom
], dtype=np.float32)
# 重新排序源点以匹配目标点
# 原顺序: right_bottom, left_bottom, left_top, right_top
# 目标顺序: left_top, right_top, right_bottom, left_bottom
reordered_src = np.array([
src_points[2], # left_top
src_points[3], # right_top
src_points[0], # right_bottom
src_points[1] # left_bottom
], dtype=np.float32)
# 计算透视变换矩阵
matrix = cv2.getPerspectiveTransform(reordered_src, dst_points)
# 应用透视变换
corrected = cv2.warpPerspective(image, matrix, target_size)
return corrected
except Exception as e:
print(f"车牌矫正失败: {e}")
return None
def get_model_info(self):
"""
获取模型信息
返回:
dict: 模型信息字典
"""
if self.model is None:
return {"status": "未加载", "path": self.model_path}
return {
"status": "已加载",
"path": self.model_path,
"model_type": "YOLO11 Pose",
"classes": self.class_names
}
def initialize_yolo_detector(model_path=None):
"""
初始化YOLO检测器的便捷函数
参数:
model_path: 模型文件路径
返回:
LicensePlateYOLO: 初始化后的检测器实例
"""
detector = LicensePlateYOLO(model_path)
return detector
if __name__ == "__main__":
# 测试代码
detector = initialize_yolo_detector()
print("检测器信息:", detector.get_model_info())