更新接口

This commit is contained in:
spdis 2025-09-04 00:07:52 +08:00
parent 95aa6b6bba
commit 6c7f013a0c
5 changed files with 290 additions and 187 deletions

View File

@ -282,14 +282,15 @@ def LPRNmodel_predict(image_array):
image_array: numpy数组格式的车牌图像已经过矫正处理 image_array: numpy数组格式的车牌图像已经过矫正处理
返回: 返回:
list: 包含7个字符的列表代表车牌号的每个字符 list: 包含最多8个字符的列表代表车牌号的每个字符
例如: ['', 'A', '1', '2', '3', '4', '5'] 例如: ['', 'A', '1', '2', '3', '4', '5', ''] (蓝牌7位+占位符)
['', 'A', 'D', '1', '2', '3', '4', '5'] (绿牌8位)
""" """
global crnn_model, crnn_decoder, crnn_preprocessor, device global crnn_model, crnn_decoder, crnn_preprocessor, device
if crnn_model is None or crnn_decoder is None or crnn_preprocessor is None: if crnn_model is None or crnn_decoder is None or crnn_preprocessor is None:
print("CRNN模型未初始化请先调用initialize_crnn_model()") print("CRNN模型未初始化请先调用initialize_crnn_model()")
return ['', '', '', '0', '0', '0', '0'] return ['', '', '', '0', '0', '0', '0', '0']
try: try:
# 预处理图像 # 预处理图像
@ -314,13 +315,17 @@ def LPRNmodel_predict(image_array):
# 将字符串转换为字符列表 # 将字符串转换为字符列表
char_list = list(predicted_text) char_list = list(predicted_text)
# 确保返回7个字符车牌标准长度 # 确保返回至少7个字符最多8个字符
if len(char_list) < 7: if len(char_list) < 7:
# 如果识别结果少于7个字符用'0'补齐 # 如果识别结果少于7个字符用'0'补齐到7位
char_list.extend(['0'] * (7 - len(char_list))) char_list.extend(['0'] * (7 - len(char_list)))
elif len(char_list) > 7: elif len(char_list) > 8:
# 如果识别结果多于7个字符截取前7个 # 如果识别结果多于8个字符截取前8个
char_list = char_list[:7] char_list = char_list[:8]
# 如果是7位补齐到8位以保持接口一致性第8位用空字符或占位符
if len(char_list) == 7:
char_list.append('') # 添加空字符作为第8位占位符
return char_list return char_list
@ -328,4 +333,4 @@ def LPRNmodel_predict(image_array):
print(f"CRNN识别失败: {e}") print(f"CRNN识别失败: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
return ['', '', '', '', '0', '0', '0'] return ['', '', '', '', '0', '0', '0', '0']

View File

@ -1,28 +1,27 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import numpy as np
import cv2 import cv2
from torch.autograd import Variable import numpy as np
import os import os
import sys
from torch.autograd import Variable
from PIL import Image
# 添加父目录到路径,以便导入模型和数据加载器
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# 字符集定义 # LPRNet字符集定义(与训练时保持一致)
CHARS = ['', '', '', '', '', '', '', '', '', '', CHARS = ['', '', '', '', '', '', '', '', '', '',
'', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '',
'', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '', '',
'',
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J', 'K', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J', 'K',
'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'U', 'V',
'W', 'X', 'Y', 'Z', 'I', 'O', '-' 'W', 'X', 'Y', 'Z', 'I', 'O', '-']
]
CHARS_DICT = {char: i for i, char in enumerate(CHARS)} CHARS_DICT = {char: i for i, char in enumerate(CHARS)}
# 全局变量 # 简化的LPRNet模型定义
lprnet_model = None
device = None
class small_basic_block(nn.Module): class small_basic_block(nn.Module):
def __init__(self, ch_in, ch_out): def __init__(self, ch_in, ch_out):
super(small_basic_block, self).__init__() super(small_basic_block, self).__init__()
@ -35,7 +34,7 @@ class small_basic_block(nn.Module):
nn.ReLU(), nn.ReLU(),
nn.Conv2d(ch_out // 4, ch_out, kernel_size=1), nn.Conv2d(ch_out // 4, ch_out, kernel_size=1),
) )
def forward(self, x): def forward(self, x):
return self.block(x) return self.block(x)
@ -58,20 +57,20 @@ class LPRNet(nn.Module):
nn.BatchNorm2d(num_features=256), nn.BatchNorm2d(num_features=256),
nn.ReLU(), # 10 nn.ReLU(), # 10
small_basic_block(ch_in=256, ch_out=256), # *** 11 *** small_basic_block(ch_in=256, ch_out=256), # *** 11 ***
nn.BatchNorm2d(num_features=256), # 12 nn.BatchNorm2d(num_features=256),
nn.ReLU(), nn.ReLU(), # 13
nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(4, 1, 2)), # 14 nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(4, 1, 2)), # 14
nn.Dropout(dropout_rate), nn.Dropout(dropout_rate),
nn.Conv2d(in_channels=64, out_channels=256, kernel_size=(1, 4), stride=1), # 16 nn.Conv2d(in_channels=64, out_channels=256, kernel_size=(1, 4), stride=1), # 16
nn.BatchNorm2d(num_features=256), nn.BatchNorm2d(num_features=256),
nn.ReLU(), # 18 nn.ReLU(), # 18
nn.Dropout(dropout_rate), nn.Dropout(dropout_rate),
nn.Conv2d(in_channels=256, out_channels=class_num, kernel_size=(13, 1), stride=1), # 20 nn.Conv2d(in_channels=256, out_channels=class_num, kernel_size=(13, 1), stride=1), # 20
nn.BatchNorm2d(num_features=class_num), nn.BatchNorm2d(num_features=class_num),
nn.ReLU(), # *** 22 *** nn.ReLU(), # 22
) )
self.container = nn.Sequential( self.container = nn.Sequential(
nn.Conv2d(in_channels=448+self.class_num, out_channels=self.class_num, kernel_size=(1, 1), stride=(1, 1)), nn.Conv2d(in_channels=448+self.class_num, out_channels=self.class_num, kernel_size=(1,1), stride=(1,1)),
) )
def forward(self, x): def forward(self, x):
@ -98,101 +97,177 @@ class LPRNet(nn.Module):
return logits return logits
def build_lprnet(lpr_max_len=8, phase=False, class_num=66, dropout_rate=0.5): class LPRNetInference:
"""构建LPRNet模型""" def __init__(self, model_path=None, img_size=[94, 24], lpr_max_len=8, dropout_rate=0.5):
Net = LPRNet(lpr_max_len, phase, class_num, dropout_rate) """
初始化LPRNet推理类
Args:
model_path: 训练好的模型权重文件路径
img_size: 输入图像尺寸 [width, height]
lpr_max_len: 车牌最大长度
dropout_rate: dropout率
"""
self.img_size = img_size
self.lpr_max_len = lpr_max_len
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 设置默认模型路径
if model_path is None:
current_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(current_dir, 'LPRNet__iteration_74000.pth')
# 初始化模型
self.model = LPRNet(lpr_max_len=lpr_max_len, phase=False, class_num=len(CHARS), dropout_rate=dropout_rate)
# 加载模型权重
if model_path and os.path.exists(model_path):
print(f"Loading LPRNet model from {model_path}")
try:
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
print("LPRNet模型权重加载成功")
except Exception as e:
print(f"Warning: 加载模型权重失败: {e}. 使用随机权重.")
else:
print(f"Warning: 模型文件不存在或未指定: {model_path}. 使用随机权重.")
self.model.to(self.device)
self.model.eval()
print(f"LPRNet模型加载完成设备: {self.device}")
print(f"模型参数数量: {sum(p.numel() for p in self.model.parameters()):,}")
if phase == "train": def preprocess_image(self, image_array):
return Net.train() """
else: 预处理图像数组 - 使用与训练时相同的预处理方式
return Net.eval() Args:
image_array: numpy数组格式的图像 (H, W, C)
def preprocess_image(image_array, img_size=(94, 24)): Returns:
"""图像预处理""" preprocessed_image: 预处理后的图像tensor
# 确保输入是numpy数组 """
if not isinstance(image_array, np.ndarray): if image_array is None:
raise ValueError("输入必须是numpy数组") raise ValueError("Input image is None")
# 确保图像是numpy数组
if not isinstance(image_array, np.ndarray):
raise ValueError("Input must be numpy array")
# 检查图像维度
if len(image_array.shape) != 3:
raise ValueError(f"Expected 3D image array, got {len(image_array.shape)}D")
height, width, channels = image_array.shape
if channels != 3:
raise ValueError(f"Expected 3 channels, got {channels}")
# 调整图像尺寸到模型要求的尺寸
if height != self.img_size[1] or width != self.img_size[0]:
image_array = cv2.resize(image_array, tuple(self.img_size))
# 使用与训练时相同的预处理方式
image_array = image_array.astype('float32')
image_array -= 127.5
image_array *= 0.0078125
image_array = np.transpose(image_array, (2, 0, 1)) # HWC -> CHW
# 转换为tensor并添加batch维度
image_tensor = torch.from_numpy(image_array).unsqueeze(0)
return image_tensor
# 调整图像尺寸 def decode_prediction(self, logits):
height, width = image_array.shape[:2] """
if height != img_size[1] or width != img_size[0]: 解码模型预测结果 - 使用正确的CTC贪婪解码
image_array = cv2.resize(image_array, img_size) Args:
logits: 模型输出的logits [batch_size, num_classes, sequence_length]
# 归一化到[0,1] Returns:
image_array = image_array.astype(np.float32) / 255.0 predicted_text: 预测的车牌号码
"""
# 转换为CHW格式 # 转换为numpy进行处理
if len(image_array.shape) == 3: prebs = logits.cpu().detach().numpy()
image_array = np.transpose(image_array, (2, 0, 1)) preb = prebs[0, :, :] # 取第一个batch [num_classes, sequence_length]
# 添加batch维度 # 贪婪解码:对每个时间步选择最大概率的字符
image_array = np.expand_dims(image_array, axis=0) preb_label = []
for j in range(preb.shape[1]): # 遍历每个时间步
return image_array
def greedy_decode(prebs):
"""贪婪解码"""
preb_labels = list()
for i in range(prebs.shape[0]):
preb = prebs[i, :, :]
preb_label = list()
for j in range(preb.shape[1]):
preb_label.append(np.argmax(preb[:, j], axis=0)) preb_label.append(np.argmax(preb[:, j], axis=0))
no_repeat_blank_label = list() # CTC解码去除重复字符和空白字符
no_repeat_blank_label = []
pre_c = preb_label[0] pre_c = preb_label[0]
if pre_c != len(CHARS) - 1:
# 处理第一个字符
if pre_c != len(CHARS) - 1: # 不是空白字符
no_repeat_blank_label.append(pre_c) no_repeat_blank_label.append(pre_c)
for c in preb_label: # 去除重复标签和空白标签 # 处理后续字符
if (pre_c == c) or (c == len(CHARS) - 1): for c in preb_label:
if (pre_c == c) or (c == len(CHARS) - 1): # 重复字符或空白字符
if c == len(CHARS) - 1: if c == len(CHARS) - 1:
pre_c = c pre_c = c
continue continue
no_repeat_blank_label.append(c) no_repeat_blank_label.append(c)
pre_c = c pre_c = c
preb_labels.append(no_repeat_blank_label) # 转换为字符
decoded_chars = [CHARS[idx] for idx in no_repeat_blank_label]
return ''.join(decoded_chars)
return preb_labels def predict(self, image_array):
"""
预测单张图像的车牌号码
Args:
image_array: numpy数组格式的图像
Returns:
prediction: 预测的车牌号码
confidence: 预测置信度
"""
try:
# 预处理图像
image = self.preprocess_image(image_array)
if image is None:
return None, 0.0
image = image.to(self.device)
# 模型推理
with torch.no_grad():
logits = self.model(image)
# logits shape: [batch_size, class_num, sequence_length]
# 计算置信度使用softmax后的最大概率平均值
probs = torch.softmax(logits, dim=1)
max_probs = torch.max(probs, dim=1)[0]
confidence = torch.mean(max_probs).item()
# 解码预测结果
prediction = self.decode_prediction(logits)
return prediction, confidence
except Exception as e:
print(f"预测图像失败: {e}")
return None, 0.0
def LPRNinitialize_model(model_path=None): # 全局变量
"""初始化LPRNet模型""" lpr_model = None
global lprnet_model, device
def LPRNinitialize_model():
"""
初始化LPRNet模型
返回:
bool: 初始化是否成功
"""
global lpr_model
try: try:
# 设置设备 # 模型权重文件路径
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model_path = os.path.join(os.path.dirname(__file__), 'LPRNet__iteration_74000.pth')
print(f"使用设备: {device}")
# 构建模型 # 创建推理对象
lprnet_model = build_lprnet( lpr_model = LPRNetInference(model_path)
lpr_max_len=8,
phase=False,
class_num=len(CHARS),
dropout_rate=0.5
)
# 加载预训练权重
if model_path is None:
model_path = os.path.join(os.path.dirname(__file__), "Final_LPRNet_model.pth")
if os.path.exists(model_path):
checkpoint = torch.load(model_path, map_location=device)
lprnet_model.load_state_dict(checkpoint)
print(f"成功加载预训练模型: {model_path}")
else:
print(f"警告: 未找到预训练模型文件 {model_path},使用随机初始化权重")
lprnet_model.to(device)
lprnet_model.eval()
print("LPRNet模型初始化完成") print("LPRNet模型初始化完成")
# 统计模型参数
total_params = sum(p.numel() for p in lprnet_model.parameters())
print(f"LPRNet模型参数数量: {total_params:,}")
return True return True
except Exception as e: except Exception as e:
@ -209,76 +284,45 @@ def LPRNmodel_predict(image_array):
image_array: numpy数组格式的车牌图像已经过矫正处理 image_array: numpy数组格式的车牌图像已经过矫正处理
返回: 返回:
list: 包含7个字符的列表代表车牌号的每个字符 list: 包含最多8个字符的列表代表车牌号的每个字符
例如: ['', 'A', '1', '2', '3', '4', '5'] 例如: ['', 'A', '1', '2', '3', '4', '5'] (蓝牌7位)
['', 'A', 'D', '1', '2', '3', '4', '5'] (绿牌8位)
""" """
global lprnet_model, device global lpr_model
if lprnet_model is None: if lpr_model is None:
print("LPRNet模型未初始化请先调用LPRNinitialize_model()") print("LPRNet模型未初始化请先调用LPRNinitialize_model()")
return ['', '', '', '0', '0', '0', '0'] return ['', '', '', '0', '0', '0', '0', '0']
try: try:
# 预处理图像 # 预测车牌号
processed_image = preprocess_image(image_array) predicted_text, confidence = lpr_model.predict(image_array)
# 转换为tensor if predicted_text is None:
input_tensor = torch.from_numpy(processed_image).float() print("LPRNet识别失败")
input_tensor = input_tensor.to(device) return ['', '', '', '', '0', '0', '0', '0']
# 模型推理 print(f"LPRNet识别结果: {predicted_text}, 置信度: {confidence:.3f}")
with torch.no_grad():
prebs = lprnet_model(input_tensor)
prebs = prebs.cpu().detach().numpy()
# 贪婪解码 # 将字符串转换为字符列表
preb_labels = greedy_decode(prebs) char_list = list(predicted_text)
# 确保返回至少7个字符最多8个字符
if len(char_list) < 7:
# 如果识别结果少于7个字符用'0'补齐到7位
char_list.extend(['0'] * (7 - len(char_list)))
elif len(char_list) > 8:
# 如果识别结果多于8个字符截取前8个
char_list = char_list[:8]
# 如果是7位补齐到8位以保持接口一致性第8位用空字符或占位符
if len(char_list) == 7:
char_list.append('') # 添加空字符作为第8位占位符
return char_list
if len(preb_labels) > 0 and len(preb_labels[0]) > 0:
# 将索引转换为字符
predicted_chars = [CHARS[idx] for idx in preb_labels[0] if idx < len(CHARS)]
print(f"LPRNet识别结果: {''.join(predicted_chars)}")
# 确保返回7个字符车牌标准长度
if len(predicted_chars) < 7:
# 如果识别结果少于7个字符用'0'补齐
predicted_chars.extend(['0'] * (7 - len(predicted_chars)))
elif len(predicted_chars) > 7:
# 如果识别结果多于7个字符截取前7个
predicted_chars = predicted_chars[:7]
return predicted_chars
else:
print("LPRNet识别结果为空")
return ['', '', '', '', '0', '0', '0']
except Exception as e: except Exception as e:
print(f"LPRNet识别失败: {e}") print(f"LPRNet识别失败: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
return ['', '', '', '', '0', '0', '0'] return ['', '', '', '', '0', '0', '0', '0']
# 为了保持与其他模块的一致性,提供一个处理器类
class LPRProcessor:
def __init__(self):
self.initialized = False
def initialize(self, model_path=None):
"""初始化模型"""
self.initialized = LPRNinitialize_model(model_path)
return self.initialized
def predict(self, image_array):
"""预测接口"""
if not self.initialized:
print("模型未初始化")
return ['', '', '', '', '0', '0', '0']
return LPRNmodel_predict(image_array)
# 创建全局处理器实例
_processor = LPRProcessor()
def get_lpr_processor():
"""获取LPR处理器实例"""
return _processor

View File

@ -22,6 +22,17 @@ def LPRNinitialize_model():
return _processor return _processor
def LPRNmodel_predict(image_array): def LPRNmodel_predict(image_array):
"""
OCR车牌号识别接口函数
参数:
image_array: numpy数组格式的车牌图像已经过矫正处理
返回:
list: 包含最多8个字符的列表代表车牌号的每个字符
例如: ['', 'A', '1', '2', '3', '4', '5', ''] (蓝牌7位+占位符)
['', 'A', 'D', '1', '2', '3', '4', '5'] (绿牌8位)
"""
# 获取原始预测结果 # 获取原始预测结果
raw_result = _processor.predict(image_array) raw_result = _processor.predict(image_array)
@ -37,13 +48,19 @@ def LPRNmodel_predict(image_array):
# 转换为字符列表 # 转换为字符列表
char_list = list(filtered_str) char_list = list(filtered_str)
# 确保返回长度为7的列表 # 确保返回至少7个字符最多8个字符
if len(char_list) >= 7: if len(char_list) < 7:
# 如果长度大于等于7取前7个字符 # 如果识别结果少于7个字符用'0'补齐到7位
return char_list[:7] char_list.extend(['0'] * (7 - len(char_list)))
else: elif len(char_list) > 8:
# 如果长度小于7用空字符串补齐到7位 # 如果识别结果多于8个字符截取前8个
return char_list + [''] * (7 - len(char_list)) char_list = char_list[:8]
# 如果是7位补齐到8位以保持接口一致性第8位用空字符或占位符
if len(char_list) == 7:
char_list.append('') # 添加空字符作为第8位占位符
return char_list

View File

@ -15,11 +15,10 @@ License_plate_recognition/
├── OCR_part/ # OCR识别模块 ├── OCR_part/ # OCR识别模块
│ └── ocr_interface.py # OCR接口占位 │ └── ocr_interface.py # OCR接口占位
├── CRNN_part/ # CRNN识别模块 ├── CRNN_part/ # CRNN识别模块
│ └── crnn_interface.py # CRNN接口 │ └── crnn_interface.py # CRNN接口(占位)
└── LPRNET_part/ # LPRNet识别模块 └── LPRNET_part/ # LPRNet识别模块
├── lpr_interface.py # LPRNet接口 ├── lpr_interface.py # LPRNet接口已完成
├── Final_LPRNet_model.pth # 预训练模型文件 └── LPRNet__iteration_74000.pth # LPRNet模型权重文件
└── will_delete/ # 参考资料(可删除)
``` ```
## 功能特性 ## 功能特性
@ -39,19 +38,22 @@ License_plate_recognition/
- 将倾斜的车牌矫正为标准矩形 - 将倾斜的车牌矫正为标准矩形
- 输出标准尺寸的车牌图像供识别使用 - 输出标准尺寸的车牌图像供识别使用
### 4. PyQt界面 ### 4. 多种识别方案
- 支持OCR、CRNN和LPRNet三种车牌识别方法
- LPRNet模型准确率高达98%
- 模块化接口设计,便于切换不同识别算法
### 5. PyQt界面
- 左侧:实时摄像头画面显示 - 左侧:实时摄像头画面显示
- 右侧:检测结果展示区域 - 右侧:检测结果展示区域
- 顶部显示识别到的车牌数量 - 顶部显示识别到的车牌数量
- 每行显示:车牌类型、矫正后图像、车牌号 - 每行显示:车牌类型、矫正后图像、车牌号
- 美观的现代化界面设计 - 美观的现代化界面设计
### 5. 模块化设计 ### 6. 模块化设计
- yolopart负责车牌定位和矫正 - yolopart负责车牌定位和矫正
- OCR_part基于PaddleOCR的车牌号识别模块 - OCR_part/CRNN_part/LPRNET_part负责车牌号识别
- CRNN_part基于CRNN网络的车牌号识别模块 - 各模块独立,便于维护和扩展
- LPRNET_part基于LPRNet网络的车牌号识别模块
- 各模块独立便于维护和扩展可通过修改main.py中的导入语句切换识别模块
## 安装和使用 ## 安装和使用
@ -73,7 +75,21 @@ pip install -r requirements.txt
python main.py python main.py
``` ```
### 5. 使用说明 ### 5. 选择识别模块
`main.py` 中修改导入语句来选择不同的识别方案:
```python
# 使用LPRNet推荐准确率98%
from LPRNET_part.lpr_interface import LPRNmodel_predict, LPRNinitialize_model
# 使用OCR
from OCR_part.ocr_interface import LPRNmodel_predict, LPRNinitialize_model
# 使用CRNN
from CRNN_part.crnn_interface import LPRNmodel_predict, LPRNinitialize_model
```
### 6. 使用说明
1. 点击"启动摄像头"按钮开始检测 1. 点击"启动摄像头"按钮开始检测
2. 将车牌对准摄像头 2. 将车牌对准摄像头
3. 系统会自动检测车牌并显示: 3. 系统会自动检测车牌并显示:
@ -95,8 +111,9 @@ YOLO Pose模型输出包含
## 接口说明 ## 接口说明
### OCR/CRNN接口 ### 车牌识别接口
车牌号识别部分使用统一接口:
项目为OCR、CRNN和LPRNet识别模块提供了标准接口
```python ```python
# 接口函数名(导入所需模块,每个模块统一函数名) # 接口函数名(导入所需模块,每个模块统一函数名)
@ -108,7 +125,7 @@ LPRNinitialize_model()
# 预测主函数 # 预测主函数
from 对应模块 import LPRNmodel_predict from 对应模块 import LPRNmodel_predict
result = LPRNmodel_predict(corrected_image) # 返回7个字符的列表 result = LPRNmodel_predict(corrected_image) # 返回7个字符的列表
```
### 输入参数 ### 输入参数
- `corrected_image`numpy数组格式的矫正后车牌图像 - `corrected_image`numpy数组格式的矫正后车牌图像
@ -117,6 +134,14 @@ result = LPRNmodel_predict(corrected_image) # 返回7个字符的列表
- 长度为7的字符列表包含车牌号的每个字符 - 长度为7的字符列表包含车牌号的每个字符
- 例如:`['京', 'A', '1', '2', '3', '4', '5']` - 例如:`['京', 'A', '1', '2', '3', '4', '5']`
### LPRNet模块特性
- **高准确率**: 模型准确率高达98%
- **快速推理**: 基于深度学习的端到端识别
- **CTC解码**: 使用CTCConnectionist Temporal Classification解码算法
- **支持中文**: 完整支持中文省份简称和字母数字组合
- **模型权重**: 使用预训练的LPRNet__iteration_74000.pth权重文件
## 开发说明 ## 开发说明
### 添加新的识别算法 ### 添加新的识别算法

22
main.py
View File

@ -361,8 +361,8 @@ class MainWindow(QMainWindow):
# 矫正车牌图像 # 矫正车牌图像
corrected_image = self.correct_license_plate(detection) corrected_image = self.correct_license_plate(detection)
# 获取车牌号(占位) # 获取车牌号,传入车牌类型信息
plate_number = self.recognize_plate_number(corrected_image) plate_number = self.recognize_plate_number(corrected_image, detection['class_name'])
# 创建车牌显示组件 # 创建车牌显示组件
plate_widget = LicensePlateWidget( plate_widget = LicensePlateWidget(
@ -389,7 +389,7 @@ class MainWindow(QMainWindow):
detection['keypoints'] detection['keypoints']
) )
def recognize_plate_number(self, corrected_image): def recognize_plate_number(self, corrected_image, class_name):
"""识别车牌号""" """识别车牌号"""
if corrected_image is None: if corrected_image is None:
return "识别失败" return "识别失败"
@ -399,9 +399,21 @@ class MainWindow(QMainWindow):
# 函数名改成一样的了,所以不要修改这里了,想用哪个模块直接导入 # 函数名改成一样的了,所以不要修改这里了,想用哪个模块直接导入
result = LPRNmodel_predict(corrected_image) result = LPRNmodel_predict(corrected_image)
# 将字符列表转换为字符串 # 将字符列表转换为字符串支持8位车牌号
if isinstance(result, list) and len(result) >= 7: if isinstance(result, list) and len(result) >= 7:
return ''.join(result[:7]) # 根据车牌类型决定显示位数
if class_name == '绿牌' and len(result) >= 8:
# 绿牌显示8位过滤掉空字符占位符
plate_chars = [char for char in result[:8] if char != '']
# 如果过滤后确实有8位显示8位否则显示7位
if len(plate_chars) == 8:
return ''.join(plate_chars)
else:
return ''.join(plate_chars[:7])
else:
# 蓝牌或其他类型显示前7位过滤掉空字符
plate_chars = [char for char in result[:7] if char != '']
return ''.join(plate_chars)
else: else:
return "识别失败" return "识别失败"
except Exception as e: except Exception as e: