332 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from PIL import Image
import cv2
from torchvision import transforms
import os
# 全局变量
crnn_model = None
crnn_decoder = None
crnn_preprocessor = None
device = None
class CRNN(nn.Module):
"""CRNN车牌识别模型"""
def __init__(self, img_height=32, num_classes=68, hidden_size=256):
super(CRNN, self).__init__()
self.img_height = img_height
self.num_classes = num_classes
self.hidden_size = hidden_size
# CNN特征提取部分 - 7层卷积
self.cnn = nn.Sequential(
# 第1层3->64, 3x3卷积
nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
# 第2层64->128, 3x3卷积
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
# 第3层128->256, 3x3卷积
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
# 第4层256->256, 3x3卷积
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1)),
# 第5层256->512, 3x3卷积
nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
# 第6层512->512, 3x3卷积
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1)),
# 第7层512->512, 2x2卷积
nn.Conv2d(512, 512, kernel_size=2, stride=1, padding=0),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
)
# RNN序列建模部分 - 2层双向LSTM
self.rnn = nn.LSTM(
input_size=512,
hidden_size=hidden_size,
num_layers=2,
batch_first=True,
bidirectional=True
)
# 全连接分类层
self.fc = nn.Linear(hidden_size * 2, num_classes)
def forward(self, x):
batch_size = x.size(0)
# CNN特征提取
conv_out = self.cnn(x)
# 重塑为RNN输入格式
batch_size, channels, height, width = conv_out.size()
conv_out = conv_out.permute(0, 3, 1, 2)
conv_out = conv_out.contiguous().view(batch_size, width, channels * height)
# RNN序列建模
rnn_out, _ = self.rnn(conv_out)
# 全连接分类
output = self.fc(rnn_out)
# 转换为CTC需要的格式(width, batch_size, num_classes)
output = output.permute(1, 0, 2)
return output
class CTCDecoder:
"""CTC解码器"""
def __init__(self):
# 定义中国车牌字符集68个字符
self.chars = [
# 空白字符CTC需要
'<BLANK>',
# 中文省份简称
'', '', '', '', '', '', '', '', '', '',
'', '', '', '', '', '', '', '', '', '',
'', '', '', '', '', '', '', '', '', '', '',
# 字母 A-Z
'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M',
'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',
# 数字 0-9
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9'
]
self.char_to_idx = {char: idx for idx, char in enumerate(self.chars)}
self.idx_to_char = {idx: char for idx, char in enumerate(self.chars)}
self.blank_idx = 0
def decode_greedy(self, predictions):
"""贪婪解码"""
# 获取每个时间步的最大概率索引
indices = torch.argmax(predictions, dim=1)
# CTC解码移除重复字符和空白字符
decoded_chars = []
prev_idx = -1
for idx in indices:
idx = idx.item()
if idx != prev_idx and idx != self.blank_idx:
if idx < len(self.chars):
decoded_chars.append(self.chars[idx])
prev_idx = idx
return ''.join(decoded_chars)
def decode_with_confidence(self, predictions):
"""解码并返回置信度信息"""
# 应用softmax获得概率
probs = torch.softmax(predictions, dim=1)
# 贪婪解码
indices = torch.argmax(probs, dim=1)
max_probs = torch.max(probs, dim=1)[0]
# CTC解码
decoded_chars = []
char_confidences = []
prev_idx = -1
for i, idx in enumerate(indices):
idx = idx.item()
confidence = max_probs[i].item()
if idx != prev_idx and idx != self.blank_idx:
if idx < len(self.chars):
decoded_chars.append(self.chars[idx])
char_confidences.append(confidence)
prev_idx = idx
text = ''.join(decoded_chars)
avg_confidence = np.mean(char_confidences) if char_confidences else 0.0
return text, avg_confidence, char_confidences
class LicensePlatePreprocessor:
"""车牌图像预处理器"""
def __init__(self, target_height=32, target_width=128):
self.target_height = target_height
self.target_width = target_width
# 定义图像变换
self.transform = transforms.Compose([
transforms.Resize((target_height, target_width)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def preprocess_numpy_array(self, image_array):
"""预处理numpy数组格式的图像"""
try:
# 确保图像是RGB格式
if len(image_array.shape) == 3 and image_array.shape[2] == 3:
# 如果是BGR格式转换为RGB
if image_array.dtype == np.uint8:
image_array = cv2.cvtColor(image_array, cv2.COLOR_BGR2RGB)
# 转换为PIL图像
if image_array.dtype != np.uint8:
image_array = (image_array * 255).astype(np.uint8)
image = Image.fromarray(image_array)
# 应用变换
tensor = self.transform(image)
# 添加batch维度
tensor = tensor.unsqueeze(0)
return tensor
except Exception as e:
print(f"图像预处理失败: {e}")
return None
def initialize_crnn_model():
"""
初始化CRNN模型
返回:
bool: 初始化是否成功
"""
global crnn_model, crnn_decoder, crnn_preprocessor, device
try:
# 设置设备
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"CRNN使用设备: {device}")
# 初始化组件
crnn_decoder = CTCDecoder()
crnn_preprocessor = LicensePlatePreprocessor(target_height=32, target_width=128)
# 创建模型实例
crnn_model = CRNN(num_classes=len(crnn_decoder.chars), hidden_size=256)
# 加载模型权重
model_path = os.path.join(os.path.dirname(__file__), 'best_model.pth')
if not os.path.exists(model_path):
raise FileNotFoundError(f"模型文件不存在: {model_path}")
print(f"正在加载CRNN模型: {model_path}")
# 加载检查点
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
# 处理不同的模型保存格式
if isinstance(checkpoint, dict):
if 'model_state_dict' in checkpoint:
# 完整检查点格式
state_dict = checkpoint['model_state_dict']
print(f"检查点信息:")
print(f" - 训练轮次: {checkpoint.get('epoch', 'N/A')}")
print(f" - 最佳验证损失: {checkpoint.get('best_val_loss', 'N/A')}")
else:
# 精简模型格式(只包含权重)
print("加载精简模型(仅权重)")
state_dict = checkpoint
else:
# 直接是状态字典
state_dict = checkpoint
# 加载权重
crnn_model.load_state_dict(state_dict)
crnn_model.to(device)
crnn_model.eval()
print("CRNN模型初始化完成")
# 统计模型参数
total_params = sum(p.numel() for p in crnn_model.parameters())
print(f"CRNN模型参数数量: {total_params:,}")
return True
except Exception as e:
print(f"CRNN模型初始化失败: {e}")
import traceback
traceback.print_exc()
return False
def crnn_predict(image_array):
"""
CRNN车牌号识别接口函数
参数:
image_array: numpy数组格式的车牌图像已经过矫正处理
返回:
list: 包含7个字符的列表代表车牌号的每个字符
例如: ['', 'A', '1', '2', '3', '4', '5']
"""
global crnn_model, crnn_decoder, crnn_preprocessor, device
if crnn_model is None or crnn_decoder is None or crnn_preprocessor is None:
print("CRNN模型未初始化请先调用initialize_crnn_model()")
return ['', '', '', '0', '0', '0', '0']
try:
# 预处理图像
input_tensor = crnn_preprocessor.preprocess_numpy_array(image_array)
if input_tensor is None:
raise ValueError("图像预处理失败")
input_tensor = input_tensor.to(device)
# 模型推理
with torch.no_grad():
outputs = crnn_model(input_tensor) # (seq_len, batch_size, num_classes)
# 移除batch维度
outputs = outputs.squeeze(1) # (seq_len, num_classes)
# CTC解码
predicted_text, confidence, char_confidences = crnn_decoder.decode_with_confidence(outputs)
print(f"CRNN识别结果: {predicted_text}, 置信度: {confidence:.3f}")
# 将字符串转换为字符列表
char_list = list(predicted_text)
# 确保返回7个字符车牌标准长度
if len(char_list) < 7:
# 如果识别结果少于7个字符用'0'补齐
char_list.extend(['0'] * (7 - len(char_list)))
elif len(char_list) > 7:
# 如果识别结果多于7个字符截取前7个
char_list = char_list[:7]
return char_list
except Exception as e:
print(f"CRNN识别失败: {e}")
import traceback
traceback.print_exc()
return ['', '', '', '', '0', '0', '0']