import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from PIL import Image import cv2 from torchvision import transforms import os # 全局变量 crnn_model = None crnn_decoder = None crnn_preprocessor = None device = None class CRNN(nn.Module): """CRNN车牌识别模型""" def __init__(self, img_height=32, num_classes=68, hidden_size=256): super(CRNN, self).__init__() self.img_height = img_height self.num_classes = num_classes self.hidden_size = hidden_size # CNN特征提取部分 - 7层卷积 self.cnn = nn.Sequential( # 第1层:3->64, 3x3卷积 nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2), # 第2层:64->128, 3x3卷积 nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2), # 第3层:128->256, 3x3卷积 nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), # 第4层:256->256, 3x3卷积 nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1)), # 第5层:256->512, 3x3卷积 nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(512), nn.ReLU(inplace=True), # 第6层:512->512, 3x3卷积 nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(512), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1)), # 第7层:512->512, 2x2卷积 nn.Conv2d(512, 512, kernel_size=2, stride=1, padding=0), nn.BatchNorm2d(512), nn.ReLU(inplace=True), ) # RNN序列建模部分 - 2层双向LSTM self.rnn = nn.LSTM( input_size=512, hidden_size=hidden_size, num_layers=2, batch_first=True, bidirectional=True ) # 全连接分类层 self.fc = nn.Linear(hidden_size * 2, num_classes) def forward(self, x): batch_size = x.size(0) # CNN特征提取 conv_out = self.cnn(x) # 重塑为RNN输入格式 batch_size, channels, height, width = conv_out.size() conv_out = conv_out.permute(0, 3, 1, 2) conv_out = conv_out.contiguous().view(batch_size, width, channels * height) # RNN序列建模 rnn_out, _ = self.rnn(conv_out) # 全连接分类 output = self.fc(rnn_out) # 转换为CTC需要的格式:(width, batch_size, num_classes) output = output.permute(1, 0, 2) return output class CTCDecoder: """CTC解码器""" def __init__(self): # 定义中国车牌字符集(68个字符) self.chars = [ # 空白字符(CTC需要) '', # 中文省份简称 '京', '沪', '津', '渝', '冀', '晋', '蒙', '辽', '吉', '黑', '苏', '浙', '皖', '闽', '赣', '鲁', '豫', '鄂', '湘', '粤', '桂', '琼', '川', '贵', '云', '藏', '陕', '甘', '青', '宁', '新', # 字母 A-Z 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', # 数字 0-9 '0', '1', '2', '3', '4', '5', '6', '7', '8', '9' ] self.char_to_idx = {char: idx for idx, char in enumerate(self.chars)} self.idx_to_char = {idx: char for idx, char in enumerate(self.chars)} self.blank_idx = 0 def decode_greedy(self, predictions): """贪婪解码""" # 获取每个时间步的最大概率索引 indices = torch.argmax(predictions, dim=1) # CTC解码:移除重复字符和空白字符 decoded_chars = [] prev_idx = -1 for idx in indices: idx = idx.item() if idx != prev_idx and idx != self.blank_idx: if idx < len(self.chars): decoded_chars.append(self.chars[idx]) prev_idx = idx return ''.join(decoded_chars) def decode_with_confidence(self, predictions): """解码并返回置信度信息""" # 应用softmax获得概率 probs = torch.softmax(predictions, dim=1) # 贪婪解码 indices = torch.argmax(probs, dim=1) max_probs = torch.max(probs, dim=1)[0] # CTC解码 decoded_chars = [] char_confidences = [] prev_idx = -1 for i, idx in enumerate(indices): idx = idx.item() confidence = max_probs[i].item() if idx != prev_idx and idx != self.blank_idx: if idx < len(self.chars): decoded_chars.append(self.chars[idx]) char_confidences.append(confidence) prev_idx = idx text = ''.join(decoded_chars) avg_confidence = np.mean(char_confidences) if char_confidences else 0.0 return text, avg_confidence, char_confidences class LicensePlatePreprocessor: """车牌图像预处理器""" def __init__(self, target_height=32, target_width=128): self.target_height = target_height self.target_width = target_width # 定义图像变换 self.transform = transforms.Compose([ transforms.Resize((target_height, target_width)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def preprocess_numpy_array(self, image_array): """预处理numpy数组格式的图像""" try: # 确保图像是RGB格式 if len(image_array.shape) == 3 and image_array.shape[2] == 3: # 如果是BGR格式,转换为RGB if image_array.dtype == np.uint8: image_array = cv2.cvtColor(image_array, cv2.COLOR_BGR2RGB) # 转换为PIL图像 if image_array.dtype != np.uint8: image_array = (image_array * 255).astype(np.uint8) image = Image.fromarray(image_array) # 应用变换 tensor = self.transform(image) # 添加batch维度 tensor = tensor.unsqueeze(0) return tensor except Exception as e: print(f"图像预处理失败: {e}") return None def initialize_crnn_model(): """ 初始化CRNN模型 返回: bool: 初始化是否成功 """ global crnn_model, crnn_decoder, crnn_preprocessor, device try: # 设置设备 device = 'cuda' if torch.cuda.is_available() else 'cpu' print(f"CRNN使用设备: {device}") # 初始化组件 crnn_decoder = CTCDecoder() crnn_preprocessor = LicensePlatePreprocessor(target_height=32, target_width=128) # 创建模型实例 crnn_model = CRNN(num_classes=len(crnn_decoder.chars), hidden_size=256) # 加载模型权重 model_path = os.path.join(os.path.dirname(__file__), 'best_model.pth') if not os.path.exists(model_path): raise FileNotFoundError(f"模型文件不存在: {model_path}") print(f"正在加载CRNN模型: {model_path}") # 加载检查点 checkpoint = torch.load(model_path, map_location=device, weights_only=False) # 处理不同的模型保存格式 if isinstance(checkpoint, dict): if 'model_state_dict' in checkpoint: # 完整检查点格式 state_dict = checkpoint['model_state_dict'] print(f"检查点信息:") print(f" - 训练轮次: {checkpoint.get('epoch', 'N/A')}") print(f" - 最佳验证损失: {checkpoint.get('best_val_loss', 'N/A')}") else: # 精简模型格式(只包含权重) print("加载精简模型(仅权重)") state_dict = checkpoint else: # 直接是状态字典 state_dict = checkpoint # 加载权重 crnn_model.load_state_dict(state_dict) crnn_model.to(device) crnn_model.eval() print("CRNN模型初始化完成") # 统计模型参数 total_params = sum(p.numel() for p in crnn_model.parameters()) print(f"CRNN模型参数数量: {total_params:,}") return True except Exception as e: print(f"CRNN模型初始化失败: {e}") import traceback traceback.print_exc() return False def crnn_predict(image_array): """ CRNN车牌号识别接口函数 参数: image_array: numpy数组格式的车牌图像,已经过矫正处理 返回: list: 包含7个字符的列表,代表车牌号的每个字符 例如: ['京', 'A', '1', '2', '3', '4', '5'] """ global crnn_model, crnn_decoder, crnn_preprocessor, device if crnn_model is None or crnn_decoder is None or crnn_preprocessor is None: print("CRNN模型未初始化,请先调用initialize_crnn_model()") return ['待', '识', '别', '0', '0', '0', '0'] try: # 预处理图像 input_tensor = crnn_preprocessor.preprocess_numpy_array(image_array) if input_tensor is None: raise ValueError("图像预处理失败") input_tensor = input_tensor.to(device) # 模型推理 with torch.no_grad(): outputs = crnn_model(input_tensor) # (seq_len, batch_size, num_classes) # 移除batch维度 outputs = outputs.squeeze(1) # (seq_len, num_classes) # CTC解码 predicted_text, confidence, char_confidences = crnn_decoder.decode_with_confidence(outputs) print(f"CRNN识别结果: {predicted_text}, 置信度: {confidence:.3f}") # 将字符串转换为字符列表 char_list = list(predicted_text) # 确保返回7个字符(车牌标准长度) if len(char_list) < 7: # 如果识别结果少于7个字符,用'0'补齐 char_list.extend(['0'] * (7 - len(char_list))) elif len(char_list) > 7: # 如果识别结果多于7个字符,截取前7个 char_list = char_list[:7] return char_list except Exception as e: print(f"CRNN识别失败: {e}") import traceback traceback.print_exc() return ['识', '别', '失', '败', '0', '0', '0']