diff --git a/CRNN_part/best_model.pth b/CRNN_part/best_model.pth index 4054755..4ef32ce 100644 Binary files a/CRNN_part/best_model.pth and b/CRNN_part/best_model.pth differ diff --git a/OCR_part/ocr_interface.py b/OCR_part/ocr_interface.py index b98c5b8..75770c0 100644 --- a/OCR_part/ocr_interface.py +++ b/OCR_part/ocr_interface.py @@ -5,6 +5,18 @@ import cv2 class OCRProcessor: def __init__(self): self.model = TextRecognition(model_name="PP-OCRv5_server_rec") + # 定义允许的字符集合(不包含空白字符) + self.allowed_chars = [ + # 中文省份简称 + '京', '沪', '津', '渝', '冀', '晋', '蒙', '辽', '吉', '黑', + '苏', '浙', '皖', '闽', '赣', '鲁', '豫', '鄂', '湘', '粤', + '桂', '琼', '川', '贵', '云', '藏', '陕', '甘', '青', '宁', '新', + # 字母 A-Z + 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', + 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', + # 数字 0-9 + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9' + ] print("OCR模型初始化完成(占位)") def predict(self, image_array): @@ -14,6 +26,14 @@ class OCRProcessor: results = output[0]["rec_text"] placeholder_result = results.split(',') return placeholder_result + + def filter_allowed_chars(self, text): + """只保留允许的字符""" + filtered_text = "" + for char in text: + if char in self.allowed_chars: + filtered_text += char + return filtered_text # 保留原有函数接口 _processor = OCRProcessor() @@ -42,8 +62,12 @@ def LPRNmodel_predict(image_array): else: result_str = str(raw_result) - # 过滤掉'·'字符 + # 过滤掉'·'和'-'字符 filtered_str = result_str.replace('·', '') + filtered_str = filtered_str.replace('-', '') + + # 只保留允许的字符 + filtered_str = _processor.filter_allowed_chars(filtered_str) # 转换为字符列表 char_list = list(filtered_str) diff --git a/lightCRNN_part/best_model.pth b/lightCRNN_part/best_model.pth new file mode 100644 index 0000000..8a575dc Binary files /dev/null and b/lightCRNN_part/best_model.pth differ diff --git a/lightCRNN_part/lightcrnn_interface.py b/lightCRNN_part/lightcrnn_interface.py new file mode 100644 index 0000000..d7ca712 --- /dev/null +++ b/lightCRNN_part/lightcrnn_interface.py @@ -0,0 +1,546 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from PIL import Image +import cv2 +from torchvision import transforms +import os +import math + +# 全局变量 +lightcrnn_model = None +lightcrnn_decoder = None +lightcrnn_preprocessor = None +device = None + +class DepthwiseSeparableConv(nn.Module): + """深度可分离卷积""" + + def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1): + super(DepthwiseSeparableConv, self).__init__() + # 深度卷积 + self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, + stride=stride, padding=padding, groups=in_channels, bias=False) + # 逐点卷积 + self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) + self.bn = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU6(inplace=True) + + def forward(self, x): + x = self.depthwise(x) + x = self.pointwise(x) + x = self.bn(x) + x = self.relu(x) + return x + +class ChannelAttention(nn.Module): + """通道注意力机制""" + + def __init__(self, in_channels, reduction=16): + super(ChannelAttention, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.max_pool = nn.AdaptiveMaxPool2d(1) + + self.fc = nn.Sequential( + nn.Conv2d(in_channels, in_channels // reduction, 1, bias=False), + nn.ReLU(inplace=True), + nn.Conv2d(in_channels // reduction, in_channels, 1, bias=False) + ) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + avg_out = self.fc(self.avg_pool(x)) + max_out = self.fc(self.max_pool(x)) + out = avg_out + max_out + return x * self.sigmoid(out) + +class InvertedResidual(nn.Module): + """MobileNetV2的倒残差块""" + + def __init__(self, in_channels, out_channels, stride=1, expand_ratio=6): + super(InvertedResidual, self).__init__() + self.stride = stride + self.use_residual = stride == 1 and in_channels == out_channels + + hidden_dim = int(round(in_channels * expand_ratio)) + + layers = [] + if expand_ratio != 1: + # 扩展层 + layers.extend([ + nn.Conv2d(in_channels, hidden_dim, 1, bias=False), + nn.BatchNorm2d(hidden_dim), + nn.ReLU6(inplace=True) + ]) + + # 深度卷积 + layers.extend([ + nn.Conv2d(hidden_dim, hidden_dim, 3, stride=stride, padding=1, groups=hidden_dim, bias=False), + nn.BatchNorm2d(hidden_dim), + nn.ReLU6(inplace=True), + # 线性瓶颈 + nn.Conv2d(hidden_dim, out_channels, 1, bias=False), + nn.BatchNorm2d(out_channels) + ]) + + self.conv = nn.Sequential(*layers) + + def forward(self, x): + if self.use_residual: + return x + self.conv(x) + else: + return self.conv(x) + +class LightweightCNN(nn.Module): + """增强版轻量化CNN特征提取器""" + + def __init__(self, num_channels=3): + super(LightweightCNN, self).__init__() + + # 初始卷积层 - 适当增加通道数 + self.conv1 = nn.Sequential( + nn.Conv2d(num_channels, 48, kernel_size=3, stride=1, padding=1, bias=False), + nn.BatchNorm2d(48), + nn.ReLU6(inplace=True) + ) + + # 增强版MobileNet风格的特征提取 + self.features = nn.Sequential( + # 第一组:48 -> 32 + InvertedResidual(48, 32, stride=1, expand_ratio=2), + InvertedResidual(32, 32, stride=1, expand_ratio=2), # 增加一层 + nn.MaxPool2d(kernel_size=2, stride=2), # 32x128 -> 16x64 + + # 第二组:32 -> 48 + InvertedResidual(32, 48, stride=1, expand_ratio=4), + InvertedResidual(48, 48, stride=1, expand_ratio=4), + nn.MaxPool2d(kernel_size=2, stride=2), # 16x64 -> 8x32 + + # 第三组:48 -> 64 + InvertedResidual(48, 64, stride=1, expand_ratio=4), + InvertedResidual(64, 64, stride=1, expand_ratio=4), + + # 第四组:64 -> 96 + InvertedResidual(64, 96, stride=1, expand_ratio=4), + InvertedResidual(96, 96, stride=1, expand_ratio=4), + nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1)), # 8x32 -> 4x32 + + # 第五组:96 -> 128 + InvertedResidual(96, 128, stride=1, expand_ratio=4), + InvertedResidual(128, 128, stride=1, expand_ratio=4), + nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1)), # 4x32 -> 2x32 + + # 最后的卷积层 - 增加通道数 + nn.Conv2d(128, 160, kernel_size=2, stride=1, padding=0, bias=False), # 2x32 -> 1x31 + nn.BatchNorm2d(160), + nn.ReLU6(inplace=True) + ) + + # 通道注意力 + self.channel_attention = ChannelAttention(160) + + def forward(self, x): + x = self.conv1(x) + x = self.features(x) + x = self.channel_attention(x) + return x + +class LightweightGRU(nn.Module): + """增强版轻量化GRU层""" + + def __init__(self, input_size, hidden_size, num_layers=2): # 默认增加到2层 + super(LightweightGRU, self).__init__() + self.gru = nn.GRU(input_size, hidden_size, num_layers=num_layers, + bidirectional=True, batch_first=True, dropout=0.2 if num_layers > 1 else 0) + # 增加一个额外的线性层 + self.linear1 = nn.Linear(hidden_size * 2, hidden_size * 2) + self.linear2 = nn.Linear(hidden_size * 2, hidden_size) + self.dropout = nn.Dropout(0.2) # 增加dropout率 + self.norm = nn.LayerNorm(hidden_size) # 添加层归一化 + + def forward(self, x): + gru_out, _ = self.gru(x) + output = self.linear1(gru_out) + output = F.relu(output) # 添加激活函数 + output = self.dropout(output) + output = self.linear2(output) + output = self.norm(output) # 应用层归一化 + output = self.dropout(output) + return output + +class LightweightCRNN(nn.Module): + """增强版轻量化CRNN模型""" + + def __init__(self, img_height, num_classes, num_channels=3, hidden_size=160): # 调整隐藏层大小 + super(LightweightCRNN, self).__init__() + + self.img_height = img_height + self.num_classes = num_classes + self.hidden_size = hidden_size + + # 增强版轻量化CNN特征提取器 + self.cnn = LightweightCNN(num_channels) + + # 增强版轻量化RNN序列建模器 + self.rnn = LightweightGRU(160, hidden_size, num_layers=2) # 使用更大的输入尺寸和2层GRU + + # 输出层 - 添加额外的全连接层 + self.fc = nn.Linear(hidden_size, hidden_size // 2) + self.dropout = nn.Dropout(0.2) + self.classifier = nn.Linear(hidden_size // 2, num_classes) + + # 初始化权重 + self._initialize_weights() + + def _initialize_weights(self): + """初始化模型权重""" + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, input): + """ + input: [batch_size, channels, height, width] + output: [seq_len, batch_size, num_classes] + """ + # CNN特征提取 + conv_features = self.cnn(input) # [batch_size, 160, 1, seq_len] + + # 重塑为RNN输入格式 + batch_size, channels, height, width = conv_features.size() + assert height == 1, f"Height should be 1, got {height}" + + # [batch_size, 160, 1, seq_len] -> [batch_size, seq_len, 160] + conv_features = conv_features.squeeze(2) # [batch_size, 160, seq_len] + conv_features = conv_features.permute(0, 2, 1) # [batch_size, seq_len, 160] + + # RNN序列建模 + rnn_output = self.rnn(conv_features) # [batch_size, seq_len, hidden_size] + + # 全连接层处理 + fc_output = self.fc(rnn_output) # [batch_size, seq_len, hidden_size//2] + fc_output = F.relu(fc_output) + fc_output = self.dropout(fc_output) + + # 分类 + output = self.classifier(fc_output) # [batch_size, seq_len, num_classes] + + # 转换为CTC期望的格式: [seq_len, batch_size, num_classes] + output = output.permute(1, 0, 2) + + return output + +class LightCTCDecoder: + """轻量化CTC解码器""" + def __init__(self): + # 中国车牌字符集 + # 省份简称 + provinces = ['京', '津', '沪', '渝', '冀', '豫', '云', '辽', '黑', '湘', '皖', '鲁', + '新', '苏', '浙', '赣', '鄂', '桂', '甘', '晋', '蒙', '陕', '吉', '闽', + '贵', '粤', '青', '藏', '川', '宁', '琼'] + + # 字母(包含I和O) + letters = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', + 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'] + + # 数字 + digits = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] + + # 组合所有字符 + self.character = provinces + letters + digits + + # 添加空白字符用于CTC + self.character = ['[blank]'] + self.character + + # 创建字符到索引的映射 + self.dict = {char: i for i, char in enumerate(self.character)} + self.dict_reverse = {i: char for i, char in enumerate(self.character)} + + self.num_classes = len(self.character) + self.blank_idx = 0 + + def decode_greedy(self, predictions): + """贪婪解码""" + # 获取每个时间步的最大概率索引 + indices = torch.argmax(predictions, dim=1) + + # CTC解码:移除重复字符和空白字符 + decoded_chars = [] + prev_idx = -1 + + for idx in indices: + idx = idx.item() + if idx != prev_idx and idx != self.blank_idx: + if idx < len(self.character): + decoded_chars.append(self.character[idx]) + prev_idx = idx + + return ''.join(decoded_chars) + + def decode_with_confidence(self, predictions): + """解码并返回置信度信息""" + # 应用softmax获得概率 + probs = torch.softmax(predictions, dim=1) + + # 贪婪解码 + indices = torch.argmax(probs, dim=1) + max_probs = torch.max(probs, dim=1)[0] + + # CTC解码 + decoded_chars = [] + char_confidences = [] + prev_idx = -1 + + for i, idx in enumerate(indices): + idx = idx.item() + confidence = max_probs[i].item() + + if idx != prev_idx and idx != self.blank_idx: + if idx < len(self.character): + decoded_chars.append(self.character[idx]) + char_confidences.append(confidence) + prev_idx = idx + + text = ''.join(decoded_chars) + avg_confidence = np.mean(char_confidences) if char_confidences else 0.0 + + return text, avg_confidence, char_confidences + +class LightLicensePlatePreprocessor: + """轻量化车牌图像预处理器""" + def __init__(self, target_height=32, target_width=128): + self.target_height = target_height + self.target_width = target_width + + # 定义图像变换 + self.transform = transforms.Compose([ + transforms.Resize((target_height, target_width)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]) + + def preprocess_numpy_array(self, image_array): + """预处理numpy数组格式的图像""" + try: + # 确保图像是RGB格式 + if len(image_array.shape) == 3 and image_array.shape[2] == 3: + # 如果是BGR格式,转换为RGB + if image_array.dtype == np.uint8: + image_array = cv2.cvtColor(image_array, cv2.COLOR_BGR2RGB) + + # 转换为PIL图像 + if image_array.dtype != np.uint8: + image_array = (image_array * 255).astype(np.uint8) + + image = Image.fromarray(image_array) + + # 应用变换 + tensor = self.transform(image) + + # 添加batch维度 + tensor = tensor.unsqueeze(0) + + return tensor + + except Exception as e: + print(f"图像预处理失败: {e}") + return None + +def LPRNinitialize_model(): + """ + 初始化轻量化CRNN模型 + + 返回: + bool: 初始化是否成功 + """ + global lightcrnn_model, lightcrnn_decoder, lightcrnn_preprocessor, device + + try: + # 设置设备 + device = 'cuda' if torch.cuda.is_available() else 'cpu' + print(f"LightCRNN使用设备: {device}") + + # 初始化组件 + lightcrnn_decoder = LightCTCDecoder() + lightcrnn_preprocessor = LightLicensePlatePreprocessor(target_height=32, target_width=128) + + # 创建模型实例 + lightcrnn_model = LightweightCRNN( + img_height=32, + num_classes=lightcrnn_decoder.num_classes, + hidden_size=160 + ) + + # 加载模型权重 + model_path = os.path.join(os.path.dirname(__file__), 'best_model.pth') + + if not os.path.exists(model_path): + raise FileNotFoundError(f"模型文件不存在: {model_path}") + + print(f"正在加载LightCRNN模型: {model_path}") + + # 加载检查点,处理可能的模块依赖问题 + try: + checkpoint = torch.load(model_path, map_location=device, weights_only=False) + except (ModuleNotFoundError, AttributeError) as e: + if 'config' in str(e) or 'Config' in str(e): + print("检测到模型文件包含config依赖,尝试使用weights_only模式加载...") + try: + # 尝试使用weights_only=True来避免pickle问题 + checkpoint = torch.load(model_path, map_location=device, weights_only=True) + except Exception: + # 如果还是失败,创建一个更完整的mock config + import sys + import types + + # 创建mock config模块 + mock_config = types.ModuleType('config') + + # 添加可能需要的Config类 + class Config: + def __init__(self): + pass + + mock_config.Config = Config + sys.modules['config'] = mock_config + + try: + checkpoint = torch.load(model_path, map_location=device, weights_only=False) + finally: + # 清理临时模块 + if 'config' in sys.modules: + del sys.modules['config'] + else: + raise e + + # 处理不同的模型保存格式 + if isinstance(checkpoint, dict): + if 'model_state_dict' in checkpoint: + # 完整检查点格式 + state_dict = checkpoint['model_state_dict'] + print(f"检查点信息:") + print(f" - 训练轮次: {checkpoint.get('epoch', 'N/A')}") + print(f" - 最佳验证损失: {checkpoint.get('best_val_loss', 'N/A')}") + else: + # 精简模型格式(只包含权重) + print("加载精简模型(仅权重)") + state_dict = checkpoint + else: + # 直接是状态字典 + state_dict = checkpoint + + # 加载权重 + lightcrnn_model.load_state_dict(state_dict) + lightcrnn_model.to(device) + lightcrnn_model.eval() + + print("LightCRNN模型初始化完成") + + # 统计模型参数 + total_params = sum(p.numel() for p in lightcrnn_model.parameters()) + print(f"LightCRNN模型参数数量: {total_params:,}") + + return True + + except Exception as e: + print(f"LightCRNN模型初始化失败: {e}") + import traceback + traceback.print_exc() + return False + +def LPRNmodel_predict(image_array): + """ + 轻量化CRNN车牌号识别接口函数 + + 参数: + image_array: numpy数组格式的车牌图像,已经过矫正处理 + + 返回: + list: 包含最多8个字符的列表,代表车牌号的每个字符 + 例如: ['京', 'A', '1', '2', '3', '4', '5', ''] (蓝牌7位+占位符) + ['京', 'A', 'D', '1', '2', '3', '4', '5'] (绿牌8位) + """ + global lightcrnn_model, lightcrnn_decoder, lightcrnn_preprocessor, device + + if lightcrnn_model is None or lightcrnn_decoder is None or lightcrnn_preprocessor is None: + print("LightCRNN模型未初始化,请先调用LPRNinitialize_model()") + return ['待', '识', '别', '0', '0', '0', '0', '0'] + + try: + # 预处理图像 + input_tensor = lightcrnn_preprocessor.preprocess_numpy_array(image_array) + if input_tensor is None: + raise ValueError("图像预处理失败") + + input_tensor = input_tensor.to(device) + + # 模型推理 + with torch.no_grad(): + outputs = lightcrnn_model(input_tensor) # (seq_len, batch_size, num_classes) + + # 移除batch维度 + outputs = outputs.squeeze(1) # (seq_len, num_classes) + + # CTC解码 + predicted_text, confidence, char_confidences = lightcrnn_decoder.decode_with_confidence(outputs) + + print(f"LightCRNN识别结果: {predicted_text}, 置信度: {confidence:.3f}") + + # 将字符串转换为字符列表 + char_list = list(predicted_text) + + # 确保返回至少7个字符,最多8个字符 + if len(char_list) < 7: + # 如果识别结果少于7个字符,用'0'补齐到7位 + char_list.extend(['0'] * (7 - len(char_list))) + elif len(char_list) > 8: + # 如果识别结果多于8个字符,截取前8个 + char_list = char_list[:8] + + # 如果是7位,补齐到8位以保持接口一致性(第8位用空字符或占位符) + if len(char_list) == 7: + char_list.append('') # 添加空字符作为第8位占位符 + + return char_list + + except Exception as e: + print(f"LightCRNN识别失败: {e}") + import traceback + traceback.print_exc() + return ['识', '别', '失', '败', '0', '0', '0', '0'] + +def create_lightweight_model(model_type='lightweight_crnn', img_height=32, num_classes=66, hidden_size=160): + """创建增强版轻量化模型""" + if model_type == 'lightweight_crnn': + return LightweightCRNN(img_height, num_classes, hidden_size=hidden_size) + else: + raise ValueError(f"Unknown lightweight model type: {model_type}") + +if __name__ == "__main__": + # 测试轻量化模型 + print("测试LightCRNN模型...") + + # 初始化模型 + success = LPRNinitialize_model() + if success: + print("模型初始化成功") + + # 创建测试输入 + test_input = np.random.randint(0, 255, (32, 128, 3), dtype=np.uint8) + + # 测试预测 + result = LPRNmodel_predict(test_input) + print(f"测试预测结果: {result}") + else: + print("模型初始化失败") \ No newline at end of file diff --git a/main.py b/main.py index bf0f7dd..9b26bf7 100644 --- a/main.py +++ b/main.py @@ -1,25 +1,207 @@ import sys +import os import cv2 import numpy as np -from PyQt5.QtWidgets import ( - QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout, - QLabel, QPushButton, QScrollArea, QFrame, QSizePolicy,QFileDialog -) +from collections import defaultdict, deque +from PyQt5.QtWidgets import QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, \ + QFileDialog, QFrame, QScrollArea, QComboBox from PyQt5.QtCore import QTimer, Qt, pyqtSignal, QThread from PyQt5.QtGui import QImage, QPixmap, QFont, QPainter, QPen, QColor -import os from yolopart.detector import LicensePlateYOLO #选择使用哪个模块 -from LPRNET_part.lpr_interface import LPRNmodel_predict -from LPRNET_part.lpr_interface import LPRNinitialize_model - +# from LPRNET_part.lpr_interface import LPRNmodel_predict +# from LPRNET_part.lpr_interface import LPRNinitialize_model #使用OCR -#from OCR_part.ocr_interface import LPRNmodel_predict -#from OCR_part.ocr_interface import LPRNinitialize_model +# from OCR_part.ocr_interface import LPRNmodel_predict +# from OCR_part.ocr_interface import LPRNinitialize_model # 使用CRNN -#from CRNN_part.crnn_interface import LPRNmodel_predict -#from CRNN_part.crnn_interface import LPRNinitialize_model +# from CRNN_part.crnn_interface import LPRNmodel_predict +# from CRNN_part.crnn_interface import LPRNinitialize_model + +class PlateStabilizer: + """车牌识别结果稳定器""" + + def __init__(self, history_size=10, confidence_threshold=0.6, stability_frames=5): + self.history_size = history_size # 历史帧数量 + self.confidence_threshold = confidence_threshold # 置信度阈值 + self.stability_frames = stability_frames # 稳定帧数要求 + + # 存储每个车牌的历史识别结果 + self.plate_histories = defaultdict(lambda: deque(maxlen=history_size)) + # 存储当前稳定的车牌结果 + self.stable_results = {} + # 车牌ID计数器 + self.plate_id_counter = 0 + # 车牌位置追踪 + self.plate_positions = {} + + def calculate_plate_distance(self, pos1, pos2): + """计算两个车牌位置的距离""" + if pos1 is None or pos2 is None: + return float('inf') + + # 计算中心点距离 + center1 = ((pos1[0] + pos1[2]) / 2, (pos1[1] + pos1[3]) / 2) + center2 = ((pos2[0] + pos2[2]) / 2, (pos2[1] + pos2[3]) / 2) + + return np.sqrt((center1[0] - center2[0])**2 + (center1[1] - center2[1])**2) + + def match_plates_to_history(self, current_detections): + """将当前检测结果匹配到历史记录""" + matched_plates = {} + used_ids = set() + + for detection in current_detections: + bbox = detection.get('bbox', [0, 0, 0, 0]) + best_match_id = None + min_distance = float('inf') + + # 寻找最佳匹配的历史车牌 + for plate_id, last_pos in self.plate_positions.items(): + if plate_id in used_ids: + continue + + distance = self.calculate_plate_distance(bbox, last_pos) + if distance < min_distance and distance < 100: # 距离阈值 + min_distance = distance + best_match_id = plate_id + + if best_match_id is not None: + matched_plates[best_match_id] = detection + used_ids.add(best_match_id) + self.plate_positions[best_match_id] = bbox + else: + # 创建新的车牌ID + new_id = f"plate_{self.plate_id_counter}" + self.plate_id_counter += 1 + matched_plates[new_id] = detection + self.plate_positions[new_id] = bbox + + return matched_plates + + def calculate_confidence(self, plate_text, detection_quality=1.0): + """计算识别结果的置信度""" + if not plate_text or plate_text == "识别失败": + return 0.0 + + # 基础置信度基于文本长度和字符类型 + base_confidence = 0.5 + + # 长度合理性检查 + if 7 <= len(plate_text) <= 8: + base_confidence += 0.2 + + # 字符类型检查(中文+字母+数字的组合) + has_chinese = any('\u4e00' <= char <= '\u9fff' for char in plate_text) + has_letter = any(char.isalpha() for char in plate_text) + has_digit = any(char.isdigit() for char in plate_text) + + if has_chinese and has_letter and has_digit: + base_confidence += 0.2 + + # 检测质量影响 + confidence = base_confidence * detection_quality + + return min(confidence, 1.0) + + def update_and_get_stable_result(self, current_detections, corrected_images, plate_texts): + """更新历史记录并返回稳定的识别结果""" + if not current_detections: + return [] + + # 匹配当前检测到历史记录 + matched_plates = self.match_plates_to_history(current_detections) + + stable_results = [] + + for plate_id, detection in matched_plates.items(): + # 获取对应的矫正图像和识别文本 + detection_idx = current_detections.index(detection) + corrected_image = corrected_images[detection_idx] if detection_idx < len(corrected_images) else None + plate_text = plate_texts[detection_idx] if detection_idx < len(plate_texts) else "识别失败" + + # 计算置信度 + confidence = self.calculate_confidence(plate_text) + + # 添加到历史记录 + history_entry = { + 'text': plate_text, + 'confidence': confidence, + 'detection': detection, + 'corrected_image': corrected_image + } + self.plate_histories[plate_id].append(history_entry) + + # 计算稳定结果 + stable_text = self.get_stable_text(plate_id) + + if stable_text and stable_text != "识别失败": + stable_results.append({ + 'id': plate_id, + 'class_name': detection['class_name'], + 'corrected_image': corrected_image, + 'plate_number': stable_text, + 'detection': detection + }) + + return stable_results + + def get_stable_text(self, plate_id): + """获取指定车牌的稳定识别结果""" + history = self.plate_histories[plate_id] + + if len(history) < 3: # 历史记录太少,返回最新结果 + return history[-1]['text'] if history else "识别失败" + + # 统计各种识别结果的加权投票 + text_votes = defaultdict(float) + total_confidence = 0 + + for entry in history: + text = entry['text'] + confidence = entry['confidence'] + + if text != "识别失败" and confidence > 0.3: + text_votes[text] += confidence + total_confidence += confidence + + if not text_votes: + return "识别失败" + + # 找到得票最高的结果 + best_text = max(text_votes.items(), key=lambda x: x[1]) + + # 检查是否足够稳定(得票率超过阈值) + vote_ratio = best_text[1] / total_confidence if total_confidence > 0 else 0 + + if vote_ratio >= self.confidence_threshold: + return best_text[0] + else: + # 不够稳定,返回最近的高置信度结果 + recent_high_conf = [entry for entry in list(history)[-5:] + if entry['confidence'] > 0.5 and entry['text'] != "识别失败"] + + if recent_high_conf: + return recent_high_conf[-1]['text'] + else: + return history[-1]['text'] + + def clear_old_plates(self, current_plate_ids): + """清理不再出现的车牌历史记录""" + # 移除超过一定时间未更新的车牌 + plates_to_remove = [] + for plate_id in self.plate_histories.keys(): + if plate_id not in current_plate_ids: + plates_to_remove.append(plate_id) + + for plate_id in plates_to_remove: + if plate_id in self.plate_histories: + del self.plate_histories[plate_id] + if plate_id in self.plate_positions: + del self.plate_positions[plate_id] + if plate_id in self.stable_results: + del self.stable_results[plate_id] class CameraThread(QThread): """摄像头线程类""" @@ -56,6 +238,60 @@ class CameraThread(QThread): self.frame_ready.emit(frame) self.msleep(30) # 约30fps +class VideoThread(QThread): + """视频处理线程类""" + frame_ready = pyqtSignal(np.ndarray) + video_finished = pyqtSignal() + + def __init__(self): + super().__init__() + self.video_path = None + self.cap = None + self.running = False + self.paused = False + + def load_video(self, video_path): + """加载视频文件""" + self.video_path = video_path + self.cap = cv2.VideoCapture(video_path) + return self.cap.isOpened() + + def start_video(self): + """开始播放视频""" + if self.cap and self.cap.isOpened(): + self.running = True + self.paused = False + self.start() + return True + return False + + def pause_video(self): + """暂停/继续视频""" + self.paused = not self.paused + return self.paused + + def stop_video(self): + """停止视频""" + self.running = False + if self.cap: + self.cap.release() + self.quit() + self.wait() + + def run(self): + """线程运行函数""" + while self.running: + if not self.paused and self.cap and self.cap.isOpened(): + ret, frame = self.cap.read() + if ret: + self.frame_ready.emit(frame) + else: + # 视频播放结束 + self.video_finished.emit() + self.running = False + break + self.msleep(30) # 约30fps + class LicensePlateWidget(QWidget): """单个车牌结果显示组件""" @@ -67,6 +303,7 @@ class LicensePlateWidget(QWidget): def init_ui(self, class_name, corrected_image, plate_number): layout = QHBoxLayout() layout.setContentsMargins(10, 5, 10, 5) + layout.setSpacing(8) # 设置组件间距 # 车牌类型标签 type_label = QLabel(class_name) @@ -104,7 +341,6 @@ class LicensePlateWidget(QWidget): # 矫正后的车牌图像 image_label = QLabel() - image_label.setFixedSize(120, 40) image_label.setStyleSheet("border: 1px solid #ddd; background-color: white;") if corrected_image is not None: @@ -118,16 +354,44 @@ class LicensePlateWidget(QWidget): q_image = QImage(corrected_image.data, w, h, bytes_per_line, QImage.Format_Grayscale8) pixmap = QPixmap.fromImage(q_image) - scaled_pixmap = pixmap.scaled(120, 40, Qt.KeepAspectRatio, Qt.SmoothTransformation) + + # 动态计算显示尺寸,保持车牌的宽高比 + original_width = pixmap.width() + original_height = pixmap.height() + + # 设置最大显示尺寸限制 + max_width = 150 + max_height = 60 + + # 计算缩放比例,确保图像完整显示 + width_ratio = max_width / original_width if original_width > 0 else 1 + height_ratio = max_height / original_height if original_height > 0 else 1 + scale_ratio = min(width_ratio, height_ratio, 1.0) # 不放大,只缩小 + + # 计算实际显示尺寸 + display_width = int(original_width * scale_ratio) + display_height = int(original_height * scale_ratio) + + # 确保最小显示尺寸 + display_width = max(display_width, 80) + display_height = max(display_height, 25) + + # 设置标签尺寸并缩放图像 + image_label.setFixedSize(display_width, display_height) + scaled_pixmap = pixmap.scaled(display_width, display_height, Qt.KeepAspectRatio, Qt.SmoothTransformation) image_label.setPixmap(scaled_pixmap) + image_label.setAlignment(Qt.AlignCenter) else: + # 当没有图像时,设置固定尺寸显示提示信息 + image_label.setFixedSize(120, 40) image_label.setText("车牌未完全\n进入摄像头") image_label.setAlignment(Qt.AlignCenter) image_label.setStyleSheet("border: 1px solid #ddd; background-color: #f5f5f5; color: #666;") - # 车牌号标签 + # 车牌号标签 - 使用自适应宽度 number_label = QLabel(plate_number) - number_label.setFixedWidth(150) + number_label.setMinimumWidth(120) # 设置最小宽度 + number_label.setMaximumWidth(200) # 设置最大宽度 number_label.setAlignment(Qt.AlignCenter) number_label.setStyleSheet( "QLabel { " @@ -139,6 +403,11 @@ class LicensePlateWidget(QWidget): "font-weight: bold; " "}" ) + # 根据文本长度调整宽度 + font_metrics = number_label.fontMetrics() + text_width = font_metrics.boundingRect(plate_number).width() + optimal_width = max(120, min(200, text_width + 20)) # 加20像素的边距 + number_label.setFixedWidth(optimal_width) layout.addWidget(type_label) layout.addWidget(image_label) @@ -146,6 +415,9 @@ class LicensePlateWidget(QWidget): layout.addStretch() self.setLayout(layout) + # 调整整体组件的最小高度以适应动态图像尺寸 + min_height = max(60, image_label.height() + 20) # 至少60像素高度 + self.setMinimumHeight(min_height) self.setStyleSheet( "QWidget { " "background-color: white; " @@ -162,15 +434,28 @@ class MainWindow(QMainWindow): super().__init__() self.detector = None self.camera_thread = None + self.video_thread = None self.current_frame = None self.detections = [] + self.current_mode = "camera" # 当前模式:camera, video, image + self.is_processing = False # 标志位,表示是否正在处理识别任务 + self.last_plate_results = [] # 存储上一次的车牌识别结果 + self.current_recognition_method = "CRNN" # 当前识别方法 + + # 添加车牌稳定器 + self.plate_stabilizer = PlateStabilizer( + history_size=15, # 保存15帧历史 + confidence_threshold=0.7, # 70%置信度阈值 + stability_frames=5 # 需要5帧稳定 + ) self.init_ui() self.init_detector() self.init_camera() + self.init_video() - # 初始化OCR/CRNN模型(函数名改成一样的了,所以不要修改这里了,想用哪个模块直接导入) - LPRNinitialize_model() + # 初始化默认识别方法(CRNN)的模型 + self.change_recognition_method(self.current_recognition_method) def init_ui(self): @@ -197,7 +482,7 @@ class MainWindow(QMainWindow): self.camera_label.setStyleSheet("QLabel { background-color: black; border: 1px solid #ccc; }") self.camera_label.setAlignment(Qt.AlignCenter) self.camera_label.setText("摄像头未启动") - self.camera_label.setScaledContents(True) + self.camera_label.setScaledContents(False) # 控制按钮 button_layout = QHBoxLayout() @@ -206,14 +491,27 @@ class MainWindow(QMainWindow): self.start_button.clicked.connect(self.start_camera) self.stop_button.clicked.connect(self.stop_camera) self.stop_button.setEnabled(False) - self.btn_image = QPushButton('选择图片') - self.btn_video = QPushButton('选择视频') - self.btn_image.clicked.connect(self.open_image_file) - self.btn_video.clicked.connect(self.open_video_file) - button_layout.addWidget(self.btn_image) - button_layout.addWidget(self.btn_video) + + # 视频控制按钮 + self.open_video_button = QPushButton("打开视频") + self.stop_video_button = QPushButton("停止视频") + self.pause_video_button = QPushButton("暂停视频") + self.open_video_button.clicked.connect(self.open_video_file) + self.stop_video_button.clicked.connect(self.stop_video) + self.pause_video_button.clicked.connect(self.pause_video) + self.stop_video_button.setEnabled(False) + self.pause_video_button.setEnabled(False) + + # 图片控制按钮 + self.open_image_button = QPushButton("打开图片") + self.open_image_button.clicked.connect(self.open_image_file) + button_layout.addWidget(self.start_button) button_layout.addWidget(self.stop_button) + button_layout.addWidget(self.open_video_button) + button_layout.addWidget(self.stop_video_button) + button_layout.addWidget(self.pause_video_button) + button_layout.addWidget(self.open_image_button) button_layout.addStretch() left_layout.addWidget(self.camera_label) @@ -222,7 +520,7 @@ class MainWindow(QMainWindow): # 右侧结果显示区域 right_frame = QFrame() right_frame.setFrameStyle(QFrame.StyledPanel) - right_frame.setFixedWidth(400) + right_frame.setFixedWidth(460) right_frame.setStyleSheet("QFrame { background-color: #fafafa; border: 2px solid #ddd; }") right_layout = QVBoxLayout(right_frame) @@ -232,6 +530,20 @@ class MainWindow(QMainWindow): title_label.setFont(QFont("Arial", 16, QFont.Bold)) title_label.setStyleSheet("QLabel { color: #333; padding: 10px; }") + # 识别方法选择 + method_layout = QHBoxLayout() + method_label = QLabel("识别方法:") + method_label.setFont(QFont("Arial", 10)) + + self.method_combo = QComboBox() + self.method_combo.addItems(["CRNN", "LightCRNN", "OCR"]) + self.method_combo.setCurrentText("CRNN") # 默认选择CRNN + self.method_combo.currentTextChanged.connect(self.change_recognition_method) + + method_layout.addWidget(method_label) + method_layout.addWidget(self.method_combo) + method_layout.addStretch() + # 车牌数量显示 self.count_label = QLabel("识别到的车牌数量: 0") self.count_label.setAlignment(Qt.AlignCenter) @@ -258,9 +570,17 @@ class MainWindow(QMainWindow): scroll_area.setWidget(self.results_widget) + # 当前识别任务显示 + self.current_method_label = QLabel("当前识别方法: CRNN") + self.current_method_label.setAlignment(Qt.AlignRight) + self.current_method_label.setFont(QFont("Arial", 9)) + self.current_method_label.setStyleSheet("QLabel { color: #666; padding: 5px; }") + right_layout.addWidget(title_label) + right_layout.addLayout(method_layout) right_layout.addWidget(self.count_label) right_layout.addWidget(scroll_area) + right_layout.addWidget(self.current_method_label) # 添加到主布局 main_layout.addWidget(left_frame, 2) @@ -296,14 +616,50 @@ class MainWindow(QMainWindow): model_path = os.path.join(os.path.dirname(__file__), "yolopart", "yolo11s-pose42.pt") self.detector = LicensePlateYOLO(model_path) + def reset_processing_state(self): + """重置处理状态和清理界面""" + # 重置处理标志 + self.is_processing = False + + # 清空当前帧和检测结果 + self.current_frame = None + self.detections = [] + + # 重置车牌稳定器 + self.plate_stabilizer = PlateStabilizer( + history_size=15, + confidence_threshold=0.7, + stability_frames=5 + ) + + # 清空右侧结果显示 + self.count_label.setText("识别到的车牌数量: 0") + for i in reversed(range(self.results_layout.count())): + child = self.results_layout.itemAt(i).widget() + if child: + child.setParent(None) + self.last_plate_results = [] + + print("处理状态已重置,界面已清理") + def init_camera(self): """初始化摄像头线程""" self.camera_thread = CameraThread() self.camera_thread.frame_ready.connect(self.process_frame) + def init_video(self): + """初始化视频线程""" + self.video_thread = VideoThread() + self.video_thread.frame_ready.connect(self.process_frame) + self.video_thread.video_finished.connect(self.on_video_finished) + def start_camera(self): """启动摄像头""" + # 重置处理状态和清理界面 + self.reset_processing_state() + if self.camera_thread.start_camera(): + self.current_mode = "camera" self.start_button.setEnabled(False) self.stop_button.setEnabled(True) self.camera_label.setText("摄像头启动中...") @@ -316,68 +672,371 @@ class MainWindow(QMainWindow): self.start_button.setEnabled(True) self.stop_button.setEnabled(False) self.camera_label.setText("摄像头已停止") - self.camera_label.clear() + # 只在摄像头模式下清除标签内容 + if self.current_mode == "camera": + self.camera_label.clear() + + def on_video_finished(self): + """视频播放结束时的处理""" + self.video_thread.stop_video() + self.open_video_button.setEnabled(True) + self.stop_video_button.setEnabled(False) + self.pause_video_button.setEnabled(False) + self.camera_label.setText("视频播放结束") + self.current_mode = "camera" + + def open_video_file(self): + """打开视频文件""" + # 停止当前模式 + if self.current_mode == "camera" and self.camera_thread and self.camera_thread.running: + self.stop_camera() + elif self.current_mode == "video" and self.video_thread and self.video_thread.running: + self.stop_video() + + # 重置处理状态和清理界面 + self.reset_processing_state() + + # 选择视频文件 + video_path, _ = QFileDialog.getOpenFileName(self, "选择视频文件", "", "视频文件 (*.mp4 *.avi *.mov *.mkv)") + + if video_path: + if self.video_thread.load_video(video_path): + self.current_mode = "video" + self.start_video() + self.camera_label.setText(f"正在播放视频: {os.path.basename(video_path)}") + else: + self.camera_label.setText("视频加载失败") + + def start_video(self): + """开始播放视频""" + if self.video_thread.start_video(): + self.open_video_button.setEnabled(False) + self.stop_video_button.setEnabled(True) + self.pause_video_button.setEnabled(True) + self.pause_video_button.setText("暂停") + else: + self.camera_label.setText("视频播放失败") + + def pause_video(self): + """暂停/继续视频""" + if self.video_thread.pause_video(): + self.pause_video_button.setText("继续") + else: + self.pause_video_button.setText("暂停") + + def stop_video(self): + """停止视频""" + self.video_thread.stop_video() + self.open_video_button.setEnabled(True) + self.stop_video_button.setEnabled(False) + self.pause_video_button.setEnabled(False) + self.camera_label.setText("视频已停止") + # 只在视频模式下清除标签内容 + if self.current_mode == "video": + self.camera_label.clear() + self.current_mode = "camera" + + def open_image_file(self): + """打开图片文件""" + # 停止当前模式 + if self.current_mode == "camera" and self.camera_thread and self.camera_thread.running: + self.stop_camera() + elif self.current_mode == "video" and self.video_thread and self.video_thread.running: + self.stop_video() + + # 重置处理状态和清理界面 + self.reset_processing_state() + + # 选择图片文件 + image_path, _ = QFileDialog.getOpenFileName(self, "选择图片文件", "", "图片文件 (*.jpg *.jpeg *.png *.bmp)") + + if image_path: + self.current_mode = "image" + try: + # 读取图片 - 方法1: 使用cv2.imdecode处理中文路径 + image = cv2.imdecode(np.fromfile(image_path, dtype=np.uint8), cv2.IMREAD_COLOR) + + # 如果方法1失败,尝试方法2: 直接使用cv2.imread + if image is None: + image = cv2.imread(image_path) + + if image is not None: + print(f"成功加载图片: {image_path}, 尺寸: {image.shape}") + self.process_image(image) + # 不在这里设置文本,避免覆盖图片 + # self.camera_label.setText(f"正在显示图片: {os.path.basename(image_path)}") + else: + print(f"图片加载失败: {image_path}") + self.camera_label.setText("图片加载失败") + except Exception as e: + print(f"图片处理异常: {str(e)}") + self.camera_label.setText(f"图片处理错误: {str(e)}") + + def process_image(self, image): + """处理图片""" + try: + print(f"开始处理图片,图片尺寸: {image.shape}") + self.current_frame = image.copy() + + # 进行车牌检测 + print("正在进行车牌检测...") + self.detections = self.detector.detect_license_plates(image) + print(f"检测到 {len(self.detections)} 个车牌") + + # 在图像上绘制检测结果 + print("正在绘制检测结果...") + display_frame = self.draw_detections(image.copy()) + + # 转换为Qt格式并显示 + print("正在显示图片...") + self.display_frame(display_frame) + + # 更新右侧结果显示 + print("正在更新结果显示...") + self.update_results_display() + print("图片处理完成") + except Exception as e: + print(f"图片处理过程中出错: {str(e)}") + import traceback + traceback.print_exc() def process_frame(self, frame): """处理摄像头帧""" + if frame is None: + return + self.current_frame = frame.copy() - # 进行车牌检测 - self.detections = self.detector.detect_license_plates(frame) + # 先显示原始帧,保证视频流畅播放 + self.display_frame(frame) - # 在图像上绘制检测结果 - display_frame = self.draw_detections(frame.copy()) - - # 转换为Qt格式并显示 - self.display_frame(display_frame) - - # 更新右侧结果显示 - self.update_results_display() + # 如果当前没有在处理识别任务,则开始新的识别任务 + if not self.is_processing: + self.is_processing = True + # 异步进行车牌检测和识别 + QTimer.singleShot(0, self.async_detect_and_update) + + def async_detect_and_update(self): + """异步进行车牌检测和识别""" + if self.current_frame is None: + self.is_processing = False # 重置标志位 + return + + try: + # 进行车牌检测 + self.detections = self.detector.detect_license_plates(self.current_frame) + + # 在图像上绘制检测结果 + display_frame = self.draw_detections(self.current_frame.copy()) + + # 更新显示帧(显示带检测结果的帧) + # 无论是摄像头模式还是视频模式,都显示检测框 + self.display_frame(display_frame) + + # 更新右侧结果显示 + self.update_results_display() + except Exception as e: + print(f"异步检测和更新失败: {str(e)}") + import traceback + traceback.print_exc() + finally: + # 无论成功或失败,都要重置标志位 + self.is_processing = False def draw_detections(self, frame): """在图像上绘制检测结果""" - return self.detector.draw_detections(frame, self.detections) + # 获取车牌号列表 + plate_numbers = [] + for detection in self.detections: + # 矫正车牌图像 + corrected_image = self.correct_license_plate(detection) + # 获取车牌号 + if corrected_image is not None: + plate_number = self.recognize_plate_number(corrected_image, detection['class_name']) + plate_numbers.append(plate_number) + else: + plate_numbers.append("识别失败") + + return self.detector.draw_detections(frame, self.detections, plate_numbers) def display_frame(self, frame): """显示帧到界面""" - rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - h, w, ch = rgb_frame.shape - bytes_per_line = ch * w - qt_image = QImage(rgb_frame.data, w, h, bytes_per_line, QImage.Format_RGB888) - - pixmap = QPixmap.fromImage(qt_image) - scaled_pixmap = pixmap.scaled(self.camera_label.size(), Qt.KeepAspectRatio, Qt.SmoothTransformation) - self.camera_label.setPixmap(scaled_pixmap) + try: + print(f"开始显示帧,帧尺寸: {frame.shape}") + + # 方法1: 标准方法 + try: + rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + h, w, ch = rgb_frame.shape + bytes_per_line = ch * w + qt_image = QImage(rgb_frame.data, w, h, bytes_per_line, QImage.Format_RGB888) + + print(f"方法1: 创建QImage,尺寸: {qt_image.width()}x{qt_image.height()}") + if qt_image.isNull(): + print("方法1: QImage为空,尝试方法2") + raise Exception("QImage为空") + + pixmap = QPixmap.fromImage(qt_image) + if pixmap.isNull(): + print("方法1: QPixmap为空,尝试方法2") + raise Exception("QPixmap为空") + + # 手动缩放图片以适应标签大小,保持宽高比 + scaled_pixmap = pixmap.scaled(self.camera_label.size(), Qt.KeepAspectRatio, Qt.SmoothTransformation) + self.camera_label.setPixmap(scaled_pixmap) + print("方法1: 帧显示完成") + return + except Exception as e1: + print(f"方法1失败: {str(e1)}") + + # 方法2: 使用imencode和imdecode + try: + print("尝试方法2: 使用imencode和imdecode") + _, buffer = cv2.imencode('.jpg', frame) + rgb_frame = cv2.imdecode(buffer, cv2.IMREAD_COLOR) + rgb_frame = cv2.cvtColor(rgb_frame, cv2.COLOR_BGR2RGB) + h, w, ch = rgb_frame.shape + bytes_per_line = ch * w + qt_image = QImage(rgb_frame.data, w, h, bytes_per_line, QImage.Format_RGB888) + + print(f"方法2: 创建QImage,尺寸: {qt_image.width()}x{qt_image.height()}") + if qt_image.isNull(): + print("方法2: QImage为空") + raise Exception("QImage为空") + + pixmap = QPixmap.fromImage(qt_image) + if pixmap.isNull(): + print("方法2: QPixmap为空") + raise Exception("QPixmap为空") + + # 手动缩放图片以适应标签大小,保持宽高比 + scaled_pixmap = pixmap.scaled(self.camera_label.size(), Qt.KeepAspectRatio, Qt.SmoothTransformation) + self.camera_label.setPixmap(scaled_pixmap) + print("方法2: 帧显示完成") + return + except Exception as e2: + print(f"方法2失败: {str(e2)}") + + # 方法3: 直接使用QImage的构造函数 + try: + print("尝试方法3: 直接使用QImage的构造函数") + height, width, channel = frame.shape + bytes_per_line = 3 * width + q_image = QImage(frame.data, width, height, bytes_per_line, QImage.Format_BGR888) + + print(f"方法3: 创建QImage,尺寸: {q_image.width()}x{q_image.height()}") + if q_image.isNull(): + print("方法3: QImage为空") + raise Exception("QImage为空") + + pixmap = QPixmap.fromImage(q_image) + if pixmap.isNull(): + print("方法3: QPixmap为空") + raise Exception("QPixmap为空") + + # 手动缩放图片以适应标签大小,保持宽高比 + scaled_pixmap = pixmap.scaled(self.camera_label.size(), Qt.KeepAspectRatio, Qt.SmoothTransformation) + self.camera_label.setPixmap(scaled_pixmap) + print("方法3: 帧显示完成") + return + except Exception as e3: + print(f"方法3失败: {str(e3)}") + + # 所有方法都失败 + print("所有显示方法都失败") + self.camera_label.setText("图片显示失败") + + except Exception as e: + print(f"显示帧过程中出错: {str(e)}") + import traceback + traceback.print_exc() + self.camera_label.setText(f"显示错误: {str(e)}") def update_results_display(self): - """更新右侧结果显示""" - # 更新车牌数量 - count = len(self.detections) - self.count_label.setText(f"识别到的车牌数量: {count}") + """更新右侧结果显示(使用稳定化结果)""" + print(f"开始更新结果显示,当前模式: {self.current_mode}, 检测数量: {len(self.detections) if self.detections else 0}") - # 清除之前的结果 - for i in reversed(range(self.results_layout.count())): - child = self.results_layout.itemAt(i).widget() - if child: - child.setParent(None) + if not self.detections: + self.count_label.setText("识别到的车牌数量: 0") + # 清除显示 + for i in reversed(range(self.results_layout.count())): + child = self.results_layout.itemAt(i).widget() + if child: + child.setParent(None) + self.last_plate_results = [] + print("无检测结果,已清空界面") + return - # 添加新的结果 - for i, detection in enumerate(self.detections): - # 矫正车牌图像 + # 获取矫正图像和识别文本 + corrected_images = [] + plate_texts = [] + + for detection in self.detections: corrected_image = self.correct_license_plate(detection) + corrected_images.append(corrected_image) - # 获取车牌号,传入车牌类型信息 - plate_number = self.recognize_plate_number(corrected_image, detection['class_name']) + if corrected_image is not None: + plate_text = self.recognize_plate_number(corrected_image, detection['class_name']) + else: + plate_text = "识别失败" + plate_texts.append(plate_text) + + # 使用稳定器获取稳定的识别结果 + stable_results = self.plate_stabilizer.update_and_get_stable_result( + self.detections, corrected_images, plate_texts + ) + + # 更新车牌数量显示 + self.count_label.setText(f"识别到的车牌数量: {len(stable_results)}") + print(f"稳定结果数量: {len(stable_results)}") + + # 检查结果是否发生变化 + results_changed = self.check_results_changed(stable_results) + print(f"结果是否变化: {results_changed}") + + if results_changed: + # 清除之前的结果 + for i in reversed(range(self.results_layout.count())): + child = self.results_layout.itemAt(i).widget() + if child: + child.setParent(None) - # 创建车牌显示组件 - plate_widget = LicensePlateWidget( - i + 1, - detection['class_name'], - corrected_image, - plate_number - ) + # 添加新的稳定结果 + for i, result in enumerate(stable_results): + plate_widget = LicensePlateWidget( + i + 1, # 显示序号 + result['class_name'], + result['corrected_image'], + result['plate_number'] + ) + self.results_layout.addWidget(plate_widget) + print(f"添加车牌widget: {result['plate_number']}") - self.results_layout.addWidget(plate_widget) + # 更新存储的结果 + self.last_plate_results = stable_results + + # 清理旧的车牌记录 + current_plate_ids = [result['id'] for result in stable_results] + self.plate_stabilizer.clear_old_plates(current_plate_ids) + print("结果显示更新完成") + + def check_results_changed(self, new_results): + """检查识别结果是否发生变化""" + if len(self.last_plate_results) != len(new_results): + return True + + for i, new_result in enumerate(new_results): + if i >= len(self.last_plate_results): + return True + + old_result = self.last_plate_results[i] + + # 比较关键字段 + if (old_result.get('class_name') != new_result.get('class_name') or + old_result.get('plate_number') != new_result.get('plate_number')): + return True + + return False def correct_license_plate(self, detection): """矫正车牌图像""" @@ -395,77 +1054,70 @@ class MainWindow(QMainWindow): ) def recognize_plate_number(self, corrected_image, class_name): - """识别车牌号""" - if corrected_image is None: - return "识别失败" - - try: - # 预测函数(来自模块) - # 函数名改成一样的了,所以不要修改这里了,想用哪个模块直接导入 - result = LPRNmodel_predict(corrected_image) - - # 将字符列表转换为字符串,支持8位车牌号 - if isinstance(result, list) and len(result) >= 7: - # 根据车牌类型决定显示位数 - if class_name == '绿牌' and len(result) >= 8: - # 绿牌显示8位,过滤掉空字符占位符 - plate_chars = [char for char in result[:8] if char != ''] - # 如果过滤后确实有8位,显示8位;否则显示7位 - if len(plate_chars) == 8: - return ''.join(plate_chars) - else: - return ''.join(plate_chars[:7]) - else: - # 蓝牌或其他类型显示前7位,过滤掉空字符 - plate_chars = [char for char in result[:7] if char != ''] - return ''.join(plate_chars) - else: - return "识别失败" - except Exception as e: - print(f"车牌号识别失败: {e}") - return "识别失败" + """识别车牌号""" + if corrected_image is None: + return "识别失败" + + try: + # 根据当前选择的识别方法调用相应的函数 + if self.current_recognition_method == "CRNN": + from CRNN_part.crnn_interface import LPRNmodel_predict + elif self.current_recognition_method == "LightCRNN": + from lightCRNN_part.lightcrnn_interface import LPRNmodel_predict + elif self.current_recognition_method == "OCR": + from OCR_part.ocr_interface import LPRNmodel_predict + + # 预测函数(来自模块) + result = LPRNmodel_predict(corrected_image) + + # 将字符列表转换为字符串,支持8位车牌号 + if isinstance(result, list) and len(result) >= 7: + # 根据车牌类型决定显示位数 + if class_name == '绿牌' and len(result) >= 8: + # 绿牌显示8位,过滤掉空字符占位符 + plate_chars = [char for char in result[:8] if char != ''] + # 如果过滤后确实有8位,显示8位;否则显示7位 + if len(plate_chars) == 8: + return ''.join(plate_chars) + else: + return ''.join(plate_chars[:7]) + else: + # 蓝牌或其他类型显示前7位,过滤掉空字符 + plate_chars = [char for char in result[:7] if char != ''] + return ''.join(plate_chars) + else: + return "识别失败" + except Exception as e: + print(f"车牌号识别失败: {e}") + return "识别失败" + + def change_recognition_method(self, method): + """切换识别方法""" + self.current_recognition_method = method + self.current_method_label.setText(f"当前识别方法: {method}") + + # 初始化对应的模型 + if method == "CRNN": + from CRNN_part.crnn_interface import LPRNinitialize_model + LPRNinitialize_model() + elif method == "LightCRNN": + from lightCRNN_part.lightcrnn_interface import LPRNinitialize_model + LPRNinitialize_model() + elif method == "OCR": + from OCR_part.ocr_interface import LPRNinitialize_model + LPRNinitialize_model() + + # 如果当前有显示的帧,重新处理以更新识别结果 + if self.current_frame is not None: + self.process_frame(self.current_frame) def closeEvent(self, event): """窗口关闭事件""" - if self.camera_thread: + if self.camera_thread and self.camera_thread.running: self.camera_thread.stop_camera() + if self.video_thread and self.video_thread.running: + self.video_thread.stop_video() event.accept() - - def open_image_file(self): - file_path, _ = QFileDialog.getOpenFileName(self, '选择图片', '', '图片文件 (*.jpg *.png)') - if file_path: - image = cv2.imread(file_path) - self.process_image(image) - - def open_video_file(self): - file_path, _ = QFileDialog.getOpenFileName(self, '选择视频', '', '视频文件 (*.mp4 *.avi)') - if file_path: - self.cap = cv2.VideoCapture(file_path) - self.video_timer = QTimer() - self.video_timer.timeout.connect(self.process_video_frame) - self.video_timer.start(30) - - def process_image(self, image): - self.detections = self.detector.detect_license_plates(image) - display_image = self.draw_detections(image.copy()) - self.display_static_image(display_image) - self.update_results_display() - - def process_video_frame(self): - ret, frame = self.cap.read() - if ret: - self.process_image(frame) - else: - self.video_timer.stop() - self.cap.release() - - def display_static_image(self, image): - rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - h, w, ch = rgb_image.shape - bytes_per_line = ch * w - qt_image = QImage(rgb_image.data, w, h, bytes_per_line, QImage.Format_RGB888) - pixmap = QPixmap.fromImage(qt_image) - self.camera_label.setPixmap(pixmap.scaled(self.camera_label.size(), Qt.KeepAspectRatio)) def main(): app = QApplication(sys.argv) diff --git a/yolopart/detector.py b/yolopart/detector.py index b95fd80..435c4fb 100644 --- a/yolopart/detector.py +++ b/yolopart/detector.py @@ -2,6 +2,7 @@ import cv2 import numpy as np from ultralytics import YOLO import os +from PIL import Image, ImageDraw, ImageFont class LicensePlateYOLO: """ @@ -113,19 +114,38 @@ class LicensePlateYOLO: print(f"检测过程中出错: {e}") return [] - def draw_detections(self, image, detections): + def draw_detections(self, image, detections, plate_numbers=None): """ 在图像上绘制检测结果 参数: image: 输入图像 detections: 检测结果列表 + plate_numbers: 车牌号列表,与detections对应 返回: numpy.ndarray: 绘制了检测结果的图像 """ draw_image = image.copy() + # 转换为PIL图像以支持中文字符 + pil_image = Image.fromarray(cv2.cvtColor(draw_image, cv2.COLOR_BGR2RGB)) + draw = ImageDraw.Draw(pil_image) + + # 尝试加载中文字体 + try: + # Windows系统常见的中文字体 + font_path = "C:/Windows/Fonts/simhei.ttf" # 黑体 + if not os.path.exists(font_path): + font_path = "C:/Windows/Fonts/msyh.ttc" # 微软雅黑 + if not os.path.exists(font_path): + font_path = "C:/Windows/Fonts/simsun.ttc" # 宋体 + + font = ImageFont.truetype(font_path, 20) + except: + # 如果无法加载字体,使用默认字体 + font = ImageFont.load_default() + for i, detection in enumerate(detections): box = detection['box'] keypoints = detection['keypoints'] @@ -133,6 +153,11 @@ class LicensePlateYOLO: confidence = detection['confidence'] incomplete = detection.get('incomplete', False) + # 获取对应的车牌号 + plate_number = "" + if plate_numbers and i < len(plate_numbers): + plate_number = plate_numbers[i] + # 绘制边界框 x1, y1, x2, y2 = map(int, box) @@ -140,30 +165,53 @@ class LicensePlateYOLO: if class_name == '绿牌': box_color = (0, 255, 0) # 绿色 elif class_name == '蓝牌': - box_color = (255, 0, 0) # 蓝色 + box_color = (0, 0, 255) # 蓝色 else: box_color = (128, 128, 128) # 灰色 - cv2.rectangle(draw_image, (x1, y1), (x2, y2), box_color, 2) + # 在PIL图像上绘制边界框 + draw.rectangle([(x1, y1), (x2, y2)], outline=box_color, width=2) + + # 构建标签文本 + if plate_number: + label = f"{class_name} {plate_number} {confidence:.2f}" + else: + label = f"{class_name} {confidence:.2f}" - # 绘制标签 - label = f"{class_name} {confidence:.2f}" if incomplete: label += " (不完整)" - # 计算文本大小和位置 - font = cv2.FONT_HERSHEY_SIMPLEX - font_scale = 0.6 - thickness = 2 - (text_width, text_height), _ = cv2.getTextSize(label, font, font_scale, thickness) + # 计算文本大小 + bbox = draw.textbbox((0, 0), label, font=font) + text_width = bbox[2] - bbox[0] + text_height = bbox[3] - bbox[1] # 绘制文本背景 - cv2.rectangle(draw_image, (x1, y1 - text_height - 10), - (x1 + text_width, y1), box_color, -1) + draw.rectangle([(x1, y1 - text_height - 10), (x1 + text_width, y1)], + fill=box_color) # 绘制文本 - cv2.putText(draw_image, label, (x1, y1 - 5), - font, font_scale, (255, 255, 255), thickness) + draw.text((x1, y1 - text_height - 5), label, fill=(255, 255, 255), font=font) + + # 转换回OpenCV格式 + draw_image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR) + + # 绘制关键点和连线(使用OpenCV) + for i, detection in enumerate(detections): + box = detection['box'] + keypoints = detection['keypoints'] + incomplete = detection.get('incomplete', False) + + x1, y1, x2, y2 = map(int, box) + + # 根据车牌类型选择颜色 + class_name = detection['class_name'] + if class_name == '绿牌': + box_color = (0, 255, 0) # 绿色 + elif class_name == '蓝牌': + box_color = (0, 0, 255) # 蓝色 + else: + box_color = (128, 128, 128) # 灰色 # 绘制关键点和连线 if len(keypoints) >= 4 and not incomplete: