Compare commits

...

7 Commits

7 changed files with 872 additions and 439 deletions

View File

@ -5,8 +5,4 @@
<orderEntry type="jdk" jdkName="cnm" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="PyDocumentationSettings">
<option name="format" value="PLAIN" />
<option name="myDocStringFormat" value="Plain" />
</component>
</module>

Binary file not shown.

View File

@ -1,328 +0,0 @@
import torch
import torch.nn as nn
import cv2
import numpy as np
import os
import sys
from torch.autograd import Variable
from PIL import Image
# 添加父目录到路径,以便导入模型和数据加载器
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# LPRNet字符集定义与训练时保持一致
CHARS = ['', '', '', '', '', '', '', '', '', '',
'', '', '', '', '', '', '', '', '', '',
'', '', '', '', '', '', '', '', '', '', '',
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J', 'K',
'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'U', 'V',
'W', 'X', 'Y', 'Z', 'I', 'O', '-']
CHARS_DICT = {char: i for i, char in enumerate(CHARS)}
# 简化的LPRNet模型定义
class small_basic_block(nn.Module):
def __init__(self, ch_in, ch_out):
super(small_basic_block, self).__init__()
self.block = nn.Sequential(
nn.Conv2d(ch_in, ch_out // 4, kernel_size=1),
nn.ReLU(),
nn.Conv2d(ch_out // 4, ch_out // 4, kernel_size=(3, 1), padding=(1, 0)),
nn.ReLU(),
nn.Conv2d(ch_out // 4, ch_out // 4, kernel_size=(1, 3), padding=(0, 1)),
nn.ReLU(),
nn.Conv2d(ch_out // 4, ch_out, kernel_size=1),
)
def forward(self, x):
return self.block(x)
class LPRNet(nn.Module):
def __init__(self, lpr_max_len, phase, class_num, dropout_rate):
super(LPRNet, self).__init__()
self.phase = phase
self.lpr_max_len = lpr_max_len
self.class_num = class_num
self.backbone = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1), # 0
nn.BatchNorm2d(num_features=64),
nn.ReLU(), # 2
nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 1, 1)),
small_basic_block(ch_in=64, ch_out=128), # *** 4 ***
nn.BatchNorm2d(num_features=128),
nn.ReLU(), # 6
nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(2, 1, 2)),
small_basic_block(ch_in=64, ch_out=256), # 8
nn.BatchNorm2d(num_features=256),
nn.ReLU(), # 10
small_basic_block(ch_in=256, ch_out=256), # *** 11 ***
nn.BatchNorm2d(num_features=256),
nn.ReLU(), # 13
nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(4, 1, 2)), # 14
nn.Dropout(dropout_rate),
nn.Conv2d(in_channels=64, out_channels=256, kernel_size=(1, 4), stride=1), # 16
nn.BatchNorm2d(num_features=256),
nn.ReLU(), # 18
nn.Dropout(dropout_rate),
nn.Conv2d(in_channels=256, out_channels=class_num, kernel_size=(13, 1), stride=1), # 20
nn.BatchNorm2d(num_features=class_num),
nn.ReLU(), # 22
)
self.container = nn.Sequential(
nn.Conv2d(in_channels=448+self.class_num, out_channels=self.class_num, kernel_size=(1,1), stride=(1,1)),
)
def forward(self, x):
keep_features = list()
for i, layer in enumerate(self.backbone.children()):
x = layer(x)
if i in [2, 6, 13, 22]: # [2, 4, 8, 11, 22]
keep_features.append(x)
global_context = list()
for i, f in enumerate(keep_features):
if i in [0, 1]:
f = nn.AvgPool2d(kernel_size=5, stride=5)(f)
if i in [2]:
f = nn.AvgPool2d(kernel_size=(4, 10), stride=(4, 2))(f)
f_pow = torch.pow(f, 2)
f_mean = torch.mean(f_pow)
f = torch.div(f, f_mean)
global_context.append(f)
x = torch.cat(global_context, 1)
x = self.container(x)
logits = torch.mean(x, dim=2)
return logits
class LPRNetInference:
def __init__(self, model_path=None, img_size=[94, 24], lpr_max_len=8, dropout_rate=0.5):
"""
初始化LPRNet推理类
Args:
model_path: 训练好的模型权重文件路径
img_size: 输入图像尺寸 [width, height]
lpr_max_len: 车牌最大长度
dropout_rate: dropout率
"""
self.img_size = img_size
self.lpr_max_len = lpr_max_len
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 设置默认模型路径
if model_path is None:
current_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(current_dir, 'LPRNet__iteration_74000.pth')
# 初始化模型
self.model = LPRNet(lpr_max_len=lpr_max_len, phase=False, class_num=len(CHARS), dropout_rate=dropout_rate)
# 加载模型权重
if model_path and os.path.exists(model_path):
print(f"Loading LPRNet model from {model_path}")
try:
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
print("LPRNet模型权重加载成功")
except Exception as e:
print(f"Warning: 加载模型权重失败: {e}. 使用随机权重.")
else:
print(f"Warning: 模型文件不存在或未指定: {model_path}. 使用随机权重.")
self.model.to(self.device)
self.model.eval()
print(f"LPRNet模型加载完成设备: {self.device}")
print(f"模型参数数量: {sum(p.numel() for p in self.model.parameters()):,}")
def preprocess_image(self, image_array):
"""
预处理图像数组 - 使用与训练时相同的预处理方式
Args:
image_array: numpy数组格式的图像 (H, W, C)
Returns:
preprocessed_image: 预处理后的图像tensor
"""
if image_array is None:
raise ValueError("Input image is None")
# 确保图像是numpy数组
if not isinstance(image_array, np.ndarray):
raise ValueError("Input must be numpy array")
# 检查图像维度
if len(image_array.shape) != 3:
raise ValueError(f"Expected 3D image array, got {len(image_array.shape)}D")
height, width, channels = image_array.shape
if channels != 3:
raise ValueError(f"Expected 3 channels, got {channels}")
# 调整图像尺寸到模型要求的尺寸
if height != self.img_size[1] or width != self.img_size[0]:
image_array = cv2.resize(image_array, tuple(self.img_size))
# 使用与训练时相同的预处理方式
image_array = image_array.astype('float32')
image_array -= 127.5
image_array *= 0.0078125
image_array = np.transpose(image_array, (2, 0, 1)) # HWC -> CHW
# 转换为tensor并添加batch维度
image_tensor = torch.from_numpy(image_array).unsqueeze(0)
return image_tensor
def decode_prediction(self, logits):
"""
解码模型预测结果 - 使用正确的CTC贪婪解码
Args:
logits: 模型输出的logits [batch_size, num_classes, sequence_length]
Returns:
predicted_text: 预测的车牌号码
"""
# 转换为numpy进行处理
prebs = logits.cpu().detach().numpy()
preb = prebs[0, :, :] # 取第一个batch [num_classes, sequence_length]
# 贪婪解码:对每个时间步选择最大概率的字符
preb_label = []
for j in range(preb.shape[1]): # 遍历每个时间步
preb_label.append(np.argmax(preb[:, j], axis=0))
# CTC解码去除重复字符和空白字符
no_repeat_blank_label = []
pre_c = preb_label[0]
# 处理第一个字符
if pre_c != len(CHARS) - 1: # 不是空白字符
no_repeat_blank_label.append(pre_c)
# 处理后续字符
for c in preb_label:
if (pre_c == c) or (c == len(CHARS) - 1): # 重复字符或空白字符
if c == len(CHARS) - 1:
pre_c = c
continue
no_repeat_blank_label.append(c)
pre_c = c
# 转换为字符
decoded_chars = [CHARS[idx] for idx in no_repeat_blank_label]
return ''.join(decoded_chars)
def predict(self, image_array):
"""
预测单张图像的车牌号码
Args:
image_array: numpy数组格式的图像
Returns:
prediction: 预测的车牌号码
confidence: 预测置信度
"""
try:
# 预处理图像
image = self.preprocess_image(image_array)
if image is None:
return None, 0.0
image = image.to(self.device)
# 模型推理
with torch.no_grad():
logits = self.model(image)
# logits shape: [batch_size, class_num, sequence_length]
# 计算置信度使用softmax后的最大概率平均值
probs = torch.softmax(logits, dim=1)
max_probs = torch.max(probs, dim=1)[0]
confidence = torch.mean(max_probs).item()
# 解码预测结果
prediction = self.decode_prediction(logits)
return prediction, confidence
except Exception as e:
print(f"预测图像失败: {e}")
return None, 0.0
# 全局变量
lpr_model = None
def LPRNinitialize_model():
"""
初始化LPRNet模型
返回:
bool: 初始化是否成功
"""
global lpr_model
try:
# 模型权重文件路径
model_path = os.path.join(os.path.dirname(__file__), 'LPRNet__iteration_74000.pth')
# 创建推理对象
lpr_model = LPRNetInference(model_path)
print("LPRNet模型初始化完成")
return True
except Exception as e:
print(f"LPRNet模型初始化失败: {e}")
import traceback
traceback.print_exc()
return False
def LPRNmodel_predict(image_array):
"""
LPRNet车牌号识别接口函数
参数:
image_array: numpy数组格式的车牌图像已经过矫正处理
返回:
list: 包含最多8个字符的列表代表车牌号的每个字符
例如: ['', 'A', '1', '2', '3', '4', '5'] (蓝牌7位)
['', 'A', 'D', '1', '2', '3', '4', '5'] (绿牌8位)
"""
global lpr_model
if lpr_model is None:
print("LPRNet模型未初始化请先调用LPRNinitialize_model()")
return ['', '', '', '0', '0', '0', '0', '0']
try:
# 预测车牌号
predicted_text, confidence = lpr_model.predict(image_array)
if predicted_text is None:
print("LPRNet识别失败")
return ['', '', '', '', '0', '0', '0', '0']
print(f"LPRNet识别结果: {predicted_text}, 置信度: {confidence:.3f}")
# 将字符串转换为字符列表
char_list = list(predicted_text)
# 确保返回至少7个字符最多8个字符
if len(char_list) < 7:
# 如果识别结果少于7个字符用'0'补齐到7位
char_list.extend(['0'] * (7 - len(char_list)))
elif len(char_list) > 8:
# 如果识别结果多于8个字符截取前8个
char_list = char_list[:8]
# 如果是7位补齐到8位以保持接口一致性第8位用空字符或占位符
if len(char_list) == 7:
char_list.append('') # 添加空字符作为第8位占位符
return char_list
except Exception as e:
print(f"LPRNet识别失败: {e}")
import traceback
traceback.print_exc()
return ['', '', '', '', '0', '0', '0', '0']

View File

@ -5,6 +5,18 @@ import cv2
class OCRProcessor:
def __init__(self):
self.model = TextRecognition(model_name="PP-OCRv5_server_rec")
# 定义允许的字符集合(不包含空白字符)
self.allowed_chars = [
# 中文省份简称
'', '', '', '', '', '', '', '', '', '',
'', '', '', '', '', '', '', '', '', '',
'', '', '', '', '', '', '', '', '', '', '',
# 字母 A-Z
'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M',
'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',
# 数字 0-9
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9'
]
print("OCR模型初始化完成占位")
def predict(self, image_array):
@ -15,6 +27,14 @@ class OCRProcessor:
placeholder_result = results.split(',')
return placeholder_result
def filter_allowed_chars(self, text):
"""只保留允许的字符"""
filtered_text = ""
for char in text:
if char in self.allowed_chars:
filtered_text += char
return filtered_text
# 保留原有函数接口
_processor = OCRProcessor()
@ -42,8 +62,12 @@ def LPRNmodel_predict(image_array):
else:
result_str = str(raw_result)
# 过滤掉'·'字符
# 过滤掉'·'和'-'字符
filtered_str = result_str.replace('·', '')
filtered_str = filtered_str.replace('-', '')
# 只保留允许的字符
filtered_str = _processor.filter_allowed_chars(filtered_str)
# 转换为字符列表
char_list = list(filtered_str)

Binary file not shown.

773
main.py
View File

