Compare commits

..

No commits in common. "8eef0d9414ef1cfeaecaa87c5ed3f17c9979b85f" and "3d7c7a06e455e7ebdbc70c207f2d772ad4d2608d" have entirely different histories.

8 changed files with 48 additions and 348 deletions

View File

@ -2,7 +2,7 @@
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="D:\conda_envs\RLP" jdkType="Python SDK" />
<orderEntry type="jdk" jdkName="pytorh" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="PyDocumentationSettings">

2
.idea/misc.xml generated
View File

@ -3,5 +3,5 @@
<component name="Black">
<option name="sdkName" value="pytorh" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="D:\conda_envs\RLP" project-jdk-type="Python SDK" />
<component name="ProjectRootManager" version="2" project-jdk-name="pytorh" project-jdk-type="Python SDK" />
</project>

Binary file not shown.

View File

@ -1,211 +1,4 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from PIL import Image
import cv2
from torchvision import transforms
import os
# 全局变量
crnn_model = None
crnn_decoder = None
crnn_preprocessor = None
device = None
class CRNN(nn.Module):
"""CRNN车牌识别模型"""
def __init__(self, img_height=32, num_classes=68, hidden_size=256):
super(CRNN, self).__init__()
self.img_height = img_height
self.num_classes = num_classes
self.hidden_size = hidden_size
# CNN特征提取部分 - 7层卷积
self.cnn = nn.Sequential(
# 第1层3->64, 3x3卷积
nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
# 第2层64->128, 3x3卷积
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
# 第3层128->256, 3x3卷积
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
# 第4层256->256, 3x3卷积
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1)),
# 第5层256->512, 3x3卷积
nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
# 第6层512->512, 3x3卷积
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1)),
# 第7层512->512, 2x2卷积
nn.Conv2d(512, 512, kernel_size=2, stride=1, padding=0),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
)
# RNN序列建模部分 - 2层双向LSTM
self.rnn = nn.LSTM(
input_size=512,
hidden_size=hidden_size,
num_layers=2,
batch_first=True,
bidirectional=True
)
# 全连接分类层
self.fc = nn.Linear(hidden_size * 2, num_classes)
def forward(self, x):
batch_size = x.size(0)
# CNN特征提取
conv_out = self.cnn(x)
# 重塑为RNN输入格式
batch_size, channels, height, width = conv_out.size()
conv_out = conv_out.permute(0, 3, 1, 2)
conv_out = conv_out.contiguous().view(batch_size, width, channels * height)
# RNN序列建模
rnn_out, _ = self.rnn(conv_out)
# 全连接分类
output = self.fc(rnn_out)
# 转换为CTC需要的格式(width, batch_size, num_classes)
output = output.permute(1, 0, 2)
return output
class CTCDecoder:
"""CTC解码器"""
def __init__(self):
# 定义中国车牌字符集68个字符
self.chars = [
# 空白字符CTC需要
'<BLANK>',
# 中文省份简称
'', '', '', '', '', '', '', '', '', '',
'', '', '', '', '', '', '', '', '', '',
'', '', '', '', '', '', '', '', '', '', '',
# 字母 A-Z
'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M',
'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',
# 数字 0-9
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9'
]
self.char_to_idx = {char: idx for idx, char in enumerate(self.chars)}
self.idx_to_char = {idx: char for idx, char in enumerate(self.chars)}
self.blank_idx = 0
def decode_greedy(self, predictions):
"""贪婪解码"""
# 获取每个时间步的最大概率索引
indices = torch.argmax(predictions, dim=1)
# CTC解码移除重复字符和空白字符
decoded_chars = []
prev_idx = -1
for idx in indices:
idx = idx.item()
if idx != prev_idx and idx != self.blank_idx:
if idx < len(self.chars):
decoded_chars.append(self.chars[idx])
prev_idx = idx
return ''.join(decoded_chars)
def decode_with_confidence(self, predictions):
"""解码并返回置信度信息"""
# 应用softmax获得概率
probs = torch.softmax(predictions, dim=1)
# 贪婪解码
indices = torch.argmax(probs, dim=1)
max_probs = torch.max(probs, dim=1)[0]
# CTC解码
decoded_chars = []
char_confidences = []
prev_idx = -1
for i, idx in enumerate(indices):
idx = idx.item()
confidence = max_probs[i].item()
if idx != prev_idx and idx != self.blank_idx:
if idx < len(self.chars):
decoded_chars.append(self.chars[idx])
char_confidences.append(confidence)
prev_idx = idx
text = ''.join(decoded_chars)
avg_confidence = np.mean(char_confidences) if char_confidences else 0.0
return text, avg_confidence, char_confidences
class LicensePlatePreprocessor:
"""车牌图像预处理器"""
def __init__(self, target_height=32, target_width=128):
self.target_height = target_height
self.target_width = target_width
# 定义图像变换
self.transform = transforms.Compose([
transforms.Resize((target_height, target_width)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def preprocess_numpy_array(self, image_array):
"""预处理numpy数组格式的图像"""
try:
# 确保图像是RGB格式
if len(image_array.shape) == 3 and image_array.shape[2] == 3:
# 如果是BGR格式转换为RGB
if image_array.dtype == np.uint8:
image_array = cv2.cvtColor(image_array, cv2.COLOR_BGR2RGB)
# 转换为PIL图像
if image_array.dtype != np.uint8:
image_array = (image_array * 255).astype(np.uint8)
image = Image.fromarray(image_array)
# 应用变换
tensor = self.transform(image)
# 添加batch维度
tensor = tensor.unsqueeze(0)
return tensor
except Exception as e:
print(f"图像预处理失败: {e}")
return None
def initialize_crnn_model():
"""
@ -214,65 +7,12 @@ def initialize_crnn_model():
返回:
bool: 初始化是否成功
"""
global crnn_model, crnn_decoder, crnn_preprocessor, device
# CRNN模型初始化代码
# 例如: 加载预训练模型、设置参数等
try:
# 设置设备
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"CRNN使用设备: {device}")
# 初始化组件
crnn_decoder = CTCDecoder()
crnn_preprocessor = LicensePlatePreprocessor(target_height=32, target_width=128)
# 创建模型实例
crnn_model = CRNN(num_classes=len(crnn_decoder.chars), hidden_size=256)
# 加载模型权重
model_path = os.path.join(os.path.dirname(__file__), 'best_model.pth')
if not os.path.exists(model_path):
raise FileNotFoundError(f"模型文件不存在: {model_path}")
print(f"正在加载CRNN模型: {model_path}")
# 加载检查点
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
# 处理不同的模型保存格式
if isinstance(checkpoint, dict):
if 'model_state_dict' in checkpoint:
# 完整检查点格式
state_dict = checkpoint['model_state_dict']
print(f"检查点信息:")
print(f" - 训练轮次: {checkpoint.get('epoch', 'N/A')}")
print(f" - 最佳验证损失: {checkpoint.get('best_val_loss', 'N/A')}")
else:
# 精简模型格式(只包含权重)
print("加载精简模型(仅权重)")
state_dict = checkpoint
else:
# 直接是状态字典
state_dict = checkpoint
# 加载权重
crnn_model.load_state_dict(state_dict)
crnn_model.to(device)
crnn_model.eval()
print("CRNN模型初始化完成")
# 统计模型参数
total_params = sum(p.numel() for p in crnn_model.parameters())
print(f"CRNN模型参数数量: {total_params:,}")
return True
except Exception as e:
print(f"CRNN模型初始化失败: {e}")
import traceback
traceback.print_exc()
return False
print("CRNN模型初始化完成占位")
return True
def crnn_predict(image_array):
"""
@ -285,47 +25,13 @@ def crnn_predict(image_array):
list: 包含7个字符的列表代表车牌号的每个字符
例如: ['', 'A', '1', '2', '3', '4', '5']
"""
global crnn_model, crnn_decoder, crnn_preprocessor, device
# 这是CRNN部分的占位函数
# 实际实现时,这里应该包含:
# 1. 图像预处理
# 2. CRNN模型推理
# 3. CTC解码
# 4. 后处理和字符识别
if crnn_model is None or crnn_decoder is None or crnn_preprocessor is None:
print("CRNN模型未初始化请先调用initialize_crnn_model()")
return ['', '', '', '0', '0', '0', '0']
try:
# 预处理图像
input_tensor = crnn_preprocessor.preprocess_numpy_array(image_array)
if input_tensor is None:
raise ValueError("图像预处理失败")
input_tensor = input_tensor.to(device)
# 模型推理
with torch.no_grad():
outputs = crnn_model(input_tensor) # (seq_len, batch_size, num_classes)
# 移除batch维度
outputs = outputs.squeeze(1) # (seq_len, num_classes)
# CTC解码
predicted_text, confidence, char_confidences = crnn_decoder.decode_with_confidence(outputs)
print(f"CRNN识别结果: {predicted_text}, 置信度: {confidence:.3f}")
# 将字符串转换为字符列表
char_list = list(predicted_text)
# 确保返回7个字符车牌标准长度
if len(char_list) < 7:
# 如果识别结果少于7个字符用'0'补齐
char_list.extend(['0'] * (7 - len(char_list)))
elif len(char_list) > 7:
# 如果识别结果多于7个字符截取前7个
char_list = char_list[:7]
return char_list
except Exception as e:
print(f"CRNN识别失败: {e}")
import traceback
traceback.print_exc()
return ['', '', '', '', '0', '0', '0']
# 临时返回占位结果
placeholder_result = ['', '', '', '0', '0', '0', '0']
return placeholder_result

View File

@ -1,28 +1,36 @@
import numpy as np
from paddleocr import TextRecognition
import cv2
class OCRProcessor:
def __init__(self):
self.model = TextRecognition(model_name="PP-OCRv5_server_rec")
print("OCR模型初始化完成占位")
def predict(self, image_array):
# 保持原有模型调用方式
output = self.model.predict(input=image_array)
# 结构化输出结果
results = output[0]["rec_text"]
placeholder_result = results.split(',')
return placeholder_result
# 保留原有函数接口
_processor = OCRProcessor()
def initialize_ocr_model():
return _processor
"""
初始化OCR模型
返回:
bool: 初始化是否成功
"""
# OCR模型初始化代码
# 例如: 加载预训练模型、设置参数等
print("OCR模型初始化完成占位")
return True
def ocr_predict(image_array):
return _processor.predict(image_array)
"""
OCR车牌号识别接口函数
参数:
image_array: numpy数组格式的车牌图像已经过矫正处理
返回:
list: 包含7个字符的列表代表车牌号的每个字符
例如: ['', 'A', '1', '2', '3', '4', '5']
"""
# 这是OCR部分的占位函数
# 实际实现时,这里应该包含:
# 1. 图像预处理
# 2. OCR模型推理
# 3. 后处理和字符识别
# 临时返回占位结果
placeholder_result = ['', '', '', '0', '0', '0', '0']
return placeholder_result

View File

@ -15,7 +15,7 @@ License_plate_recognition/
├── OCR_part/ # OCR识别模块
│ └── ocr_interface.py # OCR接口占位
└── CRNN_part/ # CRNN识别模块
└── crnn_interface.py # CRNN
└── crnn_interface.py # CRNN接口(占位)
```
## 功能特性

11
main.py
View File

@ -10,10 +10,7 @@ from PyQt5.QtGui import QImage, QPixmap, QFont, QPainter, QPen, QColor
import os
from yolopart.detector import LicensePlateYOLO
from OCR_part.ocr_interface import ocr_predict
from OCR_part.ocr_interface import initialize_ocr_model
# 使用CRNN进行车牌字符识别
# from CRNN_part.crnn_interface import crnn_predict
from CRNN_part.crnn_interface import initialize_crnn_model
#from CRNN_part.crnn_interface import crnn_predict不使用CRNN
class CameraThread(QThread):
"""摄像头线程类"""
@ -162,11 +159,6 @@ class MainWindow(QMainWindow):
self.init_ui()
self.init_detector()
self.init_camera()
# 初始化OCR/CRNN模型具体用哪个模块识别车牌号就写在这儿
initialize_ocr_model()
# initialize_crnn_model()
def init_ui(self):
"""初始化用户界面"""
@ -393,7 +385,6 @@ class MainWindow(QMainWindow):
# 使用OCR接口进行识别
# 可以根据需要切换为CRNN: crnn_predict(corrected_image)
result = ocr_predict(corrected_image)
# result = crnn_predict(corrected_image)
# 将字符列表转换为字符串
if isinstance(result, list) and len(result) >= 7:

View File

@ -11,11 +11,6 @@ PyQt5>=5.15.0
# 图像处理
Pillow>=8.0.0
#paddleocr
python -m pip install paddlepaddle-gpu==3.0.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu118/
python -m pip install "paddleocr[all]"
# 可选如果需要GPU加速
# torch>=1.9.0
# torchvision>=0.10.0