diff --git a/CRNN_part/best_model.pth b/CRNN_part/best_model.pth index 4054755..e9fa96d 100644 Binary files a/CRNN_part/best_model.pth and b/CRNN_part/best_model.pth differ diff --git a/LPRNET_part/1.jpg b/LPRNET_part/1.jpg deleted file mode 100644 index 32ddef3..0000000 Binary files a/LPRNET_part/1.jpg and /dev/null differ diff --git a/LPRNET_part/2.jpg b/LPRNET_part/2.jpg deleted file mode 100644 index 29d94e6..0000000 Binary files a/LPRNET_part/2.jpg and /dev/null differ diff --git a/LPRNET_part/6ce2ec7dbed6cf3c8403abe2683c57e5.jpg b/LPRNET_part/6ce2ec7dbed6cf3c8403abe2683c57e5.jpg deleted file mode 100644 index 3f8a8f8..0000000 Binary files a/LPRNET_part/6ce2ec7dbed6cf3c8403abe2683c57e5.jpg and /dev/null differ diff --git a/LPRNET_part/LPRNet__iteration_74000.pth b/LPRNET_part/LPRNet__iteration_74000.pth deleted file mode 100644 index 037122c..0000000 Binary files a/LPRNET_part/LPRNet__iteration_74000.pth and /dev/null differ diff --git a/LPRNET_part/c11304d10bcd47911e458398d1ea445d.jpg b/LPRNET_part/c11304d10bcd47911e458398d1ea445d.jpg deleted file mode 100644 index 570ae94..0000000 Binary files a/LPRNET_part/c11304d10bcd47911e458398d1ea445d.jpg and /dev/null differ diff --git a/LPRNET_part/c6ab0fbcfb2b6fbe15c5b3eb9806a28b.jpg b/LPRNET_part/c6ab0fbcfb2b6fbe15c5b3eb9806a28b.jpg deleted file mode 100644 index 843a03d..0000000 Binary files a/LPRNET_part/c6ab0fbcfb2b6fbe15c5b3eb9806a28b.jpg and /dev/null differ diff --git a/LPRNET_part/lpr_interface.py b/LPRNET_part/lpr_interface.py deleted file mode 100644 index 0f87201..0000000 --- a/LPRNET_part/lpr_interface.py +++ /dev/null @@ -1,372 +0,0 @@ -# 导入必要的库 -import torch -import torch.nn as nn -import cv2 -import numpy as np -import os -import sys -from torch.autograd import Variable -from PIL import Image - -# 添加父目录到路径,以便导入模型和数据加载器 -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -# LPRNet字符集定义(与训练时保持一致) -# 包含中国省份简称、数字、字母和特殊字符 -CHARS = ['京', '沪', '津', '渝', '冀', '晋', '蒙', '辽', '吉', '黑', - '苏', '浙', '皖', '闽', '赣', '鲁', '豫', '鄂', '湘', '粤', - '桂', '琼', '川', '贵', '云', '藏', '陕', '甘', '青', '宁', '新', - '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', - 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J', 'K', - 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', - 'W', 'X', 'Y', 'Z', 'I', 'O', '-'] - -# 创建字符到索引的映射字典 -CHARS_DICT = {char: i for i, char in enumerate(CHARS)} - -# 简化的LPRNet模型定义 - 基础卷积块 -class small_basic_block(nn.Module): - def __init__(self, ch_in, ch_out): - super(small_basic_block, self).__init__() - # 定义一个小的基本卷积块,包含四个卷积层 - self.block = nn.Sequential( - # 1x1卷积,降低通道数 - nn.Conv2d(ch_in, ch_out // 4, kernel_size=1), - nn.ReLU(), - # 3x1卷积,处理水平特征 - nn.Conv2d(ch_out // 4, ch_out // 4, kernel_size=(3, 1), padding=(1, 0)), - nn.ReLU(), - # 1x3卷积,处理垂直特征 - nn.Conv2d(ch_out // 4, ch_out // 4, kernel_size=(1, 3), padding=(0, 1)), - nn.ReLU(), - # 1x1卷积,恢复通道数 - nn.Conv2d(ch_out // 4, ch_out, kernel_size=1), - ) - - def forward(self, x): - return self.block(x) - -# LPRNet模型定义 - 车牌识别网络 -class LPRNet(nn.Module): - def __init__(self, lpr_max_len, phase, class_num, dropout_rate): - super(LPRNet, self).__init__() - self.phase = phase - self.lpr_max_len = lpr_max_len - self.class_num = class_num - - # 定义主干网络 - self.backbone = nn.Sequential( - # 初始卷积层 - nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1), # 0 - nn.BatchNorm2d(num_features=64), - nn.ReLU(), # 2 - # 最大池化层 - nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 1, 1)), - # 第一个基本块 - small_basic_block(ch_in=64, ch_out=128), # *** 4 *** - nn.BatchNorm2d(num_features=128), - nn.ReLU(), # 6 - # 第二个池化层 - nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(2, 1, 2)), - # 第二个基本块 - small_basic_block(ch_in=64, ch_out=256), # 8 - nn.BatchNorm2d(num_features=256), - nn.ReLU(), # 10 - # 第三个基本块 - small_basic_block(ch_in=256, ch_out=256), # *** 11 *** - nn.BatchNorm2d(num_features=256), - nn.ReLU(), # 13 - # 第三个池化层 - nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(4, 1, 2)), # 14 - # Dropout层,防止过拟合 - nn.Dropout(dropout_rate), - # 特征提取卷积层 - nn.Conv2d(in_channels=64, out_channels=256, kernel_size=(1, 4), stride=1), # 16 - nn.BatchNorm2d(num_features=256), - nn.ReLU(), # 18 - # 第二个Dropout层 - nn.Dropout(dropout_rate), - # 分类卷积层 - nn.Conv2d(in_channels=256, out_channels=class_num, kernel_size=(13, 1), stride=1), # 20 - nn.BatchNorm2d(num_features=class_num), - nn.ReLU(), # 22 - ) - - # 定义容器层,用于融合全局上下文信息 - self.container = nn.Sequential( - nn.Conv2d(in_channels=448+self.class_num, out_channels=self.class_num, kernel_size=(1,1), stride=(1,1)), - ) - - def forward(self, x): - # 保存中间特征 - keep_features = list() - for i, layer in enumerate(self.backbone.children()): - x = layer(x) - # 保存特定层的输出特征 - if i in [2, 6, 13, 22]: # [2, 4, 8, 11, 22] - keep_features.append(x) - - # 处理全局上下文信息 - global_context = list() - for i, f in enumerate(keep_features): - # 对不同层的特征进行不同尺度的平均池化 - if i in [0, 1]: - f = nn.AvgPool2d(kernel_size=5, stride=5)(f) - if i in [2]: - f = nn.AvgPool2d(kernel_size=(4, 10), stride=(4, 2))(f) - # 对特征进行归一化处理 - f_pow = torch.pow(f, 2) - f_mean = torch.mean(f_pow) - f = torch.div(f, f_mean) - global_context.append(f) - - # 拼接全局上下文特征 - x = torch.cat(global_context, 1) - # 通过容器层处理 - x = self.container(x) - # 对序列维度进行平均,得到最终输出 - logits = torch.mean(x, dim=2) - - return logits - -# LPRNet推理类 -class LPRNetInference: - def __init__(self, model_path=None, img_size=[94, 24], lpr_max_len=8, dropout_rate=0.5): - """ - 初始化LPRNet推理类 - Args: - model_path: 训练好的模型权重文件路径 - img_size: 输入图像尺寸 [width, height] - lpr_max_len: 车牌最大长度 - dropout_rate: dropout率 - """ - self.img_size = img_size - self.lpr_max_len = lpr_max_len - # 检测是否有可用的CUDA设备 - self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - - # 设置默认模型路径 - if model_path is None: - current_dir = os.path.dirname(os.path.abspath(__file__)) - model_path = os.path.join(current_dir, 'LPRNet__iteration_74000.pth') - - # 初始化模型 - self.model = LPRNet(lpr_max_len=lpr_max_len, phase=False, class_num=len(CHARS), dropout_rate=dropout_rate) - - # 加载模型权重 - if model_path and os.path.exists(model_path): - print(f"Loading LPRNet model from {model_path}") - try: - self.model.load_state_dict(torch.load(model_path, map_location=self.device)) - print("LPRNet模型权重加载成功") - except Exception as e: - print(f"Warning: 加载模型权重失败: {e}. 使用随机权重.") - else: - print(f"Warning: 模型文件不存在或未指定: {model_path}. 使用随机权重.") - - # 将模型移动到指定设备并设置为评估模式 - self.model.to(self.device) - self.model.eval() - - print(f"LPRNet模型加载完成,设备: {self.device}") - print(f"模型参数数量: {sum(p.numel() for p in self.model.parameters()):,}") - - def preprocess_image(self, image_array): - """ - 预处理图像数组 - 使用与训练时相同的预处理方式 - Args: - image_array: numpy数组格式的图像 (H, W, C) - Returns: - preprocessed_image: 预处理后的图像tensor - """ - if image_array is None: - raise ValueError("Input image is None") - - # 确保图像是numpy数组 - if not isinstance(image_array, np.ndarray): - raise ValueError("Input must be numpy array") - - # 检查图像维度 - if len(image_array.shape) != 3: - raise ValueError(f"Expected 3D image array, got {len(image_array.shape)}D") - - height, width, channels = image_array.shape - if channels != 3: - raise ValueError(f"Expected 3 channels, got {channels}") - - # 调整图像尺寸到模型要求的尺寸 - if height != self.img_size[1] or width != self.img_size[0]: - image_array = cv2.resize(image_array, tuple(self.img_size)) - - # 使用与训练时相同的预处理方式 - # 归一化处理:减去127.5并乘以0.0078125,将像素值从[0,255]映射到[-1,1] - image_array = image_array.astype('float32') - image_array -= 127.5 - image_array *= 0.0078125 - # 调整维度顺序从HWC到CHW - image_array = np.transpose(image_array, (2, 0, 1)) # HWC -> CHW - - # 转换为tensor并添加batch维度 - image_tensor = torch.from_numpy(image_array).unsqueeze(0) - - return image_tensor - - def decode_prediction(self, logits): - """ - 解码模型预测结果 - 使用正确的CTC贪婪解码 - Args: - logits: 模型输出的logits [batch_size, num_classes, sequence_length] - Returns: - predicted_text: 预测的车牌号码 - """ - # 转换为numpy进行处理 - prebs = logits.cpu().detach().numpy() - preb = prebs[0, :, :] # 取第一个batch [num_classes, sequence_length] - - # 贪婪解码: 对每个时间步选择最大概率的字符 - preb_label = [] - for j in range(preb.shape[1]): # 遍历每个时间步 - preb_label.append(np.argmax(preb[:, j], axis=0)) - - # CTC解码:去除重复字符和空白字符 - no_repeat_blank_label = [] - pre_c = preb_label[0] - - # 处理第一个字符 - if pre_c != len(CHARS) - 1: # 不是空白字符 - no_repeat_blank_label.append(pre_c) - - # 处理后续字符 - for c in preb_label: - if (pre_c == c) or (c == len(CHARS) - 1): # 重复字符或空白字符 - if c == len(CHARS) - 1: - pre_c = c - continue - no_repeat_blank_label.append(c) - pre_c = c - - # 转换为字符 - decoded_chars = [CHARS[idx] for idx in no_repeat_blank_label] - return ''.join(decoded_chars) - - def predict(self, image_array): - """ - 预测单张图像的车牌号码 - Args: - image_array: numpy数组格式的图像 - Returns: - prediction: 预测的车牌号码 - confidence: 预测置信度 - """ - try: - # 预处理图像 - image = self.preprocess_image(image_array) - if image is None: - return None, 0.0 - - image = image.to(self.device) - - # 模型推理 - with torch.no_grad(): - logits = self.model(image) - # logits shape: [batch_size, class_num, sequence_length] - - # 计算置信度(使用softmax后的最大概率平均值) - probs = torch.softmax(logits, dim=1) - max_probs = torch.max(probs, dim=1)[0] - confidence = torch.mean(max_probs).item() - - # 解码预测结果 - prediction = self.decode_prediction(logits) - - return prediction, confidence - - except Exception as e: - print(f"预测图像失败: {e}") - return None, 0.0 - -# 全局变量,用于存储模型实例 -lpr_model = None - -def LPRNinitialize_model(): - """ - 初始化LPRNet模型 - - 返回: - bool: 初始化是否成功 - """ - global lpr_model - - try: - # 模型权重文件路径 - model_path = os.path.join(os.path.dirname(__file__), 'LPRNet__iteration_74000.pth') - - # 创建推理对象 - lpr_model = LPRNetInference(model_path) - - print("LPRNet模型初始化完成") - return True - - except Exception as e: - print(f"LPRNet模型初始化失败: {e}") - import traceback - traceback.print_exc() - return False - -def LPRNmodel_predict(image_array): - """ - LPRNet车牌号识别接口函数 - - 参数: - image_array: numpy数组格式的车牌图像,已经过矫正处理 - - 返回: - list: 包含最多8个字符的列表,代表车牌号的每个字符 - 例如: ['京', 'A', '1', '2', '3', '4', '5'] (蓝牌7位) - ['京', 'A', 'D', '1', '2', '3', '4', '5'] (绿牌8位) - """ - global lpr_model - - if lpr_model is None: - print("LPRNet模型未初始化,请先调用LPRNinitialize_model()") - return ['待', '识', '别', '0', '0', '0', '0', '0'] - - try: - # 使用OpenCV调整图像大小到模型要求的尺寸 - image_array = cv2.resize(image_array, (94, 24)) - print(f"666999图片尺寸: {image_array.shape}") - - # 显示修正后的图像 - cv2.imshow('Resized License Plate Image (94x24)', image_array) - cv2.waitKey(1) # 非阻塞显示,允许程序继续执行 - # 预测车牌号 - predicted_text, confidence = lpr_model.predict(image_array) - - if predicted_text is None: - print("LPRNet识别失败") - return ['识', '别', '失', '败', '0', '0', '0', '0'] - - print(f"LPRNet识别结果: {predicted_text}, 置信度: {confidence:.3f}") - - # 将字符串转换为字符列表 - char_list = list(predicted_text) - - # 确保返回至少7个字符,最多8个字符 - if len(char_list) < 7: - # 如果识别结果少于7个字符,用'0'补齐到7位 - char_list.extend(['0'] * (7 - len(char_list))) - elif len(char_list) > 8: - # 如果识别结果多于8个字符,截取前8个 - char_list = char_list[:8] - - # 如果是7位,补齐到8位以保持接口一致性(第8位用空字符或占位符) - if len(char_list) == 7: - char_list.append('') # 添加空字符作为第8位占位符 - - return char_list - - except Exception as e: - print(f"LPRNet识别失败: {e}") - import traceback - traceback.print_exc() - return ['识', '别', '失', '败', '0', '0', '0', '0'] \ No newline at end of file diff --git a/lightCRNN_part/best_model.pth b/lightCRNN_part/best_model.pth index 122cf25..8a575dc 100644 Binary files a/lightCRNN_part/best_model.pth and b/lightCRNN_part/best_model.pth differ diff --git a/main.py b/main.py index 132ad5a..1058d9d 100644 --- a/main.py +++ b/main.py @@ -308,7 +308,7 @@ class MainWindow(QMainWindow): method_label.setFont(QFont("Arial", 10)) self.method_combo = QComboBox() - self.method_combo.addItems(["CRNN", "LPRNET", "OCR"]) + self.method_combo.addItems(["CRNN", "LightCRNN", "OCR"]) self.method_combo.setCurrentText("CRNN") # 默认选择CRNN self.method_combo.currentTextChanged.connect(self.change_recognition_method) @@ -578,7 +578,19 @@ class MainWindow(QMainWindow): def draw_detections(self, frame): """在图像上绘制检测结果""" - return self.detector.draw_detections(frame, self.detections) + # 获取车牌号列表 + plate_numbers = [] + for detection in self.detections: + # 矫正车牌图像 + corrected_image = self.correct_license_plate(detection) + # 获取车牌号 + if corrected_image is not None: + plate_number = self.recognize_plate_number(corrected_image, detection['class_name']) + plate_numbers.append(plate_number) + else: + plate_numbers.append("识别失败") + + return self.detector.draw_detections(frame, self.detections, plate_numbers) def display_frame(self, frame): """显示帧到界面""" @@ -760,7 +772,7 @@ class MainWindow(QMainWindow): # 根据当前选择的识别方法调用相应的函数 if self.current_recognition_method == "CRNN": from CRNN_part.crnn_interface import LPRNmodel_predict - elif self.current_recognition_method == "LPRNET": + elif self.current_recognition_method == "LightCRNN": from lightCRNN_part.lightcrnn_interface import LPRNmodel_predict elif self.current_recognition_method == "OCR": from OCR_part.ocr_interface import LPRNmodel_predict @@ -798,7 +810,7 @@ class MainWindow(QMainWindow): if method == "CRNN": from CRNN_part.crnn_interface import LPRNinitialize_model LPRNinitialize_model() - elif method == "LPRNET": + elif method == "LightCRNN": from lightCRNN_part.lightcrnn_interface import LPRNinitialize_model LPRNinitialize_model() elif method == "OCR": diff --git a/test_lpr_real_images.py b/test_lpr_real_images.py deleted file mode 100644 index b3f859b..0000000 --- a/test_lpr_real_images.py +++ /dev/null @@ -1,100 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" -LPRNet接口真实图片测试脚本 -测试LPRNET_part目录下的真实车牌图片 -""" - -import cv2 -import numpy as np -import os -from LPRNET_part.lpr_interface import LPRNinitialize_model, LPRNmodel_predict - -def test_real_images(): - """ - 测试LPRNET_part目录下的真实车牌图片 - """ - print("=== LPRNet真实图片测试 ===") - - # 初始化模型 - print("1. 初始化LPRNet模型...") - success = LPRNinitialize_model() - if not success: - print("模型初始化失败!") - return - - # 获取LPRNET_part目录下的图片文件 - lprnet_dir = "LPRNET_part" - image_files = [] - - if os.path.exists(lprnet_dir): - for file in os.listdir(lprnet_dir): - if file.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp')): - image_files.append(os.path.join(lprnet_dir, file)) - - if not image_files: - print("未找到图片文件!") - return - - print(f"2. 找到 {len(image_files)} 个图片文件") - - # 测试每个图片 - for i, image_path in enumerate(image_files, 1): - print(f"\n--- 测试图片 {i}: {os.path.basename(image_path)} ---") - - try: - # 使用支持中文路径的方式读取图片 - image = cv2.imdecode(np.fromfile(image_path, dtype=np.uint8), cv2.IMREAD_COLOR) - - if image is None: - print(f"无法读取图片: {image_path}") - continue - - print(f"图片尺寸: {image.shape}") - - # 进行预测 - result = LPRNmodel_predict(image) - print(f"识别结果: {result}") - print(f"识别车牌号: {''.join(result)}") - - except Exception as e: - print(f"处理图片 {image_path} 时出错: {e}") - import traceback - traceback.print_exc() - - print("\n=== 测试完成 ===") - -def test_image_loading(): - """ - 测试图片加载方式 - """ - print("\n=== 图片加载测试 ===") - - lprnet_dir = "LPRNET_part" - - if os.path.exists(lprnet_dir): - for file in os.listdir(lprnet_dir): - if file.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp')): - image_path = os.path.join(lprnet_dir, file) - print(f"\n测试文件: {file}") - - # 方法1: 普通cv2.imread - img1 = cv2.imread(image_path) - print(f"cv2.imread结果: {img1 is not None}") - - # 方法2: 支持中文路径的方式 - try: - img2 = cv2.imdecode(np.fromfile(image_path, dtype=np.uint8), cv2.IMREAD_COLOR) - # img2 = cv2.resize(img2,(128,48)) - print(f"cv2.imdecode结果: {img2 is not None}") - if img2 is not None: - print(f"图片尺寸: {img2.shape}") - except Exception as e: - print(f"cv2.imdecode失败: {e}") - -if __name__ == "__main__": - # 首先测试图片加载 - test_image_loading() - - # 然后测试完整的识别流程 - test_real_images() \ No newline at end of file diff --git a/yolopart/detector.py b/yolopart/detector.py index b95fd80..435c4fb 100644 --- a/yolopart/detector.py +++ b/yolopart/detector.py @@ -2,6 +2,7 @@ import cv2 import numpy as np from ultralytics import YOLO import os +from PIL import Image, ImageDraw, ImageFont class LicensePlateYOLO: """ @@ -113,19 +114,38 @@ class LicensePlateYOLO: print(f"检测过程中出错: {e}") return [] - def draw_detections(self, image, detections): + def draw_detections(self, image, detections, plate_numbers=None): """ 在图像上绘制检测结果 参数: image: 输入图像 detections: 检测结果列表 + plate_numbers: 车牌号列表,与detections对应 返回: numpy.ndarray: 绘制了检测结果的图像 """ draw_image = image.copy() + # 转换为PIL图像以支持中文字符 + pil_image = Image.fromarray(cv2.cvtColor(draw_image, cv2.COLOR_BGR2RGB)) + draw = ImageDraw.Draw(pil_image) + + # 尝试加载中文字体 + try: + # Windows系统常见的中文字体 + font_path = "C:/Windows/Fonts/simhei.ttf" # 黑体 + if not os.path.exists(font_path): + font_path = "C:/Windows/Fonts/msyh.ttc" # 微软雅黑 + if not os.path.exists(font_path): + font_path = "C:/Windows/Fonts/simsun.ttc" # 宋体 + + font = ImageFont.truetype(font_path, 20) + except: + # 如果无法加载字体,使用默认字体 + font = ImageFont.load_default() + for i, detection in enumerate(detections): box = detection['box'] keypoints = detection['keypoints'] @@ -133,6 +153,11 @@ class LicensePlateYOLO: confidence = detection['confidence'] incomplete = detection.get('incomplete', False) + # 获取对应的车牌号 + plate_number = "" + if plate_numbers and i < len(plate_numbers): + plate_number = plate_numbers[i] + # 绘制边界框 x1, y1, x2, y2 = map(int, box) @@ -140,30 +165,53 @@ class LicensePlateYOLO: if class_name == '绿牌': box_color = (0, 255, 0) # 绿色 elif class_name == '蓝牌': - box_color = (255, 0, 0) # 蓝色 + box_color = (0, 0, 255) # 蓝色 else: box_color = (128, 128, 128) # 灰色 - cv2.rectangle(draw_image, (x1, y1), (x2, y2), box_color, 2) + # 在PIL图像上绘制边界框 + draw.rectangle([(x1, y1), (x2, y2)], outline=box_color, width=2) + + # 构建标签文本 + if plate_number: + label = f"{class_name} {plate_number} {confidence:.2f}" + else: + label = f"{class_name} {confidence:.2f}" - # 绘制标签 - label = f"{class_name} {confidence:.2f}" if incomplete: label += " (不完整)" - # 计算文本大小和位置 - font = cv2.FONT_HERSHEY_SIMPLEX - font_scale = 0.6 - thickness = 2 - (text_width, text_height), _ = cv2.getTextSize(label, font, font_scale, thickness) + # 计算文本大小 + bbox = draw.textbbox((0, 0), label, font=font) + text_width = bbox[2] - bbox[0] + text_height = bbox[3] - bbox[1] # 绘制文本背景 - cv2.rectangle(draw_image, (x1, y1 - text_height - 10), - (x1 + text_width, y1), box_color, -1) + draw.rectangle([(x1, y1 - text_height - 10), (x1 + text_width, y1)], + fill=box_color) # 绘制文本 - cv2.putText(draw_image, label, (x1, y1 - 5), - font, font_scale, (255, 255, 255), thickness) + draw.text((x1, y1 - text_height - 5), label, fill=(255, 255, 255), font=font) + + # 转换回OpenCV格式 + draw_image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR) + + # 绘制关键点和连线(使用OpenCV) + for i, detection in enumerate(detections): + box = detection['box'] + keypoints = detection['keypoints'] + incomplete = detection.get('incomplete', False) + + x1, y1, x2, y2 = map(int, box) + + # 根据车牌类型选择颜色 + class_name = detection['class_name'] + if class_name == '绿牌': + box_color = (0, 255, 0) # 绿色 + elif class_name == '蓝牌': + box_color = (0, 0, 255) # 蓝色 + else: + box_color = (128, 128, 128) # 灰色 # 绘制关键点和连线 if len(keypoints) >= 4 and not incomplete: