Compare commits

...

15 Commits

Author SHA1 Message Date
8eef0d9414 Merge pull request 'yolorestart' (#1) from yolopart_restart into main
Reviewed-on: #1
2025-08-31 18:42:35 +08:00
8e8fda7fe9 Merge remote-tracking branch 'origin/ocr-v1' into ocr-v1
# Conflicts:
#	OCR_part/ocr_interface.py
2025-08-31 18:37:40 +08:00
9879cb1547 Merge pull request 'yolorestart' (#1) from yolopart_restart into main
Reviewed-on: #1
2025-08-31 18:36:36 +08:00
3829cf76ee Merge pull request 'yolorestart' (#1) from yolopart_restart into main
Reviewed-on: #1
2025-08-31 18:28:57 +08:00
c8a541ec11 Merge pull request 'yolorestart' (#1) from yolopart_restart into main
Reviewed-on: #1
2025-08-31 16:11:18 +08:00
b5839d2c36 更新 README.md 2025-08-31 12:53:11 +08:00
afe15b990a 更新 main.py 2025-08-31 12:19:25 +08:00
7f89965956 上传文件至 CRNN_part 2025-08-31 12:18:48 +08:00
c7ecc5325e 删除 CRNN_part/best_model.pth 2025-08-31 12:17:59 +08:00
01b286fce1 更新 CRNN_part/crnn_interface.py 2025-08-31 12:15:38 +08:00
85c8302fc1 Merge pull request 'yolopart_restart' (#3) from yolopart_restart into main
Reviewed-on: #3
2025-08-31 01:26:01 +08:00
0cd70df215 CRNN model 2025-08-31 01:16:08 +08:00
658560c34f Merge pull request 'yolorestart' (#2) from yolopart_restart into main
Reviewed-on: #2
2025-08-30 12:33:05 +08:00
c773a12f90 Merge remote-tracking branch 'origin/main' into yolopart_restart 2025-08-30 12:28:53 +08:00
a41a4a2236 yolorestart 2025-08-30 12:23:01 +08:00
8 changed files with 348 additions and 48 deletions

View File

@ -2,7 +2,7 @@
<module type="PYTHON_MODULE" version="4"> <module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager"> <component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" /> <content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="pytorh" jdkType="Python SDK" /> <orderEntry type="jdk" jdkName="D:\conda_envs\RLP" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" /> <orderEntry type="sourceFolder" forTests="false" />
</component> </component>
<component name="PyDocumentationSettings"> <component name="PyDocumentationSettings">

2
.idea/misc.xml generated
View File

@ -3,5 +3,5 @@
<component name="Black"> <component name="Black">
<option name="sdkName" value="pytorh" /> <option name="sdkName" value="pytorh" />
</component> </component>
<component name="ProjectRootManager" version="2" project-jdk-name="pytorh" project-jdk-type="Python SDK" /> <component name="ProjectRootManager" version="2" project-jdk-name="D:\conda_envs\RLP" project-jdk-type="Python SDK" />
</project> </project>

BIN
CRNN_part/best_model.pth Normal file

Binary file not shown.

View File

@ -1,4 +1,211 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np import numpy as np
from PIL import Image
import cv2
from torchvision import transforms
import os
# 全局变量
crnn_model = None
crnn_decoder = None
crnn_preprocessor = None
device = None
class CRNN(nn.Module):
"""CRNN车牌识别模型"""
def __init__(self, img_height=32, num_classes=68, hidden_size=256):
super(CRNN, self).__init__()
self.img_height = img_height
self.num_classes = num_classes
self.hidden_size = hidden_size
# CNN特征提取部分 - 7层卷积
self.cnn = nn.Sequential(
# 第1层3->64, 3x3卷积
nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
# 第2层64->128, 3x3卷积
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
# 第3层128->256, 3x3卷积
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
# 第4层256->256, 3x3卷积
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1)),
# 第5层256->512, 3x3卷积
nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
# 第6层512->512, 3x3卷积
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1)),
# 第7层512->512, 2x2卷积
nn.Conv2d(512, 512, kernel_size=2, stride=1, padding=0),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
)
# RNN序列建模部分 - 2层双向LSTM
self.rnn = nn.LSTM(
input_size=512,
hidden_size=hidden_size,
num_layers=2,
batch_first=True,
bidirectional=True
)
# 全连接分类层
self.fc = nn.Linear(hidden_size * 2, num_classes)
def forward(self, x):
batch_size = x.size(0)
# CNN特征提取
conv_out = self.cnn(x)
# 重塑为RNN输入格式
batch_size, channels, height, width = conv_out.size()
conv_out = conv_out.permute(0, 3, 1, 2)
conv_out = conv_out.contiguous().view(batch_size, width, channels * height)
# RNN序列建模
rnn_out, _ = self.rnn(conv_out)
# 全连接分类
output = self.fc(rnn_out)
# 转换为CTC需要的格式(width, batch_size, num_classes)
output = output.permute(1, 0, 2)
return output
class CTCDecoder:
"""CTC解码器"""
def __init__(self):
# 定义中国车牌字符集68个字符
self.chars = [
# 空白字符CTC需要
'<BLANK>',
# 中文省份简称
'', '', '', '', '', '', '', '', '', '',
'', '', '', '', '', '', '', '', '', '',
'', '', '', '', '', '', '', '', '', '', '',
# 字母 A-Z
'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M',
'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',
# 数字 0-9
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9'
]
self.char_to_idx = {char: idx for idx, char in enumerate(self.chars)}
self.idx_to_char = {idx: char for idx, char in enumerate(self.chars)}
self.blank_idx = 0
def decode_greedy(self, predictions):
"""贪婪解码"""
# 获取每个时间步的最大概率索引
indices = torch.argmax(predictions, dim=1)
# CTC解码移除重复字符和空白字符
decoded_chars = []
prev_idx = -1
for idx in indices:
idx = idx.item()
if idx != prev_idx and idx != self.blank_idx:
if idx < len(self.chars):
decoded_chars.append(self.chars[idx])
prev_idx = idx
return ''.join(decoded_chars)
def decode_with_confidence(self, predictions):
"""解码并返回置信度信息"""
# 应用softmax获得概率
probs = torch.softmax(predictions, dim=1)
# 贪婪解码
indices = torch.argmax(probs, dim=1)
max_probs = torch.max(probs, dim=1)[0]
# CTC解码
decoded_chars = []
char_confidences = []
prev_idx = -1
for i, idx in enumerate(indices):
idx = idx.item()
confidence = max_probs[i].item()
if idx != prev_idx and idx != self.blank_idx:
if idx < len(self.chars):
decoded_chars.append(self.chars[idx])
char_confidences.append(confidence)
prev_idx = idx
text = ''.join(decoded_chars)
avg_confidence = np.mean(char_confidences) if char_confidences else 0.0
return text, avg_confidence, char_confidences
class LicensePlatePreprocessor:
"""车牌图像预处理器"""
def __init__(self, target_height=32, target_width=128):
self.target_height = target_height
self.target_width = target_width
# 定义图像变换
self.transform = transforms.Compose([
transforms.Resize((target_height, target_width)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def preprocess_numpy_array(self, image_array):
"""预处理numpy数组格式的图像"""
try:
# 确保图像是RGB格式
if len(image_array.shape) == 3 and image_array.shape[2] == 3:
# 如果是BGR格式转换为RGB
if image_array.dtype == np.uint8:
image_array = cv2.cvtColor(image_array, cv2.COLOR_BGR2RGB)
# 转换为PIL图像
if image_array.dtype != np.uint8:
image_array = (image_array * 255).astype(np.uint8)
image = Image.fromarray(image_array)
# 应用变换
tensor = self.transform(image)
# 添加batch维度
tensor = tensor.unsqueeze(0)
return tensor
except Exception as e:
print(f"图像预处理失败: {e}")
return None
def initialize_crnn_model(): def initialize_crnn_model():
""" """
@ -7,12 +214,65 @@ def initialize_crnn_model():
返回: 返回:
bool: 初始化是否成功 bool: 初始化是否成功
""" """
# CRNN模型初始化代码 global crnn_model, crnn_decoder, crnn_preprocessor, device
# 例如: 加载预训练模型、设置参数等
try:
# 设置设备
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"CRNN使用设备: {device}")
# 初始化组件
crnn_decoder = CTCDecoder()
crnn_preprocessor = LicensePlatePreprocessor(target_height=32, target_width=128)
# 创建模型实例
crnn_model = CRNN(num_classes=len(crnn_decoder.chars), hidden_size=256)
# 加载模型权重
model_path = os.path.join(os.path.dirname(__file__), 'best_model.pth')
if not os.path.exists(model_path):
raise FileNotFoundError(f"模型文件不存在: {model_path}")
print(f"正在加载CRNN模型: {model_path}")
# 加载检查点
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
# 处理不同的模型保存格式
if isinstance(checkpoint, dict):
if 'model_state_dict' in checkpoint:
# 完整检查点格式
state_dict = checkpoint['model_state_dict']
print(f"检查点信息:")
print(f" - 训练轮次: {checkpoint.get('epoch', 'N/A')}")
print(f" - 最佳验证损失: {checkpoint.get('best_val_loss', 'N/A')}")
else:
# 精简模型格式(只包含权重)
print("加载精简模型(仅权重)")
state_dict = checkpoint
else:
# 直接是状态字典
state_dict = checkpoint
# 加载权重
crnn_model.load_state_dict(state_dict)
crnn_model.to(device)
crnn_model.eval()
print("CRNN模型初始化完成")
# 统计模型参数
total_params = sum(p.numel() for p in crnn_model.parameters())
print(f"CRNN模型参数数量: {total_params:,}")
print("CRNN模型初始化完成占位")
return True return True
except Exception as e:
print(f"CRNN模型初始化失败: {e}")
import traceback
traceback.print_exc()
return False
def crnn_predict(image_array): def crnn_predict(image_array):
""" """
@ -25,13 +285,47 @@ def crnn_predict(image_array):
list: 包含7个字符的列表代表车牌号的每个字符 list: 包含7个字符的列表代表车牌号的每个字符
例如: ['', 'A', '1', '2', '3', '4', '5'] 例如: ['', 'A', '1', '2', '3', '4', '5']
""" """
# 这是CRNN部分的占位函数 global crnn_model, crnn_decoder, crnn_preprocessor, device
# 实际实现时,这里应该包含:
# 1. 图像预处理
# 2. CRNN模型推理
# 3. CTC解码
# 4. 后处理和字符识别
# 临时返回占位结果 if crnn_model is None or crnn_decoder is None or crnn_preprocessor is None:
placeholder_result = ['', '', '', '0', '0', '0', '0'] print("CRNN模型未初始化请先调用initialize_crnn_model()")
return placeholder_result return ['', '', '', '0', '0', '0', '0']
try:
# 预处理图像
input_tensor = crnn_preprocessor.preprocess_numpy_array(image_array)
if input_tensor is None:
raise ValueError("图像预处理失败")
input_tensor = input_tensor.to(device)
# 模型推理
with torch.no_grad():
outputs = crnn_model(input_tensor) # (seq_len, batch_size, num_classes)
# 移除batch维度
outputs = outputs.squeeze(1) # (seq_len, num_classes)
# CTC解码
predicted_text, confidence, char_confidences = crnn_decoder.decode_with_confidence(outputs)
print(f"CRNN识别结果: {predicted_text}, 置信度: {confidence:.3f}")
# 将字符串转换为字符列表
char_list = list(predicted_text)
# 确保返回7个字符车牌标准长度
if len(char_list) < 7:
# 如果识别结果少于7个字符用'0'补齐
char_list.extend(['0'] * (7 - len(char_list)))
elif len(char_list) > 7:
# 如果识别结果多于7个字符截取前7个
char_list = char_list[:7]
return char_list
except Exception as e:
print(f"CRNN识别失败: {e}")
import traceback
traceback.print_exc()
return ['', '', '', '', '0', '0', '0']

View File

@ -1,36 +1,28 @@
import numpy as np import numpy as np
from paddleocr import TextRecognition
import cv2
def initialize_ocr_model(): class OCRProcessor:
""" def __init__(self):
初始化OCR模型 self.model = TextRecognition(model_name="PP-OCRv5_server_rec")
返回:
bool: 初始化是否成功
"""
# OCR模型初始化代码
# 例如: 加载预训练模型、设置参数等
print("OCR模型初始化完成占位") print("OCR模型初始化完成占位")
return True
def ocr_predict(image_array): def predict(self, image_array):
""" # 保持原有模型调用方式
OCR车牌号识别接口函数 output = self.model.predict(input=image_array)
# 结构化输出结果
参数: results = output[0]["rec_text"]
image_array: numpy数组格式的车牌图像已经过矫正处理 placeholder_result = results.split(',')
返回:
list: 包含7个字符的列表代表车牌号的每个字符
例如: ['', 'A', '1', '2', '3', '4', '5']
"""
# 这是OCR部分的占位函数
# 实际实现时,这里应该包含:
# 1. 图像预处理
# 2. OCR模型推理
# 3. 后处理和字符识别
# 临时返回占位结果
placeholder_result = ['', '', '', '0', '0', '0', '0']
return placeholder_result return placeholder_result
# 保留原有函数接口
_processor = OCRProcessor()
def initialize_ocr_model():
return _processor
def ocr_predict(image_array):
return _processor.predict(image_array)

View File

@ -15,7 +15,7 @@ License_plate_recognition/
├── OCR_part/ # OCR识别模块 ├── OCR_part/ # OCR识别模块
│ └── ocr_interface.py # OCR接口占位 │ └── ocr_interface.py # OCR接口占位
└── CRNN_part/ # CRNN识别模块 └── CRNN_part/ # CRNN识别模块
└── crnn_interface.py # CRNN接口(占位) └── crnn_interface.py # CRNN
``` ```
## 功能特性 ## 功能特性

11
main.py
View File

@ -10,7 +10,10 @@ from PyQt5.QtGui import QImage, QPixmap, QFont, QPainter, QPen, QColor
import os import os
from yolopart.detector import LicensePlateYOLO from yolopart.detector import LicensePlateYOLO
from OCR_part.ocr_interface import ocr_predict from OCR_part.ocr_interface import ocr_predict
#from CRNN_part.crnn_interface import crnn_predict不使用CRNN from OCR_part.ocr_interface import initialize_ocr_model
# 使用CRNN进行车牌字符识别
# from CRNN_part.crnn_interface import crnn_predict
from CRNN_part.crnn_interface import initialize_crnn_model
class CameraThread(QThread): class CameraThread(QThread):
"""摄像头线程类""" """摄像头线程类"""
@ -160,6 +163,11 @@ class MainWindow(QMainWindow):
self.init_detector() self.init_detector()
self.init_camera() self.init_camera()
# 初始化OCR/CRNN模型具体用哪个模块识别车牌号就写在这儿
initialize_ocr_model()
# initialize_crnn_model()
def init_ui(self): def init_ui(self):
"""初始化用户界面""" """初始化用户界面"""
self.setWindowTitle("车牌识别系统") self.setWindowTitle("车牌识别系统")
@ -385,6 +393,7 @@ class MainWindow(QMainWindow):
# 使用OCR接口进行识别 # 使用OCR接口进行识别
# 可以根据需要切换为CRNN: crnn_predict(corrected_image) # 可以根据需要切换为CRNN: crnn_predict(corrected_image)
result = ocr_predict(corrected_image) result = ocr_predict(corrected_image)
# result = crnn_predict(corrected_image)
# 将字符列表转换为字符串 # 将字符列表转换为字符串
if isinstance(result, list) and len(result) >= 7: if isinstance(result, list) and len(result) >= 7:

View File

@ -11,6 +11,11 @@ PyQt5>=5.15.0
# 图像处理 # 图像处理
Pillow>=8.0.0 Pillow>=8.0.0
#paddleocr
python -m pip install paddlepaddle-gpu==3.0.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu118/
python -m pip install "paddleocr[all]"
# 可选如果需要GPU加速 # 可选如果需要GPU加速
# torch>=1.9.0 # torch>=1.9.0
# torchvision>=0.10.0 # torchvision>=0.10.0