323 lines
12 KiB
Python
323 lines
12 KiB
Python
import cv2
|
||
import numpy as np
|
||
from ultralytics import YOLO
|
||
import os
|
||
from PIL import Image, ImageDraw, ImageFont
|
||
|
||
class LicensePlateYOLO:
|
||
"""
|
||
车牌YOLO检测器类
|
||
负责加载YOLO pose模型并进行车牌检测和角点提取
|
||
"""
|
||
|
||
def __init__(self, model_path=None):
|
||
"""
|
||
初始化YOLO检测器
|
||
|
||
参数:
|
||
model_path: 模型文件路径,如果为None则使用默认路径
|
||
"""
|
||
self.model = None
|
||
self.model_path = model_path or self._get_default_model_path()
|
||
self.class_names = {0: '蓝牌', 1: '绿牌'}
|
||
self.load_model()
|
||
|
||
def _get_default_model_path(self):
|
||
"""获取默认模型路径"""
|
||
current_dir = os.path.dirname(__file__)
|
||
return os.path.join(current_dir, "yolo11s-pose42.pt")
|
||
|
||
def load_model(self):
|
||
"""
|
||
加载YOLO pose模型
|
||
|
||
返回:
|
||
bool: 加载是否成功
|
||
"""
|
||
try:
|
||
if os.path.exists(self.model_path):
|
||
self.model = YOLO(self.model_path)
|
||
print(f"YOLO模型加载成功: {self.model_path}")
|
||
return True
|
||
else:
|
||
print(f"模型文件不存在: {self.model_path}")
|
||
return False
|
||
except Exception as e:
|
||
print(f"YOLO模型加载失败: {e}")
|
||
return False
|
||
|
||
def detect_license_plates(self, image, conf_threshold=0.5):
|
||
"""
|
||
检测图像中的车牌
|
||
|
||
参数:
|
||
image: 输入图像 (numpy数组)
|
||
conf_threshold: 置信度阈值
|
||
|
||
返回:
|
||
list: 检测结果列表,每个元素包含:
|
||
- box: 边界框坐标 [x1, y1, x2, y2]
|
||
- keypoints: 四个角点坐标 [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
|
||
- confidence: 置信度
|
||
- class_id: 类别ID (0=蓝牌, 1=绿牌)
|
||
- class_name: 类别名称
|
||
"""
|
||
if self.model is None:
|
||
print("模型未加载")
|
||
return []
|
||
|
||
try:
|
||
# 进行推理
|
||
results = self.model(image, conf=conf_threshold, verbose=False)
|
||
detections = []
|
||
|
||
for result in results:
|
||
# 检查是否有检测结果
|
||
if result.boxes is None or result.keypoints is None:
|
||
continue
|
||
|
||
# 提取检测信息
|
||
boxes = result.boxes.xyxy.cpu().numpy() # 边界框
|
||
keypoints = result.keypoints.xy.cpu().numpy() # 关键点
|
||
confidences = result.boxes.conf.cpu().numpy() # 置信度
|
||
classes = result.boxes.cls.cpu().numpy() # 类别
|
||
|
||
# 处理每个检测结果
|
||
for i in range(len(boxes)):
|
||
# 检查关键点数量是否为4个
|
||
if len(keypoints[i]) == 4:
|
||
class_id = int(classes[i])
|
||
detection = {
|
||
'box': boxes[i],
|
||
'keypoints': keypoints[i],
|
||
'confidence': confidences[i],
|
||
'class_id': class_id,
|
||
'class_name': self.class_names.get(class_id, '未知')
|
||
}
|
||
detections.append(detection)
|
||
else:
|
||
# 关键点不足4个,记录但标记为不完整
|
||
class_id = int(classes[i])
|
||
detection = {
|
||
'box': boxes[i],
|
||
'keypoints': keypoints[i] if len(keypoints[i]) > 0 else [],
|
||
'confidence': confidences[i],
|
||
'class_id': class_id,
|
||
'class_name': self.class_names.get(class_id, '未知'),
|
||
'incomplete': True # 标记为不完整
|
||
}
|
||
detections.append(detection)
|
||
|
||
return detections
|
||
|
||
except Exception as e:
|
||
print(f"检测过程中出错: {e}")
|
||
return []
|
||
|
||
def draw_detections(self, image, detections, plate_numbers=None):
|
||
"""
|
||
在图像上绘制检测结果
|
||
|
||
参数:
|
||
image: 输入图像
|
||
detections: 检测结果列表
|
||
plate_numbers: 车牌号列表,与detections对应
|
||
|
||
返回:
|
||
numpy.ndarray: 绘制了检测结果的图像
|
||
"""
|
||
draw_image = image.copy()
|
||
|
||
# 转换为PIL图像以支持中文字符
|
||
pil_image = Image.fromarray(cv2.cvtColor(draw_image, cv2.COLOR_BGR2RGB))
|
||
draw = ImageDraw.Draw(pil_image)
|
||
|
||
# 尝试加载中文字体
|
||
try:
|
||
# Windows系统常见的中文字体
|
||
font_path = "C:/Windows/Fonts/simhei.ttf" # 黑体
|
||
if not os.path.exists(font_path):
|
||
font_path = "C:/Windows/Fonts/msyh.ttc" # 微软雅黑
|
||
if not os.path.exists(font_path):
|
||
font_path = "C:/Windows/Fonts/simsun.ttc" # 宋体
|
||
|
||
font = ImageFont.truetype(font_path, 20)
|
||
except:
|
||
# 如果无法加载字体,使用默认字体
|
||
font = ImageFont.load_default()
|
||
|
||
for i, detection in enumerate(detections):
|
||
box = detection['box']
|
||
keypoints = detection['keypoints']
|
||
class_name = detection['class_name']
|
||
confidence = detection['confidence']
|
||
incomplete = detection.get('incomplete', False)
|
||
|
||
# 获取对应的车牌号
|
||
plate_number = ""
|
||
if plate_numbers and i < len(plate_numbers):
|
||
plate_number = plate_numbers[i]
|
||
|
||
# 绘制边界框
|
||
x1, y1, x2, y2 = map(int, box)
|
||
|
||
# 根据车牌类型选择颜色
|
||
if class_name == '绿牌':
|
||
box_color = (0, 255, 0) # 绿色
|
||
elif class_name == '蓝牌':
|
||
box_color = (0, 0, 255) # 蓝色
|
||
else:
|
||
box_color = (128, 128, 128) # 灰色
|
||
|
||
# 在PIL图像上绘制边界框
|
||
draw.rectangle([(x1, y1), (x2, y2)], outline=box_color, width=2)
|
||
|
||
# 构建标签文本
|
||
if plate_number:
|
||
label = f"{class_name} {plate_number} {confidence:.2f}"
|
||
else:
|
||
label = f"{class_name} {confidence:.2f}"
|
||
|
||
if incomplete:
|
||
label += " (不完整)"
|
||
|
||
# 计算文本大小
|
||
bbox = draw.textbbox((0, 0), label, font=font)
|
||
text_width = bbox[2] - bbox[0]
|
||
text_height = bbox[3] - bbox[1]
|
||
|
||
# 绘制文本背景
|
||
draw.rectangle([(x1, y1 - text_height - 10), (x1 + text_width, y1)],
|
||
fill=box_color)
|
||
|
||
# 绘制文本
|
||
draw.text((x1, y1 - text_height - 5), label, fill=(255, 255, 255), font=font)
|
||
|
||
# 转换回OpenCV格式
|
||
draw_image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
|
||
|
||
# 绘制关键点和连线(使用OpenCV)
|
||
for i, detection in enumerate(detections):
|
||
box = detection['box']
|
||
keypoints = detection['keypoints']
|
||
incomplete = detection.get('incomplete', False)
|
||
|
||
x1, y1, x2, y2 = map(int, box)
|
||
|
||
# 根据车牌类型选择颜色
|
||
class_name = detection['class_name']
|
||
if class_name == '绿牌':
|
||
box_color = (0, 255, 0) # 绿色
|
||
elif class_name == '蓝牌':
|
||
box_color = (0, 0, 255) # 蓝色
|
||
else:
|
||
box_color = (128, 128, 128) # 灰色
|
||
|
||
# 绘制关键点和连线
|
||
if len(keypoints) >= 4 and not incomplete:
|
||
# 四个角点完整,用黄色连线
|
||
points = [(int(kp[0]), int(kp[1])) for kp in keypoints[:4]]
|
||
|
||
# 绘制关键点
|
||
for point in points:
|
||
cv2.circle(draw_image, point, 5, (0, 255, 255), -1)
|
||
|
||
# 连接关键点形成四边形(按顺序连接)
|
||
# 假设关键点顺序为: right_bottom, left_bottom, left_top, right_top
|
||
for j in range(4):
|
||
cv2.line(draw_image, points[j], points[(j+1)%4], (0, 255, 255), 2)
|
||
|
||
elif len(keypoints) > 0:
|
||
# 关键点不完整,用红色标记现有点
|
||
for kp in keypoints:
|
||
point = (int(kp[0]), int(kp[1]))
|
||
cv2.circle(draw_image, point, 5, (0, 0, 255), -1)
|
||
|
||
return draw_image
|
||
|
||
def correct_license_plate(self, image, keypoints, target_size=(240, 80)):
|
||
"""
|
||
使用四个角点对车牌进行透视变换矫正
|
||
|
||
参数:
|
||
image: 原始图像
|
||
keypoints: 四个角点坐标
|
||
target_size: 目标尺寸 (width, height)
|
||
|
||
返回:
|
||
numpy.ndarray: 矫正后的车牌图像,如果失败返回None
|
||
"""
|
||
if len(keypoints) != 4:
|
||
return None
|
||
|
||
try:
|
||
# 将关键点转换为numpy数组
|
||
src_points = np.array(keypoints, dtype=np.float32)
|
||
|
||
# 定义目标矩形的四个角点
|
||
# 假设关键点顺序为: right_bottom, left_bottom, left_top, right_top
|
||
# 重新排序为标准顺序: left_top, right_top, right_bottom, left_bottom
|
||
width, height = target_size
|
||
dst_points = np.array([
|
||
[0, 0], # left_top
|
||
[width, 0], # right_top
|
||
[width, height], # right_bottom
|
||
[0, height] # left_bottom
|
||
], dtype=np.float32)
|
||
|
||
# 重新排序源点以匹配目标点
|
||
# 原顺序: right_bottom, left_bottom, left_top, right_top
|
||
# 目标顺序: left_top, right_top, right_bottom, left_bottom
|
||
reordered_src = np.array([
|
||
src_points[2], # left_top
|
||
src_points[3], # right_top
|
||
src_points[0], # right_bottom
|
||
src_points[1] # left_bottom
|
||
], dtype=np.float32)
|
||
|
||
# 计算透视变换矩阵
|
||
matrix = cv2.getPerspectiveTransform(reordered_src, dst_points)
|
||
|
||
# 应用透视变换
|
||
corrected = cv2.warpPerspective(image, matrix, target_size)
|
||
|
||
return corrected
|
||
|
||
except Exception as e:
|
||
print(f"车牌矫正失败: {e}")
|
||
return None
|
||
|
||
def get_model_info(self):
|
||
"""
|
||
获取模型信息
|
||
|
||
返回:
|
||
dict: 模型信息字典
|
||
"""
|
||
if self.model is None:
|
||
return {"status": "未加载", "path": self.model_path}
|
||
|
||
return {
|
||
"status": "已加载",
|
||
"path": self.model_path,
|
||
"model_type": "YOLO11 Pose",
|
||
"classes": self.class_names
|
||
}
|
||
|
||
def initialize_yolo_detector(model_path=None):
|
||
"""
|
||
初始化YOLO检测器的便捷函数
|
||
|
||
参数:
|
||
model_path: 模型文件路径
|
||
|
||
返回:
|
||
LicensePlateYOLO: 初始化后的检测器实例
|
||
"""
|
||
detector = LicensePlateYOLO(model_path)
|
||
return detector
|
||
|
||
if __name__ == "__main__":
|
||
# 测试代码
|
||
detector = initialize_yolo_detector()
|
||
print("检测器信息:", detector.get_model_info()) |