From 95aa6b6bba8fd405db3e59f513bc7f87e5fbb4d1 Mon Sep 17 00:00:00 2001 From: spdis Date: Tue, 2 Sep 2025 11:40:41 +0800 Subject: [PATCH] LPR --- LPRNET_part/lpr_interface.py | 284 +++++++++++++++++++++++++++++++++++ README.md | 14 +- main.py | 16 +- 3 files changed, 305 insertions(+), 9 deletions(-) create mode 100644 LPRNET_part/lpr_interface.py diff --git a/LPRNET_part/lpr_interface.py b/LPRNET_part/lpr_interface.py new file mode 100644 index 0000000..8d7e674 --- /dev/null +++ b/LPRNET_part/lpr_interface.py @@ -0,0 +1,284 @@ +import torch +import torch.nn as nn +import numpy as np +import cv2 +from torch.autograd import Variable +import os + + +# 字符集定义 +CHARS = ['京', '沪', '津', '渝', '冀', '晋', '蒙', '辽', '吉', '黑', + '苏', '浙', '皖', '闽', '赣', '鲁', '豫', '鄂', '湘', '粤', + '桂', '琼', '川', '贵', '云', '藏', '陕', '甘', '青', '宁', + '新', + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', + 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J', 'K', + 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', + 'W', 'X', 'Y', 'Z', 'I', 'O', '-' + ] + +CHARS_DICT = {char: i for i, char in enumerate(CHARS)} + +# 全局变量 +lprnet_model = None +device = None + +class small_basic_block(nn.Module): + def __init__(self, ch_in, ch_out): + super(small_basic_block, self).__init__() + self.block = nn.Sequential( + nn.Conv2d(ch_in, ch_out // 4, kernel_size=1), + nn.ReLU(), + nn.Conv2d(ch_out // 4, ch_out // 4, kernel_size=(3, 1), padding=(1, 0)), + nn.ReLU(), + nn.Conv2d(ch_out // 4, ch_out // 4, kernel_size=(1, 3), padding=(0, 1)), + nn.ReLU(), + nn.Conv2d(ch_out // 4, ch_out, kernel_size=1), + ) + + def forward(self, x): + return self.block(x) + +class LPRNet(nn.Module): + def __init__(self, lpr_max_len, phase, class_num, dropout_rate): + super(LPRNet, self).__init__() + self.phase = phase + self.lpr_max_len = lpr_max_len + self.class_num = class_num + self.backbone = nn.Sequential( + nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1), # 0 + nn.BatchNorm2d(num_features=64), + nn.ReLU(), # 2 + nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 1, 1)), + small_basic_block(ch_in=64, ch_out=128), # *** 4 *** + nn.BatchNorm2d(num_features=128), + nn.ReLU(), # 6 + nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(2, 1, 2)), + small_basic_block(ch_in=64, ch_out=256), # 8 + nn.BatchNorm2d(num_features=256), + nn.ReLU(), # 10 + small_basic_block(ch_in=256, ch_out=256), # *** 11 *** + nn.BatchNorm2d(num_features=256), # 12 + nn.ReLU(), + nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(4, 1, 2)), # 14 + nn.Dropout(dropout_rate), + nn.Conv2d(in_channels=64, out_channels=256, kernel_size=(1, 4), stride=1), # 16 + nn.BatchNorm2d(num_features=256), + nn.ReLU(), # 18 + nn.Dropout(dropout_rate), + nn.Conv2d(in_channels=256, out_channels=class_num, kernel_size=(13, 1), stride=1), # 20 + nn.BatchNorm2d(num_features=class_num), + nn.ReLU(), # *** 22 *** + ) + self.container = nn.Sequential( + nn.Conv2d(in_channels=448+self.class_num, out_channels=self.class_num, kernel_size=(1, 1), stride=(1, 1)), + ) + + def forward(self, x): + keep_features = list() + for i, layer in enumerate(self.backbone.children()): + x = layer(x) + if i in [2, 6, 13, 22]: # [2, 4, 8, 11, 22] + keep_features.append(x) + + global_context = list() + for i, f in enumerate(keep_features): + if i in [0, 1]: + f = nn.AvgPool2d(kernel_size=5, stride=5)(f) + if i in [2]: + f = nn.AvgPool2d(kernel_size=(4, 10), stride=(4, 2))(f) + f_pow = torch.pow(f, 2) + f_mean = torch.mean(f_pow) + f = torch.div(f, f_mean) + global_context.append(f) + + x = torch.cat(global_context, 1) + x = self.container(x) + logits = torch.mean(x, dim=2) + + return logits + +def build_lprnet(lpr_max_len=8, phase=False, class_num=66, dropout_rate=0.5): + """构建LPRNet模型""" + Net = LPRNet(lpr_max_len, phase, class_num, dropout_rate) + + if phase == "train": + return Net.train() + else: + return Net.eval() + +def preprocess_image(image_array, img_size=(94, 24)): + """图像预处理""" + # 确保输入是numpy数组 + if not isinstance(image_array, np.ndarray): + raise ValueError("输入必须是numpy数组") + + # 调整图像尺寸 + height, width = image_array.shape[:2] + if height != img_size[1] or width != img_size[0]: + image_array = cv2.resize(image_array, img_size) + + # 归一化到[0,1] + image_array = image_array.astype(np.float32) / 255.0 + + # 转换为CHW格式 + if len(image_array.shape) == 3: + image_array = np.transpose(image_array, (2, 0, 1)) + + # 添加batch维度 + image_array = np.expand_dims(image_array, axis=0) + + return image_array + +def greedy_decode(prebs): + """贪婪解码""" + preb_labels = list() + for i in range(prebs.shape[0]): + preb = prebs[i, :, :] + preb_label = list() + for j in range(preb.shape[1]): + preb_label.append(np.argmax(preb[:, j], axis=0)) + + no_repeat_blank_label = list() + pre_c = preb_label[0] + if pre_c != len(CHARS) - 1: + no_repeat_blank_label.append(pre_c) + + for c in preb_label: # 去除重复标签和空白标签 + if (pre_c == c) or (c == len(CHARS) - 1): + if c == len(CHARS) - 1: + pre_c = c + continue + no_repeat_blank_label.append(c) + pre_c = c + + preb_labels.append(no_repeat_blank_label) + + return preb_labels + +def LPRNinitialize_model(model_path=None): + """初始化LPRNet模型""" + global lprnet_model, device + + try: + # 设置设备 + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + print(f"使用设备: {device}") + + # 构建模型 + lprnet_model = build_lprnet( + lpr_max_len=8, + phase=False, + class_num=len(CHARS), + dropout_rate=0.5 + ) + + # 加载预训练权重 + if model_path is None: + model_path = os.path.join(os.path.dirname(__file__), "Final_LPRNet_model.pth") + + if os.path.exists(model_path): + checkpoint = torch.load(model_path, map_location=device) + lprnet_model.load_state_dict(checkpoint) + print(f"成功加载预训练模型: {model_path}") + else: + print(f"警告: 未找到预训练模型文件 {model_path},使用随机初始化权重") + + lprnet_model.to(device) + lprnet_model.eval() + + print("LPRNet模型初始化完成") + + # 统计模型参数 + total_params = sum(p.numel() for p in lprnet_model.parameters()) + print(f"LPRNet模型参数数量: {total_params:,}") + + return True + + except Exception as e: + print(f"LPRNet模型初始化失败: {e}") + import traceback + traceback.print_exc() + return False + +def LPRNmodel_predict(image_array): + """ + LPRNet车牌号识别接口函数 + + 参数: + image_array: numpy数组格式的车牌图像,已经过矫正处理 + + 返回: + list: 包含7个字符的列表,代表车牌号的每个字符 + 例如: ['京', 'A', '1', '2', '3', '4', '5'] + """ + global lprnet_model, device + + if lprnet_model is None: + print("LPRNet模型未初始化,请先调用LPRNinitialize_model()") + return ['待', '识', '别', '0', '0', '0', '0'] + + try: + # 预处理图像 + processed_image = preprocess_image(image_array) + + # 转换为tensor + input_tensor = torch.from_numpy(processed_image).float() + input_tensor = input_tensor.to(device) + + # 模型推理 + with torch.no_grad(): + prebs = lprnet_model(input_tensor) + prebs = prebs.cpu().detach().numpy() + + # 贪婪解码 + preb_labels = greedy_decode(prebs) + + if len(preb_labels) > 0 and len(preb_labels[0]) > 0: + # 将索引转换为字符 + predicted_chars = [CHARS[idx] for idx in preb_labels[0] if idx < len(CHARS)] + + print(f"LPRNet识别结果: {''.join(predicted_chars)}") + + # 确保返回7个字符(车牌标准长度) + if len(predicted_chars) < 7: + # 如果识别结果少于7个字符,用'0'补齐 + predicted_chars.extend(['0'] * (7 - len(predicted_chars))) + elif len(predicted_chars) > 7: + # 如果识别结果多于7个字符,截取前7个 + predicted_chars = predicted_chars[:7] + + return predicted_chars + else: + print("LPRNet识别结果为空") + return ['识', '别', '为', '空', '0', '0', '0'] + + except Exception as e: + print(f"LPRNet识别失败: {e}") + import traceback + traceback.print_exc() + return ['识', '别', '失', '败', '0', '0', '0'] + +# 为了保持与其他模块的一致性,提供一个处理器类 +class LPRProcessor: + def __init__(self): + self.initialized = False + + def initialize(self, model_path=None): + """初始化模型""" + self.initialized = LPRNinitialize_model(model_path) + return self.initialized + + def predict(self, image_array): + """预测接口""" + if not self.initialized: + print("模型未初始化") + return ['未', '初', '始', '化', '0', '0', '0'] + return LPRNmodel_predict(image_array) + +# 创建全局处理器实例 +_processor = LPRProcessor() + +def get_lpr_processor(): + """获取LPR处理器实例""" + return _processor \ No newline at end of file diff --git a/README.md b/README.md index 7508b79..ca34982 100644 --- a/README.md +++ b/README.md @@ -14,8 +14,12 @@ License_plate_recognition/ │ └── yolo11s-pose42.pt # YOLO pose模型文件 ├── OCR_part/ # OCR识别模块 │ └── ocr_interface.py # OCR接口(占位) -└── CRNN_part/ # CRNN识别模块 - └── crnn_interface.py # CRNN +├── CRNN_part/ # CRNN识别模块 +│ └── crnn_interface.py # CRNN接口 +└── LPRNET_part/ # LPRNet识别模块 + ├── lpr_interface.py # LPRNet接口 + ├── Final_LPRNet_model.pth # 预训练模型文件 + └── will_delete/ # 参考资料(可删除) ``` ## 功能特性 @@ -44,8 +48,10 @@ License_plate_recognition/ ### 5. 模块化设计 - yolopart:负责车牌定位和矫正 -- OCR_part/CRNN_part:负责车牌号识别(接口已预留) -- 各模块独立,便于维护和扩展 +- OCR_part:基于PaddleOCR的车牌号识别模块 +- CRNN_part:基于CRNN网络的车牌号识别模块 +- LPRNET_part:基于LPRNet网络的车牌号识别模块 +- 各模块独立,便于维护和扩展,可通过修改main.py中的导入语句切换识别模块 ## 安装和使用 diff --git a/main.py b/main.py index 2eee603..766c2b3 100644 --- a/main.py +++ b/main.py @@ -9,11 +9,17 @@ from PyQt5.QtCore import QTimer, Qt, pyqtSignal, QThread from PyQt5.QtGui import QImage, QPixmap, QFont, QPainter, QPen, QColor import os from yolopart.detector import LicensePlateYOLO -from OCR_part.ocr_interface import LPRNmodel_predict -from OCR_part.ocr_interface import LPRNinitialize_model -# 使用CRNN进行车牌字符识别(可选)同时也要修改第395,396行 -# from CRNN_part.crnn_interface import LPRNmodel_predict -# from CRNN_part.crnn_interface import LPRNinitialize_model + +#选择使用哪个模块 +from LPRNET_part.lpr_interface import LPRNmodel_predict +from LPRNET_part.lpr_interface import LPRNinitialize_model + +#使用OCR +#from OCR_part.ocr_interface import LPRNmodel_predict +#from OCR_part.ocr_interface import LPRNinitialize_model +# 使用CRNN +#from CRNN_part.crnn_interface import LPRNmodel_predict +#from CRNN_part.crnn_interface import LPRNinitialize_model class CameraThread(QThread): """摄像头线程类"""