更新 CRNN_part/crnn_interface.py
This commit is contained in:
parent
85c8302fc1
commit
01b286fce1
@ -1,4 +1,211 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import cv2
|
||||
from torchvision import transforms
|
||||
import os
|
||||
|
||||
# 全局变量
|
||||
crnn_model = None
|
||||
crnn_decoder = None
|
||||
crnn_preprocessor = None
|
||||
device = None
|
||||
|
||||
class CRNN(nn.Module):
|
||||
"""CRNN车牌识别模型"""
|
||||
def __init__(self, img_height=32, num_classes=68, hidden_size=256):
|
||||
super(CRNN, self).__init__()
|
||||
self.img_height = img_height
|
||||
self.num_classes = num_classes
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
# CNN特征提取部分 - 7层卷积
|
||||
self.cnn = nn.Sequential(
|
||||
# 第1层:3->64, 3x3卷积
|
||||
nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
|
||||
# 第2层:64->128, 3x3卷积
|
||||
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
|
||||
# 第3层:128->256, 3x3卷积
|
||||
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
|
||||
nn.BatchNorm2d(256),
|
||||
nn.ReLU(inplace=True),
|
||||
|
||||
# 第4层:256->256, 3x3卷积
|
||||
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
|
||||
nn.BatchNorm2d(256),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1)),
|
||||
|
||||
# 第5层:256->512, 3x3卷积
|
||||
nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
|
||||
nn.BatchNorm2d(512),
|
||||
nn.ReLU(inplace=True),
|
||||
|
||||
# 第6层:512->512, 3x3卷积
|
||||
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
|
||||
nn.BatchNorm2d(512),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1)),
|
||||
|
||||
# 第7层:512->512, 2x2卷积
|
||||
nn.Conv2d(512, 512, kernel_size=2, stride=1, padding=0),
|
||||
nn.BatchNorm2d(512),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
|
||||
# RNN序列建模部分 - 2层双向LSTM
|
||||
self.rnn = nn.LSTM(
|
||||
input_size=512,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=2,
|
||||
batch_first=True,
|
||||
bidirectional=True
|
||||
)
|
||||
|
||||
# 全连接分类层
|
||||
self.fc = nn.Linear(hidden_size * 2, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
batch_size = x.size(0)
|
||||
|
||||
# CNN特征提取
|
||||
conv_out = self.cnn(x)
|
||||
|
||||
# 重塑为RNN输入格式
|
||||
batch_size, channels, height, width = conv_out.size()
|
||||
conv_out = conv_out.permute(0, 3, 1, 2)
|
||||
conv_out = conv_out.contiguous().view(batch_size, width, channels * height)
|
||||
|
||||
# RNN序列建模
|
||||
rnn_out, _ = self.rnn(conv_out)
|
||||
|
||||
# 全连接分类
|
||||
output = self.fc(rnn_out)
|
||||
|
||||
# 转换为CTC需要的格式:(width, batch_size, num_classes)
|
||||
output = output.permute(1, 0, 2)
|
||||
|
||||
return output
|
||||
|
||||
class CTCDecoder:
|
||||
"""CTC解码器"""
|
||||
def __init__(self):
|
||||
# 定义中国车牌字符集(68个字符)
|
||||
self.chars = [
|
||||
# 空白字符(CTC需要)
|
||||
'<BLANK>',
|
||||
# 中文省份简称
|
||||
'京', '沪', '津', '渝', '冀', '晋', '蒙', '辽', '吉', '黑',
|
||||
'苏', '浙', '皖', '闽', '赣', '鲁', '豫', '鄂', '湘', '粤',
|
||||
'桂', '琼', '川', '贵', '云', '藏', '陕', '甘', '青', '宁', '新',
|
||||
# 字母 A-Z
|
||||
'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M',
|
||||
'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',
|
||||
# 数字 0-9
|
||||
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9'
|
||||
]
|
||||
|
||||
self.char_to_idx = {char: idx for idx, char in enumerate(self.chars)}
|
||||
self.idx_to_char = {idx: char for idx, char in enumerate(self.chars)}
|
||||
self.blank_idx = 0
|
||||
|
||||
def decode_greedy(self, predictions):
|
||||
"""贪婪解码"""
|
||||
# 获取每个时间步的最大概率索引
|
||||
indices = torch.argmax(predictions, dim=1)
|
||||
|
||||
# CTC解码:移除重复字符和空白字符
|
||||
decoded_chars = []
|
||||
prev_idx = -1
|
||||
|
||||
for idx in indices:
|
||||
idx = idx.item()
|
||||
if idx != prev_idx and idx != self.blank_idx:
|
||||
if idx < len(self.chars):
|
||||
decoded_chars.append(self.chars[idx])
|
||||
prev_idx = idx
|
||||
|
||||
return ''.join(decoded_chars)
|
||||
|
||||
def decode_with_confidence(self, predictions):
|
||||
"""解码并返回置信度信息"""
|
||||
# 应用softmax获得概率
|
||||
probs = torch.softmax(predictions, dim=1)
|
||||
|
||||
# 贪婪解码
|
||||
indices = torch.argmax(probs, dim=1)
|
||||
max_probs = torch.max(probs, dim=1)[0]
|
||||
|
||||
# CTC解码
|
||||
decoded_chars = []
|
||||
char_confidences = []
|
||||
prev_idx = -1
|
||||
|
||||
for i, idx in enumerate(indices):
|
||||
idx = idx.item()
|
||||
confidence = max_probs[i].item()
|
||||
|
||||
if idx != prev_idx and idx != self.blank_idx:
|
||||
if idx < len(self.chars):
|
||||
decoded_chars.append(self.chars[idx])
|
||||
char_confidences.append(confidence)
|
||||
prev_idx = idx
|
||||
|
||||
text = ''.join(decoded_chars)
|
||||
avg_confidence = np.mean(char_confidences) if char_confidences else 0.0
|
||||
|
||||
return text, avg_confidence, char_confidences
|
||||
|
||||
class LicensePlatePreprocessor:
|
||||
"""车牌图像预处理器"""
|
||||
def __init__(self, target_height=32, target_width=128):
|
||||
self.target_height = target_height
|
||||
self.target_width = target_width
|
||||
|
||||
# 定义图像变换
|
||||
self.transform = transforms.Compose([
|
||||
transforms.Resize((target_height, target_width)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
def preprocess_numpy_array(self, image_array):
|
||||
"""预处理numpy数组格式的图像"""
|
||||
try:
|
||||
# 确保图像是RGB格式
|
||||
if len(image_array.shape) == 3 and image_array.shape[2] == 3:
|
||||
# 如果是BGR格式,转换为RGB
|
||||
if image_array.dtype == np.uint8:
|
||||
image_array = cv2.cvtColor(image_array, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# 转换为PIL图像
|
||||
if image_array.dtype != np.uint8:
|
||||
image_array = (image_array * 255).astype(np.uint8)
|
||||
|
||||
image = Image.fromarray(image_array)
|
||||
|
||||
# 应用变换
|
||||
tensor = self.transform(image)
|
||||
|
||||
# 添加batch维度
|
||||
tensor = tensor.unsqueeze(0)
|
||||
|
||||
return tensor
|
||||
|
||||
except Exception as e:
|
||||
print(f"图像预处理失败: {e}")
|
||||
return None
|
||||
|
||||
def initialize_crnn_model():
|
||||
"""
|
||||
@ -7,12 +214,65 @@ def initialize_crnn_model():
|
||||
返回:
|
||||
bool: 初始化是否成功
|
||||
"""
|
||||
# CRNN模型初始化代码
|
||||
# 例如: 加载预训练模型、设置参数等
|
||||
global crnn_model, crnn_decoder, crnn_preprocessor, device
|
||||
|
||||
print("CRNN模型初始化完成(占位)")
|
||||
return True
|
||||
|
||||
try:
|
||||
# 设置设备
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
print(f"CRNN使用设备: {device}")
|
||||
|
||||
# 初始化组件
|
||||
crnn_decoder = CTCDecoder()
|
||||
crnn_preprocessor = LicensePlatePreprocessor(target_height=32, target_width=128)
|
||||
|
||||
# 创建模型实例
|
||||
crnn_model = CRNN(num_classes=len(crnn_decoder.chars), hidden_size=256)
|
||||
|
||||
# 加载模型权重
|
||||
model_path = os.path.join(os.path.dirname(__file__), 'best_model.pth')
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
raise FileNotFoundError(f"模型文件不存在: {model_path}")
|
||||
|
||||
print(f"正在加载CRNN模型: {model_path}")
|
||||
|
||||
# 加载检查点
|
||||
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
|
||||
|
||||
# 处理不同的模型保存格式
|
||||
if isinstance(checkpoint, dict):
|
||||
if 'model_state_dict' in checkpoint:
|
||||
# 完整检查点格式
|
||||
state_dict = checkpoint['model_state_dict']
|
||||
print(f"检查点信息:")
|
||||
print(f" - 训练轮次: {checkpoint.get('epoch', 'N/A')}")
|
||||
print(f" - 最佳验证损失: {checkpoint.get('best_val_loss', 'N/A')}")
|
||||
else:
|
||||
# 精简模型格式(只包含权重)
|
||||
print("加载精简模型(仅权重)")
|
||||
state_dict = checkpoint
|
||||
else:
|
||||
# 直接是状态字典
|
||||
state_dict = checkpoint
|
||||
|
||||
# 加载权重
|
||||
crnn_model.load_state_dict(state_dict)
|
||||
crnn_model.to(device)
|
||||
crnn_model.eval()
|
||||
|
||||
print("CRNN模型初始化完成")
|
||||
|
||||
# 统计模型参数
|
||||
total_params = sum(p.numel() for p in crnn_model.parameters())
|
||||
print(f"CRNN模型参数数量: {total_params:,}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"CRNN模型初始化失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def crnn_predict(image_array):
|
||||
"""
|
||||
@ -25,13 +285,47 @@ def crnn_predict(image_array):
|
||||
list: 包含7个字符的列表,代表车牌号的每个字符
|
||||
例如: ['京', 'A', '1', '2', '3', '4', '5']
|
||||
"""
|
||||
# 这是CRNN部分的占位函数
|
||||
# 实际实现时,这里应该包含:
|
||||
# 1. 图像预处理
|
||||
# 2. CRNN模型推理
|
||||
# 3. CTC解码
|
||||
# 4. 后处理和字符识别
|
||||
global crnn_model, crnn_decoder, crnn_preprocessor, device
|
||||
|
||||
# 临时返回占位结果
|
||||
placeholder_result = ['待', '识', '别', '0', '0', '0', '0']
|
||||
return placeholder_result
|
||||
if crnn_model is None or crnn_decoder is None or crnn_preprocessor is None:
|
||||
print("CRNN模型未初始化,请先调用initialize_crnn_model()")
|
||||
return ['待', '识', '别', '0', '0', '0', '0']
|
||||
|
||||
try:
|
||||
# 预处理图像
|
||||
input_tensor = crnn_preprocessor.preprocess_numpy_array(image_array)
|
||||
if input_tensor is None:
|
||||
raise ValueError("图像预处理失败")
|
||||
|
||||
input_tensor = input_tensor.to(device)
|
||||
|
||||
# 模型推理
|
||||
with torch.no_grad():
|
||||
outputs = crnn_model(input_tensor) # (seq_len, batch_size, num_classes)
|
||||
|
||||
# 移除batch维度
|
||||
outputs = outputs.squeeze(1) # (seq_len, num_classes)
|
||||
|
||||
# CTC解码
|
||||
predicted_text, confidence, char_confidences = crnn_decoder.decode_with_confidence(outputs)
|
||||
|
||||
print(f"CRNN识别结果: {predicted_text}, 置信度: {confidence:.3f}")
|
||||
|
||||
# 将字符串转换为字符列表
|
||||
char_list = list(predicted_text)
|
||||
|
||||
# 确保返回7个字符(车牌标准长度)
|
||||
if len(char_list) < 7:
|
||||
# 如果识别结果少于7个字符,用'0'补齐
|
||||
char_list.extend(['0'] * (7 - len(char_list)))
|
||||
elif len(char_list) > 7:
|
||||
# 如果识别结果多于7个字符,截取前7个
|
||||
char_list = char_list[:7]
|
||||
|
||||
return char_list
|
||||
|
||||
except Exception as e:
|
||||
print(f"CRNN识别失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return ['识', '别', '失', '败', '0', '0', '0']
|
||||
|
Loading…
x
Reference in New Issue
Block a user