This commit is contained in:
spdis 2025-09-02 11:40:41 +08:00
parent 739cd1d914
commit 95aa6b6bba
3 changed files with 305 additions and 9 deletions

View File

@ -0,0 +1,284 @@
import torch
import torch.nn as nn
import numpy as np
import cv2
from torch.autograd import Variable
import os
# 字符集定义
CHARS = ['', '', '', '', '', '', '', '', '', '',
'', '', '', '', '', '', '', '', '', '',
'', '', '', '', '', '', '', '', '', '',
'',
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J', 'K',
'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'U', 'V',
'W', 'X', 'Y', 'Z', 'I', 'O', '-'
]
CHARS_DICT = {char: i for i, char in enumerate(CHARS)}
# 全局变量
lprnet_model = None
device = None
class small_basic_block(nn.Module):
def __init__(self, ch_in, ch_out):
super(small_basic_block, self).__init__()
self.block = nn.Sequential(
nn.Conv2d(ch_in, ch_out // 4, kernel_size=1),
nn.ReLU(),
nn.Conv2d(ch_out // 4, ch_out // 4, kernel_size=(3, 1), padding=(1, 0)),
nn.ReLU(),
nn.Conv2d(ch_out // 4, ch_out // 4, kernel_size=(1, 3), padding=(0, 1)),
nn.ReLU(),
nn.Conv2d(ch_out // 4, ch_out, kernel_size=1),
)
def forward(self, x):
return self.block(x)
class LPRNet(nn.Module):
def __init__(self, lpr_max_len, phase, class_num, dropout_rate):
super(LPRNet, self).__init__()
self.phase = phase
self.lpr_max_len = lpr_max_len
self.class_num = class_num
self.backbone = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1), # 0
nn.BatchNorm2d(num_features=64),
nn.ReLU(), # 2
nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 1, 1)),
small_basic_block(ch_in=64, ch_out=128), # *** 4 ***
nn.BatchNorm2d(num_features=128),
nn.ReLU(), # 6
nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(2, 1, 2)),
small_basic_block(ch_in=64, ch_out=256), # 8
nn.BatchNorm2d(num_features=256),
nn.ReLU(), # 10
small_basic_block(ch_in=256, ch_out=256), # *** 11 ***
nn.BatchNorm2d(num_features=256), # 12
nn.ReLU(),
nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(4, 1, 2)), # 14
nn.Dropout(dropout_rate),
nn.Conv2d(in_channels=64, out_channels=256, kernel_size=(1, 4), stride=1), # 16
nn.BatchNorm2d(num_features=256),
nn.ReLU(), # 18
nn.Dropout(dropout_rate),
nn.Conv2d(in_channels=256, out_channels=class_num, kernel_size=(13, 1), stride=1), # 20
nn.BatchNorm2d(num_features=class_num),
nn.ReLU(), # *** 22 ***
)
self.container = nn.Sequential(
nn.Conv2d(in_channels=448+self.class_num, out_channels=self.class_num, kernel_size=(1, 1), stride=(1, 1)),
)
def forward(self, x):
keep_features = list()
for i, layer in enumerate(self.backbone.children()):
x = layer(x)
if i in [2, 6, 13, 22]: # [2, 4, 8, 11, 22]
keep_features.append(x)
global_context = list()
for i, f in enumerate(keep_features):
if i in [0, 1]:
f = nn.AvgPool2d(kernel_size=5, stride=5)(f)
if i in [2]:
f = nn.AvgPool2d(kernel_size=(4, 10), stride=(4, 2))(f)
f_pow = torch.pow(f, 2)
f_mean = torch.mean(f_pow)
f = torch.div(f, f_mean)
global_context.append(f)
x = torch.cat(global_context, 1)
x = self.container(x)
logits = torch.mean(x, dim=2)
return logits
def build_lprnet(lpr_max_len=8, phase=False, class_num=66, dropout_rate=0.5):
"""构建LPRNet模型"""
Net = LPRNet(lpr_max_len, phase, class_num, dropout_rate)
if phase == "train":
return Net.train()
else:
return Net.eval()
def preprocess_image(image_array, img_size=(94, 24)):
"""图像预处理"""
# 确保输入是numpy数组
if not isinstance(image_array, np.ndarray):
raise ValueError("输入必须是numpy数组")
# 调整图像尺寸
height, width = image_array.shape[:2]
if height != img_size[1] or width != img_size[0]:
image_array = cv2.resize(image_array, img_size)
# 归一化到[0,1]
image_array = image_array.astype(np.float32) / 255.0
# 转换为CHW格式
if len(image_array.shape) == 3:
image_array = np.transpose(image_array, (2, 0, 1))
# 添加batch维度
image_array = np.expand_dims(image_array, axis=0)
return image_array
def greedy_decode(prebs):
"""贪婪解码"""
preb_labels = list()
for i in range(prebs.shape[0]):
preb = prebs[i, :, :]
preb_label = list()
for j in range(preb.shape[1]):
preb_label.append(np.argmax(preb[:, j], axis=0))
no_repeat_blank_label = list()
pre_c = preb_label[0]
if pre_c != len(CHARS) - 1:
no_repeat_blank_label.append(pre_c)
for c in preb_label: # 去除重复标签和空白标签
if (pre_c == c) or (c == len(CHARS) - 1):
if c == len(CHARS) - 1:
pre_c = c
continue
no_repeat_blank_label.append(c)
pre_c = c
preb_labels.append(no_repeat_blank_label)
return preb_labels
def LPRNinitialize_model(model_path=None):
"""初始化LPRNet模型"""
global lprnet_model, device
try:
# 设置设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
# 构建模型
lprnet_model = build_lprnet(
lpr_max_len=8,
phase=False,
class_num=len(CHARS),
dropout_rate=0.5
)
# 加载预训练权重
if model_path is None:
model_path = os.path.join(os.path.dirname(__file__), "Final_LPRNet_model.pth")
if os.path.exists(model_path):
checkpoint = torch.load(model_path, map_location=device)
lprnet_model.load_state_dict(checkpoint)
print(f"成功加载预训练模型: {model_path}")
else:
print(f"警告: 未找到预训练模型文件 {model_path},使用随机初始化权重")
lprnet_model.to(device)
lprnet_model.eval()
print("LPRNet模型初始化完成")
# 统计模型参数
total_params = sum(p.numel() for p in lprnet_model.parameters())
print(f"LPRNet模型参数数量: {total_params:,}")
return True
except Exception as e:
print(f"LPRNet模型初始化失败: {e}")
import traceback
traceback.print_exc()
return False
def LPRNmodel_predict(image_array):
"""
LPRNet车牌号识别接口函数
参数:
image_array: numpy数组格式的车牌图像已经过矫正处理
返回:
list: 包含7个字符的列表代表车牌号的每个字符
例如: ['', 'A', '1', '2', '3', '4', '5']
"""
global lprnet_model, device
if lprnet_model is None:
print("LPRNet模型未初始化请先调用LPRNinitialize_model()")
return ['', '', '', '0', '0', '0', '0']
try:
# 预处理图像
processed_image = preprocess_image(image_array)
# 转换为tensor
input_tensor = torch.from_numpy(processed_image).float()
input_tensor = input_tensor.to(device)
# 模型推理
with torch.no_grad():
prebs = lprnet_model(input_tensor)
prebs = prebs.cpu().detach().numpy()
# 贪婪解码
preb_labels = greedy_decode(prebs)
if len(preb_labels) > 0 and len(preb_labels[0]) > 0:
# 将索引转换为字符
predicted_chars = [CHARS[idx] for idx in preb_labels[0] if idx < len(CHARS)]
print(f"LPRNet识别结果: {''.join(predicted_chars)}")
# 确保返回7个字符车牌标准长度
if len(predicted_chars) < 7:
# 如果识别结果少于7个字符用'0'补齐
predicted_chars.extend(['0'] * (7 - len(predicted_chars)))
elif len(predicted_chars) > 7:
# 如果识别结果多于7个字符截取前7个
predicted_chars = predicted_chars[:7]
return predicted_chars
else:
print("LPRNet识别结果为空")
return ['', '', '', '', '0', '0', '0']
except Exception as e:
print(f"LPRNet识别失败: {e}")
import traceback
traceback.print_exc()
return ['', '', '', '', '0', '0', '0']
# 为了保持与其他模块的一致性,提供一个处理器类
class LPRProcessor:
def __init__(self):
self.initialized = False
def initialize(self, model_path=None):
"""初始化模型"""
self.initialized = LPRNinitialize_model(model_path)
return self.initialized
def predict(self, image_array):
"""预测接口"""
if not self.initialized:
print("模型未初始化")
return ['', '', '', '', '0', '0', '0']
return LPRNmodel_predict(image_array)
# 创建全局处理器实例
_processor = LPRProcessor()
def get_lpr_processor():
"""获取LPR处理器实例"""
return _processor

View File

@ -14,8 +14,12 @@ License_plate_recognition/
│ └── yolo11s-pose42.pt # YOLO pose模型文件 │ └── yolo11s-pose42.pt # YOLO pose模型文件
├── OCR_part/ # OCR识别模块 ├── OCR_part/ # OCR识别模块
│ └── ocr_interface.py # OCR接口占位 │ └── ocr_interface.py # OCR接口占位
└── CRNN_part/ # CRNN识别模块 ├── CRNN_part/ # CRNN识别模块
└── crnn_interface.py # CRNN │ └── crnn_interface.py # CRNN接口
└── LPRNET_part/ # LPRNet识别模块
├── lpr_interface.py # LPRNet接口
├── Final_LPRNet_model.pth # 预训练模型文件
└── will_delete/ # 参考资料(可删除)
``` ```
## 功能特性 ## 功能特性
@ -44,8 +48,10 @@ License_plate_recognition/
### 5. 模块化设计 ### 5. 模块化设计
- yolopart负责车牌定位和矫正 - yolopart负责车牌定位和矫正
- OCR_part/CRNN_part负责车牌号识别接口已预留 - OCR_part基于PaddleOCR的车牌号识别模块
- 各模块独立,便于维护和扩展 - CRNN_part基于CRNN网络的车牌号识别模块
- LPRNET_part基于LPRNet网络的车牌号识别模块
- 各模块独立便于维护和扩展可通过修改main.py中的导入语句切换识别模块
## 安装和使用 ## 安装和使用

16
main.py
View File

@ -9,11 +9,17 @@ from PyQt5.QtCore import QTimer, Qt, pyqtSignal, QThread
from PyQt5.QtGui import QImage, QPixmap, QFont, QPainter, QPen, QColor from PyQt5.QtGui import QImage, QPixmap, QFont, QPainter, QPen, QColor
import os import os
from yolopart.detector import LicensePlateYOLO from yolopart.detector import LicensePlateYOLO
from OCR_part.ocr_interface import LPRNmodel_predict
from OCR_part.ocr_interface import LPRNinitialize_model #选择使用哪个模块
# 使用CRNN进行车牌字符识别可选同时也要修改第395396行 from LPRNET_part.lpr_interface import LPRNmodel_predict
# from CRNN_part.crnn_interface import LPRNmodel_predict from LPRNET_part.lpr_interface import LPRNinitialize_model
# from CRNN_part.crnn_interface import LPRNinitialize_model
#使用OCR
#from OCR_part.ocr_interface import LPRNmodel_predict
#from OCR_part.ocr_interface import LPRNinitialize_model
# 使用CRNN
#from CRNN_part.crnn_interface import LPRNmodel_predict
#from CRNN_part.crnn_interface import LPRNinitialize_model
class CameraThread(QThread): class CameraThread(QThread):
"""摄像头线程类""" """摄像头线程类"""