@ -1,19 +1,17 @@
import sys
import os
import cv2
import numpy as np
from PyQt5.QtWidgets import (
QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout,
QLabel, QPushButton, QScrollArea, QFrame, QSizePolicy
)
from collections import defaultdict, deque
from PyQt5.QtWidgets import QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, \
QFileDialog, QFrame, QScrollArea, QComboBox
from PyQt5.QtCore import QTimer, Qt, pyqtSignal, QThread
from PyQt5.QtGui import QImage, QPixmap, QFont, QPainter, QPen, QColor
import os
from yolopart.detector import LicensePlateYOLO
#选择使用哪个模块
from LPRNET_part.lpr_interface import LPRNmodel_predict
from LPRNET_part.lpr_interface import LPRNinitialize_model
# from LPRNET_part.lpr_interface import LPRNmodel_predict
# from LPRNET_part.lpr_interface import LPRNinitialize_model
#使用OCR
# from OCR_part.ocr_interface import LPRNmodel_predict
# from OCR_part.ocr_interface import LPRNinitialize_model
@ -21,6 +19,190 @@ from LPRNET_part.lpr_interface import LPRNinitialize_model
# from CRNN_part.crnn_interface import LPRNmodel_predict
# from CRNN_part.crnn_interface import LPRNinitialize_model
class PlateStabilizer:
"""车牌识别结果稳定器"""
def __init__(self, history_size=10, confidence_threshold=0.6, stability_frames=5):
self.history_size = history_size # 历史帧数量
self.confidence_threshold = confidence_threshold # 置信度阈值
self.stability_frames = stability_frames # 稳定帧数要求
# 存储每个车牌的历史识别结果
self.plate_histories = defaultdict(lambda: deque(maxlen=history_size))
# 存储当前稳定的车牌结果
self.stable_results = {}
# 车牌ID计数器
self.plate_id_counter = 0
# 车牌位置追踪
self.plate_positions = {}
def calculate_plate_distance(self, pos1, pos2):
"""计算两个车牌位置的距离"""
if pos1 is None or pos2 is None:
return float('inf')
# 计算中心点距离
center1 = ((pos1[0] + pos1[2]) / 2, (pos1[1] + pos1[3]) / 2)
center2 = ((pos2[0] + pos2[2]) / 2, (pos2[1] + pos2[3]) / 2)
return np.sqrt((center1[0] - center2[0])**2 + (center1[1] - center2[1])**2)
def match_plates_to_history(self, current_detections):
"""将当前检测结果匹配到历史记录"""
matched_plates = {}
used_ids = set()
for detection in current_detections:
bbox = detection.get('bbox', [0, 0, 0, 0])
best_match_id = None
min_distance = float('inf')
# 寻找最佳匹配的历史车牌
for plate_id, last_pos in self.plate_positions.items():
if plate_id in used_ids:
continue
distance = self.calculate_plate_distance(bbox, last_pos)
if distance < min_distance and distance < 100: # 距离阈值
min_distance = distance
best_match_id = plate_id
if best_match_id is not None:
matched_plates[best_match_id] = detection
used_ids.add(best_match_id)
self.plate_positions[best_match_id] = bbox
else:
# 创建新的车牌ID
new_id = f"plate_{self.plate_id_counter}"
self.plate_id_counter += 1
matched_plates[new_id] = detection
self.plate_positions[new_id] = bbox
return matched_plates
def calculate_confidence(self, plate_text, detection_quality=1.0):
"""计算识别结果的置信度"""
if not plate_text or plate_text == "识别失败":
return 0.0
# 基础置信度基于文本长度和字符类型
base_confidence = 0.5
# 长度合理性检查
if 7 <= len(plate_text) <= 8:
base_confidence += 0.2
# 字符类型检查(中文+字母+数字的组合)
has_chinese = any('\u4e00' <= char <= '\u9fff' for char in plate_text)
has_letter = any(char.isalpha() for char in plate_text)
has_digit = any(char.isdigit() for char in plate_text)
if has_chinese and has_letter and has_digit:
base_confidence += 0.2
# 检测质量影响
confidence = base_confidence * detection_quality
return min(confidence, 1.0)
def update_and_get_stable_result(self, current_detections, corrected_images, plate_texts):
"""更新历史记录并返回稳定的识别结果"""
if not current_detections:
return []
# 匹配当前检测到历史记录
matched_plates = self.match_plates_to_history(current_detections)
stable_results = []
for plate_id, detection in matched_plates.items():
# 获取对应的矫正图像和识别文本
detection_idx = current_detections.index(detection)
corrected_image = corrected_images[detection_idx] if detection_idx < len(corrected_images) else None
plate_text = plate_texts[detection_idx] if detection_idx < len(plate_texts) else "识别失败"
# 计算置信度
confidence = self.calculate_confidence(plate_text)
# 添加到历史记录
history_entry = {
'text': plate_text,
'confidence': confidence,
'detection': detection,
'corrected_image': corrected_image
}
self.plate_histories[plate_id].append(history_entry)
# 计算稳定结果
stable_text = self.get_stable_text(plate_id)
if stable_text and stable_text != "识别失败":
stable_results.append({
'id': plate_id,
'class_name': detection['class_name'],
'corrected_image': corrected_image,
'plate_number': stable_text,
'detection': detection
})
return stable_results
def get_stable_text(self, plate_id):
"""获取指定车牌的稳定识别结果"""
history = self.plate_histories[plate_id]
if len(history) < 3: # 历史记录太少,返回最新结果
return history[-1]['text'] if history else "识别失败"
# 统计各种识别结果的加权投票
text_votes = defaultdict(float)
total_confidence = 0
for entry in history:
text = entry['text']
confidence = entry['confidence']
if text != "识别失败" and confidence > 0.3:
text_votes[text] += confidence
total_confidence += confidence
if not text_votes:
return "识别失败"
# 找到得票最高的结果
best_text = max(text_votes.items(), key=lambda x: x[1])
# 检查是否足够稳定(得票率超过阈值)
vote_ratio = best_text[1] / total_confidence if total_confidence > 0 else 0
if vote_ratio >= self.confidence_threshold:
return best_text[0]
else:
# 不够稳定,返回最近的高置信度结果
recent_high_conf = [entry for entry in list(history)[-5:]
if entry['confidence'] > 0.5 and entry['text'] != "识别失败"]
if recent_high_conf:
return recent_high_conf[-1]['text']
else:
return history[-1]['text']
def clear_old_plates(self, current_plate_ids):
"""清理不再出现的车牌历史记录"""
# 移除超过一定时间未更新的车牌
plates_to_remove = []
for plate_id in self.plate_histories.keys():
if plate_id not in current_plate_ids:
plates_to_remove.append(plate_id)
for plate_id in plates_to_remove:
if plate_id in self.plate_histories:
del self.plate_histories[plate_id]
if plate_id in self.plate_positions:
del self.plate_positions[plate_id]
if plate_id in self.stable_results:
del self.stable_results[plate_id]
class CameraThread(QThread):
"""摄像头线程类"""
frame_ready = pyqtSignal(np.ndarray)
@ -56,6 +238,60 @@ class CameraThread(QThread):
self.frame_ready.emit(frame)
self.msleep(30) # 约30fps
class VideoThread(QThread):
"""视频处理线程类"""
frame_ready = pyqtSignal(np.ndarray)
video_finished = pyqtSignal()
def __init__(self):
super().__init__()
self.video_path = None
self.cap = None
self.running = False
self.paused = False
def load_video(self, video_path):
"""加载视频文件"""
self.video_path = video_path
self.cap = cv2.VideoCapture(video_path)
return self.cap.isOpened()
def start_video(self):
"""开始播放视频"""
if self.cap and self.cap.isOpened():
self.running = True
self.paused = False
self.start()
return True
return False
def pause_video(self):
"""暂停/继续视频"""
self.paused = not self.paused
return self.paused
def stop_video(self):
"""停止视频"""
self.running = False
if self.cap:
self.cap.release()
self.quit()
self.wait()
def run(self):
"""线程运行函数"""
while self.running:
if not self.paused and self.cap and self.cap.isOpened():
ret, frame = self.cap.read()
if ret:
self.frame_ready.emit(frame)
else:
# 视频播放结束
self.video_finished.emit()
self.running = False
break
self.msleep(30) # 约30fps
class LicensePlateWidget(QWidget):
"""单个车牌结果显示组件"""
@ -67,6 +303,7 @@ class LicensePlateWidget(QWidget):
def init_ui(self, class_name, corrected_image, plate_number):
layout = QHBoxLayout()
layout.setContentsMargins(10, 5, 10, 5)
layout.setSpacing(8) # 设置组件间距
# 车牌类型标签
type_label = QLabel(class_name)
@ -104,7 +341,6 @@ class LicensePlateWidget(QWidget):
# 矫正后的车牌图像
image_label = QLabel()
image_label.setFixedSize(120, 40)
image_label.setStyleSheet("border: 1px solid #ddd; background-color: white;")
if corrected_image is not None:
@ -118,16 +354,44 @@ class LicensePlateWidget(QWidget):
q_image = QImage(corrected_image.data, w, h, bytes_per_line, QImage.Format_Grayscale8)
pixmap = QPixmap.fromImage(q_image)
scaled_pixmap = pixmap.scaled(120, 40, Qt.KeepAspectRatio, Qt.SmoothTransformation)
# 动态计算显示尺寸,保持车牌的宽高比
original_width = pixmap.width()
original_height = pixmap.height()
# 设置最大显示尺寸限制
max_width = 150
max_height = 60
# 计算缩放比例,确保图像完整显示
width_ratio = max_width / original_width if original_width > 0 else 1
height_ratio = max_height / original_height if original_height > 0 else 1
scale_ratio = min(width_ratio, height_ratio, 1.0) # 不放大,只缩小
# 计算实际显示尺寸
display_width = int(original_width * scale_ratio)
display_height = int(original_height * scale_ratio)
# 确保最小显示尺寸
display_width = max(display_width, 80)
display_height = max(display_height, 25)
# 设置标签尺寸并缩放图像
image_label.setFixedSize(display_width, display_height)
scaled_pixmap = pixmap.scaled(display_width, display_height, Qt.KeepAspectRatio, Qt.SmoothTransformation)
image_label.setPixmap(scaled_pixmap)
image_label.setAlignment(Qt.AlignCenter)
else:
# 当没有图像时,设置固定尺寸显示提示信息
image_label.setFixedSize(120, 40)
image_label.setText("车牌未完全\n进入摄像头")
image_label.setAlignment(Qt.AlignCenter)
image_label.setStyleSheet("border: 1px solid #ddd; background-color: #f5f5f5; color: #666;")
# 车牌号标签
# 车牌号标签 - 使用自适应宽度
number_label = QLabel(plate_number)
number_label.setFixedWidth(150)
number_label.setMinimumWidth(120) # 设置最小宽度
number_label.setMaximumWidth(200) # 设置最大宽度
number_label.setAlignment(Qt.AlignCenter)
number_label.setStyleSheet(
"QLabel { "
@ -139,6 +403,11 @@ class LicensePlateWidget(QWidget):
"font-weight: bold; "
"}"
)
# 根据文本长度调整宽度
font_metrics = number_label.fontMetrics()
text_width = font_metrics.boundingRect(plate_number).width()
optimal_width = max(120, min(200, text_width + 20)) # 加20像素的边距
number_label.setFixedWidth(optimal_width)
layout.addWidget(type_label)
layout.addWidget(image_label)
@ -146,6 +415,9 @@ class LicensePlateWidget(QWidget):
layout.addStretch()
self.setLayout(layout)
# 调整整体组件的最小高度以适应动态图像尺寸
min_height = max(60, image_label.height() + 20) # 至少60像素高度
self.setMinimumHeight(min_height)
self.setStyleSheet(
"QWidget { "
"background-color: white; "
@ -162,15 +434,28 @@ class MainWindow(QMainWindow):
super().__init__()
self.detector = None
self.camera_thread = None
self.video_thread = None
self.current_frame = None
self.detections = []
self.current_mode = "camera" # 当前模式camera, video, image
self.is_processing = False # 标志位,表示是否正在处理识别任务
self.last_plate_results = [] # 存储上一次的车牌识别结果
self.current_recognition_method = "CRNN" # 当前识别方法
# 添加车牌稳定器
self.plate_stabilizer = PlateStabilizer(
history_size=15, # 保存15帧历史
confidence_threshold=0.7, # 70%置信度阈值
stability_frames=5 # 需要5帧稳定
)
self.init_ui()
self.init_detector()
self.init_camera()
self.init_video()
# 初始化OCR/CRNN模型函数名改成一样的了所以不要修改这里了想用哪个模块直接导入
LPRNinitialize_model()
# 初始化默认识别方法CRNN的模型
self.change_recognition_method(self.current_recognition_method)
def init_ui(self):
@ -197,7 +482,7 @@ class MainWindow(QMainWindow):
self.camera_label.setStyleSheet("QLabel { background-color: black; border: 1px solid #ccc; }")
self.camera_label.setAlignment(Qt.AlignCenter)
self.camera_label.setText("摄像头未启动")
self.camera_label.setScaledContents(True)
self.camera_label.setScaledContents(False)
# 控制按钮
button_layout = QHBoxLayout()
@ -207,8 +492,26 @@ class MainWindow(QMainWindow):
self.stop_button.clicked.connect(self.stop_camera)
self.stop_button.setEnabled(False)
# 视频控制按钮
self.open_video_button = QPushButton("打开视频")
self.stop_video_button = QPushButton("停止视频")
self.pause_video_button = QPushButton("暂停视频")
self.open_video_button.clicked.connect(self.open_video_file)
self.stop_video_button.clicked.connect(self.stop_video)
self.pause_video_button.clicked.connect(self.pause_video)
self.stop_video_button.setEnabled(False)
self.pause_video_button.setEnabled(False)
# 图片控制按钮
self.open_image_button = QPushButton("打开图片")
self.open_image_button.clicked.connect(self.open_image_file)
button_layout.addWidget(self.start_button)
button_layout.addWidget(self.stop_button)
button_layout.addWidget(self.open_video_button)
button_layout.addWidget(self.stop_video_button)
button_layout.addWidget(self.pause_video_button)
button_layout.addWidget(self.open_image_button)
button_layout.addStretch()
left_layout.addWidget(self.camera_label)
@ -217,7 +520,7 @@ class MainWindow(QMainWindow):
# 右侧结果显示区域
right_frame = QFrame()
right_frame.setFrameStyle(QFrame.StyledPanel)
right_frame.setFixedWidth(400)
right_frame.setFixedWidth(460)
right_frame.setStyleSheet("QFrame { background-color: #fafafa; border: 2px solid #ddd; }")
right_layout = QVBoxLayout(right_frame)
@ -227,6 +530,20 @@ class MainWindow(QMainWindow):
title_label.setFont(QFont("Arial", 16, QFont.Bold))
title_label.setStyleSheet("QLabel { color: #333; padding: 10px; }")
# 识别方法选择
method_layout = QHBoxLayout()
method_label = QLabel("识别方法:")
method_label.setFont(QFont("Arial", 10))
self.method_combo = QComboBox()
self.method_combo.addItems(["CRNN", "LightCRNN", "OCR"])
self.method_combo.setCurrentText("CRNN") # 默认选择CRNN
self.method_combo.currentTextChanged.connect(self.change_recognition_method)
method_layout.addWidget(method_label)
method_layout.addWidget(self.method_combo)
method_layout.addStretch()
# 车牌数量显示
self.count_label = QLabel("识别到的车牌数量: 0")
self.count_label.setAlignment(Qt.AlignCenter)
@ -253,9 +570,17 @@ class MainWindow(QMainWindow):
scroll_area.setWidget(self.results_widget)
# 当前识别任务显示
self.current_method_label = QLabel("当前识别方法: CRNN")
self.current_method_label.setAlignment(Qt.AlignRight)
self.current_method_label.setFont(QFont("Arial", 9))
self.current_method_label.setStyleSheet("QLabel { color: #666; padding: 5px; }")
right_layout.addWidget(title_label)
right_layout.addLayout(method_layout)
right_layout.addWidget(self.count_label)
right_layout.addWidget(scroll_area)
right_layout.addWidget(self.current_method_label)
# 添加到主布局
main_layout.addWidget(left_frame, 2)
@ -291,14 +616,50 @@ class MainWindow(QMainWindow):
model_path = os.path.join(os.path.dirname(__file__), "yolopart", "yolo11s-pose42.pt")
self.detector = LicensePlateYOLO(model_path)
def reset_processing_state(self):
"""重置处理状态和清理界面"""
# 重置处理标志
self.is_processing = False
# 清空当前帧和检测结果
self.current_frame = None
self.detections = []
# 重置车牌稳定器
self.plate_stabilizer = PlateStabilizer(
history_size=15,
confidence_threshold=0.7,
stability_frames=5
)
# 清空右侧结果显示
self.count_label.setText("识别到的车牌数量: 0")
for i in reversed(range(self.results_layout.count())):
child = self.results_layout.itemAt(i).widget()
if child:
child.setParent(None)
self.last_plate_results = []
print("处理状态已重置,界面已清理")
def init_camera(self):
"""初始化摄像头线程"""
self.camera_thread = CameraThread()
self.camera_thread.frame_ready.connect(self.process_frame)
def init_video(self):
"""初始化视频线程"""
self.video_thread = VideoThread()
self.video_thread.frame_ready.connect(self.process_frame)
self.video_thread.video_finished.connect(self.on_video_finished)
def start_camera(self):
"""启动摄像头"""
# 重置处理状态和清理界面
self.reset_processing_state()
if self.camera_thread.start_camera():
self.current_mode = "camera"
self.start_button.setEnabled(False)
self.stop_button.setEnabled(True)
self.camera_label.setText("摄像头启动中...")
@ -311,68 +672,371 @@ class MainWindow(QMainWindow):
self.start_button.setEnabled(True)
self.stop_button.setEnabled(False)
self.camera_label.setText("摄像头已停止")
# 只在摄像头模式下清除标签内容
if self.current_mode == "camera":
self.camera_label.clear()
def on_video_finished(self):
"""视频播放结束时的处理"""
self.video_thread.stop_video()
self.open_video_button.setEnabled(True)
self.stop_video_button.setEnabled(False)
self.pause_video_button.setEnabled(False)
self.camera_label.setText("视频播放结束")
self.current_mode = "camera"
def open_video_file(self):
"""打开视频文件"""
# 停止当前模式
if self.current_mode == "camera" and self.camera_thread and self.camera_thread.running:
self.stop_camera()
elif self.current_mode == "video" and self.video_thread and self.video_thread.running:
self.stop_video()
# 重置处理状态和清理界面
self.reset_processing_state()
# 选择视频文件
video_path, _ = QFileDialog.getOpenFileName(self, "选择视频文件", "", "视频文件 (*.mp4 *.avi *.mov *.mkv)")
if video_path:
if self.video_thread.load_video(video_path):
self.current_mode = "video"
self.start_video()
self.camera_label.setText(f"正在播放视频: {os.path.basename(video_path)}")
else:
self.camera_label.setText("视频加载失败")
def start_video(self):
"""开始播放视频"""
if self.video_thread.start_video():
self.open_video_button.setEnabled(False)
self.stop_video_button.setEnabled(True)
self.pause_video_button.setEnabled(True)
self.pause_video_button.setText("暂停")
else:
self.camera_label.setText("视频播放失败")
def pause_video(self):
"""暂停/继续视频"""
if self.video_thread.pause_video():
self.pause_video_button.setText("继续")
else:
self.pause_video_button.setText("暂停")
def stop_video(self):
"""停止视频"""
self.video_thread.stop_video()
self.open_video_button.setEnabled(True)
self.stop_video_button.setEnabled(False)
self.pause_video_button.setEnabled(False)
self.camera_label.setText("视频已停止")
# 只在视频模式下清除标签内容
if self.current_mode == "video":
self.camera_label.clear()
self.current_mode = "camera"
def open_image_file(self):
"""打开图片文件"""
# 停止当前模式
if self.current_mode == "camera" and self.camera_thread and self.camera_thread.running:
self.stop_camera()
elif self.current_mode == "video" and self.video_thread and self.video_thread.running:
self.stop_video()
# 重置处理状态和清理界面
self.reset_processing_state()
# 选择图片文件
image_path, _ = QFileDialog.getOpenFileName(self, "选择图片文件", "", "图片文件 (*.jpg *.jpeg *.png *.bmp)")
if image_path:
self.current_mode = "image"
try:
# 读取图片 - 方法1: 使用cv2.imdecode处理中文路径
image = cv2.imdecode(np.fromfile(image_path, dtype=np.uint8), cv2.IMREAD_COLOR)
# 如果方法1失败尝试方法2: 直接使用cv2.imread
if image is None:
image = cv2.imread(image_path)
if image is not None:
print(f"成功加载图片: {image_path}, 尺寸: {image.shape}")
self.process_image(image)
# 不在这里设置文本,避免覆盖图片
# self.camera_label.setText(f"正在显示图片: {os.path.basename(image_path)}")
else:
print(f"图片加载失败: {image_path}")
self.camera_label.setText("图片加载失败")
except Exception as e:
print(f"图片处理异常: {str(e)}")
self.camera_label.setText(f"图片处理错误: {str(e)}")
def process_image(self, image):
"""处理图片"""
try:
print(f"开始处理图片,图片尺寸: {image.shape}")
self.current_frame = image.copy()
# 进行车牌检测
print("正在进行车牌检测...")
self.detections = self.detector.detect_license_plates(image)
print(f"检测到 {len(self.detections)} 个车牌")
# 在图像上绘制检测结果
print("正在绘制检测结果...")
display_frame = self.draw_detections(image.copy())
# 转换为Qt格式并显示
print("正在显示图片...")
self.display_frame(display_frame)
# 更新右侧结果显示
print("正在更新结果显示...")
self.update_results_display()
print("图片处理完成")
except Exception as e:
print(f"图片处理过程中出错: {str(e)}")
import traceback
traceback.print_exc()
def process_frame(self, frame):
"""处理摄像头帧"""
if frame is None:
return
self.current_frame = frame.copy()
# 先显示原始帧,保证视频流畅播放
self.display_frame(frame)
# 如果当前没有在处理识别任务,则开始新的识别任务
if not self.is_processing:
self.is_processing = True
# 异步进行车牌检测和识别
QTimer.singleShot(0, self.async_detect_and_update)
def async_detect_and_update(self):
"""异步进行车牌检测和识别"""
if self.current_frame is None:
self.is_processing = False # 重置标志位
return
try:
# 进行车牌检测
self.detections = self.detector.detect_license_plates(frame)
self.detections = self.detector.detect_license_plates(self.current_frame)
# 在图像上绘制检测结果
display_frame = self.draw_detections(frame.copy())
display_frame = self.draw_detections(self.current_frame.copy())
# 转换为Qt格式并显示
# 更新显示帧(显示带检测结果的帧)
# 无论是摄像头模式还是视频模式,都显示检测框
self.display_frame(display_frame)
# 更新右侧结果显示
self.update_results_display()
except Exception as e:
print(f"异步检测和更新失败: {str(e)}")
import traceback
traceback.print_exc()
finally:
# 无论成功或失败,都要重置标志位
self.is_processing = False
def draw_detections(self, frame):
"""在图像上绘制检测结果"""
return self.detector.draw_detections(frame, self.detections)
# 获取车牌号列表
plate_numbers = []
for detection in self.detections:
# 矫正车牌图像
corrected_image = self.correct_license_plate(detection)
# 获取车牌号
if corrected_image is not None:
plate_number = self.recognize_plate_number(corrected_image, detection['class_name'])
plate_numbers.append(plate_number)
else:
plate_numbers.append("识别失败")
return self.detector.draw_detections(frame, self.detections, plate_numbers)
def display_frame(self, frame):
"""显示帧到界面"""
try:
print(f"开始显示帧,帧尺寸: {frame.shape}")
# 方法1: 标准方法
try:
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
h, w, ch = rgb_frame.shape
bytes_per_line = ch * w
qt_image = QImage(rgb_frame.data, w, h, bytes_per_line, QImage.Format_RGB888)
print(f"方法1: 创建QImage尺寸: {qt_image.width()}x{qt_image.height()}")
if qt_image.isNull():
print("方法1: QImage为空尝试方法2")
raise Exception("QImage为空")
pixmap = QPixmap.fromImage(qt_image)
if pixmap.isNull():
print("方法1: QPixmap为空尝试方法2")
raise Exception("QPixmap为空")
# 手动缩放图片以适应标签大小,保持宽高比
scaled_pixmap = pixmap.scaled(self.camera_label.size(), Qt.KeepAspectRatio, Qt.SmoothTransformation)
self.camera_label.setPixmap(scaled_pixmap)
print("方法1: 帧显示完成")
return
except Exception as e1:
print(f"方法1失败: {str(e1)}")
# 方法2: 使用imencode和imdecode
try:
print("尝试方法2: 使用imencode和imdecode")
_, buffer = cv2.imencode('.jpg', frame)
rgb_frame = cv2.imdecode(buffer, cv2.IMREAD_COLOR)
rgb_frame = cv2.cvtColor(rgb_frame, cv2.COLOR_BGR2RGB)
h, w, ch = rgb_frame.shape
bytes_per_line = ch * w
qt_image = QImage(rgb_frame.data, w, h, bytes_per_line, QImage.Format_RGB888)
print(f"方法2: 创建QImage尺寸: {qt_image.width()}x{qt_image.height()}")
if qt_image.isNull():
print("方法2: QImage为空")
raise Exception("QImage为空")
pixmap = QPixmap.fromImage(qt_image)
if pixmap.isNull():
print("方法2: QPixmap为空")
raise Exception("QPixmap为空")
# 手动缩放图片以适应标签大小,保持宽高比
scaled_pixmap = pixmap.scaled(self.camera_label.size(), Qt.KeepAspectRatio, Qt.SmoothTransformation)
self.camera_label.setPixmap(scaled_pixmap)
print("方法2: 帧显示完成")
return
except Exception as e2:
print(f"方法2失败: {str(e2)}")
# 方法3: 直接使用QImage的构造函数
try:
print("尝试方法3: 直接使用QImage的构造函数")
height, width, channel = frame.shape
bytes_per_line = 3 * width
q_image = QImage(frame.data, width, height, bytes_per_line, QImage.Format_BGR888)
print(f"方法3: 创建QImage尺寸: {q_image.width()}x{q_image.height()}")
if q_image.isNull():
print("方法3: QImage为空")
raise Exception("QImage为空")
pixmap = QPixmap.fromImage(q_image)
if pixmap.isNull():
print("方法3: QPixmap为空")
raise Exception("QPixmap为空")
# 手动缩放图片以适应标签大小,保持宽高比
scaled_pixmap = pixmap.scaled(self.camera_label.size(), Qt.KeepAspectRatio, Qt.SmoothTransformation)
self.camera_label.setPixmap(scaled_pixmap)
print("方法3: 帧显示完成")
return
except Exception as e3:
print(f"方法3失败: {str(e3)}")
# 所有方法都失败
print("所有显示方法都失败")
self.camera_label.setText("图片显示失败")
except Exception as e:
print(f"显示帧过程中出错: {str(e)}")
import traceback
traceback.print_exc()
self.camera_label.setText(f"显示错误: {str(e)}")
def update_results_display(self):
"""更新右侧结果显示"""
# 更新车牌数量
count = len(self.detections)
self.count_label.setText(f"识别到的车牌数量: {count}")
"""更新右侧结果显示(使用稳定化结果)"""
print(f"开始更新结果显示,当前模式: {self.current_mode}, 检测数量: {len(self.detections) if self.detections else 0}")
if not self.detections:
self.count_label.setText("识别到的车牌数量: 0")
# 清除显示
for i in reversed(range(self.results_layout.count())):
child = self.results_layout.itemAt(i).widget()
if child:
child.setParent(None)
self.last_plate_results = []
print("无检测结果,已清空界面")
return
# 获取矫正图像和识别文本
corrected_images = []
plate_texts = []
for detection in self.detections:
corrected_image = self.correct_license_plate(detection)
corrected_images.append(corrected_image)
if corrected_image is not None:
plate_text = self.recognize_plate_number(corrected_image, detection['class_name'])
else:
plate_text = "识别失败"
plate_texts.append(plate_text)
# 使用稳定器获取稳定的识别结果
stable_results = self.plate_stabilizer.update_and_get_stable_result(
self.detections, corrected_images, plate_texts
)
# 更新车牌数量显示
self.count_label.setText(f"识别到的车牌数量: {len(stable_results)}")
print(f"稳定结果数量: {len(stable_results)}")
# 检查结果是否发生变化
results_changed = self.check_results_changed(stable_results)
print(f"结果是否变化: {results_changed}")
if results_changed:
# 清除之前的结果
for i in reversed(range(self.results_layout.count())):
child = self.results_layout.itemAt(i).widget()
if child:
child.setParent(None)
# 添加新的结果
for i, detection in enumerate(self.detections):
# 矫正车牌图像
corrected_image = self.correct_license_plate(detection)
# 获取车牌号,传入车牌类型信息
plate_number = self.recognize_plate_number(corrected_image, detection['class_name'])
# 创建车牌显示组件
# 添加新的稳定结果
for i, result in enumerate(stable_results):
plate_widget = LicensePlateWidget(
i + 1,
detection['class_name'],
corrected_image,
plate_number
i + 1, # 显示序号
result['class_name'],
result['corrected_image'],
result['plate_number']
)
self.results_layout.addWidget(plate_widget)
print(f"添加车牌widget: {result['plate_number']}")
# 更新存储的结果
self.last_plate_results = stable_results
# 清理旧的车牌记录
current_plate_ids = [result['id'] for result in stable_results]
self.plate_stabilizer.clear_old_plates(current_plate_ids)
print("结果显示更新完成")
def check_results_changed(self, new_results):
"""检查识别结果是否发生变化"""
if len(self.last_plate_results) != len(new_results):
return True
for i, new_result in enumerate(new_results):
if i >= len(self.last_plate_results):
return True
old_result = self.last_plate_results[i]
# 比较关键字段
if (old_result.get('class_name') != new_result.get('class_name') or
old_result.get('plate_number') != new_result.get('plate_number')):
return True
return False
def correct_license_plate(self, detection):
"""矫正车牌图像"""
@ -395,8 +1059,15 @@ class MainWindow(QMainWindow):
return "识别失败"
try:
# 根据当前选择的识别方法调用相应的函数
if self.current_recognition_method == "CRNN":
from CRNN_part.crnn_interface import LPRNmodel_predict
elif self.current_recognition_method == "LightCRNN":
from lightCRNN_part.lightcrnn_interface import LPRNmodel_predict
elif self.current_recognition_method == "OCR":
from OCR_part.ocr_interface import LPRNmodel_predict
# 预测函数(来自模块)
# 函数名改成一样的了,所以不要修改这里了,想用哪个模块直接导入
result = LPRNmodel_predict(corrected_image)
# 将字符列表转换为字符串支持8位车牌号
@ -420,10 +1091,32 @@ class MainWindow(QMainWindow):
print(f"车牌号识别失败: {e}")
return "识别失败"
def change_recognition_method(self, method):
"""切换识别方法"""
self.current_recognition_method = method
self.current_method_label.setText(f"当前识别方法: {method}")
# 初始化对应的模型
if method == "CRNN":
from CRNN_part.crnn_interface import LPRNinitialize_model
LPRNinitialize_model()
elif method == "LightCRNN":
from lightCRNN_part.lightcrnn_interface import LPRNinitialize_model
LPRNinitialize_model()
elif method == "OCR":
from OCR_part.ocr_interface import LPRNinitialize_model
LPRNinitialize_model()
# 如果当前有显示的帧,重新处理以更新识别结果
if self.current_frame is not None:
self.process_frame(self.current_frame)
def closeEvent(self, event):
"""窗口关闭事件"""
if self.camera_thread:
if self.camera_thread and self.camera_thread.running:
self.camera_thread.stop_camera()
if self.video_thread and self.video_thread.running:
self.video_thread.stop_video()
event.accept()
def main():

View File

@ -2,6 +2,7 @@ import cv2
import numpy as np
from ultralytics import YOLO
import os
from PIL import Image, ImageDraw, ImageFont
class LicensePlateYOLO:
"""
@ -113,19 +114,38 @@ class LicensePlateYOLO:
print(f"检测过程中出错: {e}")
return []
def draw_detections(self, image, detections):
def draw_detections(self, image, detections, plate_numbers=None):
"""
在图像上绘制检测结果
参数:
image: 输入图像
detections: 检测结果列表
plate_numbers: 车牌号列表与detections对应
返回:
numpy.ndarray: 绘制了检测结果的图像
"""
draw_image = image.copy()
# 转换为PIL图像以支持中文字符
pil_image = Image.fromarray(cv2.cvtColor(draw_image, cv2.COLOR_BGR2RGB))
draw = ImageDraw.Draw(pil_image)
# 尝试加载中文字体
try:
# Windows系统常见的中文字体
font_path = "C:/Windows/Fonts/simhei.ttf" # 黑体
if not os.path.exists(font_path):
font_path = "C:/Windows/Fonts/msyh.ttc" # 微软雅黑
if not os.path.exists(font_path):
font_path = "C:/Windows/Fonts/simsun.ttc" # 宋体
font = ImageFont.truetype(font_path, 20)
except:
# 如果无法加载字体,使用默认字体
font = ImageFont.load_default()
for i, detection in enumerate(detections):
box = detection['box']
keypoints = detection['keypoints']
@ -133,6 +153,11 @@ class LicensePlateYOLO:
confidence = detection['confidence']
incomplete = detection.get('incomplete', False)
# 获取对应的车牌号
plate_number = ""
if plate_numbers and i < len(plate_numbers):
plate_number = plate_numbers[i]
# 绘制边界框
x1, y1, x2, y2 = map(int, box)
@ -140,30 +165,53 @@ class LicensePlateYOLO:
if class_name == '绿牌':
box_color = (0, 255, 0) # 绿色
elif class_name == '蓝牌':
box_color = (255, 0, 0) # 蓝色
box_color = (0, 0, 255) # 蓝色
else:
box_color = (128, 128, 128) # 灰色
cv2.rectangle(draw_image, (x1, y1), (x2, y2), box_color, 2)
# 在PIL图像上绘制边界框
draw.rectangle([(x1, y1), (x2, y2)], outline=box_color, width=2)
# 绘制标签
# 构建标签文本
if plate_number:
label = f"{class_name} {plate_number} {confidence:.2f}"
else:
label = f"{class_name} {confidence:.2f}"
if incomplete:
label += " (不完整)"
# 计算文本大小和位置
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.6
thickness = 2
(text_width, text_height), _ = cv2.getTextSize(label, font, font_scale, thickness)
# 计算文本大小
bbox = draw.textbbox((0, 0), label, font=font)
text_width = bbox[2] - bbox[0]
text_height = bbox[3] - bbox[1]
# 绘制文本背景
cv2.rectangle(draw_image, (x1, y1 - text_height - 10),
(x1 + text_width, y1), box_color, -1)
draw.rectangle([(x1, y1 - text_height - 10), (x1 + text_width, y1)],
fill=box_color)
# 绘制文本
cv2.putText(draw_image, label, (x1, y1 - 5),
font, font_scale, (255, 255, 255), thickness)
draw.text((x1, y1 - text_height - 5), label, fill=(255, 255, 255), font=font)
# 转换回OpenCV格式
draw_image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
# 绘制关键点和连线使用OpenCV
for i, detection in enumerate(detections):
box = detection['box']
keypoints = detection['keypoints']
incomplete = detection.get('incomplete', False)
x1, y1, x2, y2 = map(int, box)
# 根据车牌类型选择颜色
class_name = detection['class_name']
if class_name == '绿牌':
box_color = (0, 255, 0) # 绿色
elif class_name == '蓝牌':
box_color = (0, 0, 255) # 蓝色
else:
box_color = (128, 128, 128) # 灰色
# 绘制关键点和连线
if len(keypoints) >= 4 and not incomplete: