Compare commits
	
		
			15 Commits
		
	
	
		
			3d7c7a06e4
			...
			8eef0d9414
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 8eef0d9414 | |||
| 8e8fda7fe9 | |||
| 9879cb1547 | |||
| 3829cf76ee | |||
| c8a541ec11 | |||
| b5839d2c36 | |||
| afe15b990a | |||
| 7f89965956 | |||
| c7ecc5325e | |||
| 01b286fce1 | |||
| 85c8302fc1 | |||
| 0cd70df215 | |||
| 658560c34f | |||
| c773a12f90 | |||
| a41a4a2236 | 
							
								
								
									
										2
									
								
								.idea/License_plate_recognition.iml
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										2
									
								
								.idea/License_plate_recognition.iml
									
									
									
										generated
									
									
									
								
							@@ -2,7 +2,7 @@
 | 
			
		||||
<module type="PYTHON_MODULE" version="4">
 | 
			
		||||
  <component name="NewModuleRootManager">
 | 
			
		||||
    <content url="file://$MODULE_DIR$" />
 | 
			
		||||
    <orderEntry type="jdk" jdkName="pytorh" jdkType="Python SDK" />
 | 
			
		||||
    <orderEntry type="jdk" jdkName="D:\conda_envs\RLP" jdkType="Python SDK" />
 | 
			
		||||
    <orderEntry type="sourceFolder" forTests="false" />
 | 
			
		||||
  </component>
 | 
			
		||||
  <component name="PyDocumentationSettings">
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										2
									
								
								.idea/misc.xml
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										2
									
								
								.idea/misc.xml
									
									
									
										generated
									
									
									
								
							@@ -3,5 +3,5 @@
 | 
			
		||||
  <component name="Black">
 | 
			
		||||
    <option name="sdkName" value="pytorh" />
 | 
			
		||||
  </component>
 | 
			
		||||
  <component name="ProjectRootManager" version="2" project-jdk-name="pytorh" project-jdk-type="Python SDK" />
 | 
			
		||||
  <component name="ProjectRootManager" version="2" project-jdk-name="D:\conda_envs\RLP" project-jdk-type="Python SDK" />
 | 
			
		||||
</project>
 | 
			
		||||
							
								
								
									
										
											BIN
										
									
								
								CRNN_part/best_model.pth
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								CRNN_part/best_model.pth
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							@@ -1,4 +1,211 @@
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
import numpy as np
 | 
			
		||||
from PIL import Image
 | 
			
		||||
import cv2
 | 
			
		||||
from torchvision import transforms
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
# 全局变量
 | 
			
		||||
crnn_model = None
 | 
			
		||||
crnn_decoder = None
 | 
			
		||||
crnn_preprocessor = None
 | 
			
		||||
device = None
 | 
			
		||||
 | 
			
		||||
class CRNN(nn.Module):
 | 
			
		||||
    """CRNN车牌识别模型"""
 | 
			
		||||
    def __init__(self, img_height=32, num_classes=68, hidden_size=256):
 | 
			
		||||
        super(CRNN, self).__init__()
 | 
			
		||||
        self.img_height = img_height
 | 
			
		||||
        self.num_classes = num_classes
 | 
			
		||||
        self.hidden_size = hidden_size
 | 
			
		||||
        
 | 
			
		||||
        # CNN特征提取部分 - 7层卷积
 | 
			
		||||
        self.cnn = nn.Sequential(
 | 
			
		||||
            # 第1层:3->64, 3x3卷积
 | 
			
		||||
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
 | 
			
		||||
            nn.BatchNorm2d(64),
 | 
			
		||||
            nn.ReLU(inplace=True),
 | 
			
		||||
            nn.MaxPool2d(kernel_size=2, stride=2),
 | 
			
		||||
            
 | 
			
		||||
            # 第2层:64->128, 3x3卷积
 | 
			
		||||
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
 | 
			
		||||
            nn.BatchNorm2d(128),
 | 
			
		||||
            nn.ReLU(inplace=True),
 | 
			
		||||
            nn.MaxPool2d(kernel_size=2, stride=2),
 | 
			
		||||
            
 | 
			
		||||
            # 第3层:128->256, 3x3卷积
 | 
			
		||||
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
 | 
			
		||||
            nn.BatchNorm2d(256),
 | 
			
		||||
            nn.ReLU(inplace=True),
 | 
			
		||||
            
 | 
			
		||||
            # 第4层:256->256, 3x3卷积
 | 
			
		||||
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
 | 
			
		||||
            nn.BatchNorm2d(256),
 | 
			
		||||
            nn.ReLU(inplace=True),
 | 
			
		||||
            nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1)),
 | 
			
		||||
            
 | 
			
		||||
            # 第5层:256->512, 3x3卷积
 | 
			
		||||
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
 | 
			
		||||
            nn.BatchNorm2d(512),
 | 
			
		||||
            nn.ReLU(inplace=True),
 | 
			
		||||
            
 | 
			
		||||
            # 第6层:512->512, 3x3卷积
 | 
			
		||||
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
 | 
			
		||||
            nn.BatchNorm2d(512),
 | 
			
		||||
            nn.ReLU(inplace=True),
 | 
			
		||||
            nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1)),
 | 
			
		||||
            
 | 
			
		||||
            # 第7层:512->512, 2x2卷积
 | 
			
		||||
            nn.Conv2d(512, 512, kernel_size=2, stride=1, padding=0),
 | 
			
		||||
            nn.BatchNorm2d(512),
 | 
			
		||||
            nn.ReLU(inplace=True),
 | 
			
		||||
        )
 | 
			
		||||
        
 | 
			
		||||
        # RNN序列建模部分 - 2层双向LSTM
 | 
			
		||||
        self.rnn = nn.LSTM(
 | 
			
		||||
            input_size=512,
 | 
			
		||||
            hidden_size=hidden_size,
 | 
			
		||||
            num_layers=2,
 | 
			
		||||
            batch_first=True,
 | 
			
		||||
            bidirectional=True
 | 
			
		||||
        )
 | 
			
		||||
        
 | 
			
		||||
        # 全连接分类层
 | 
			
		||||
        self.fc = nn.Linear(hidden_size * 2, num_classes)
 | 
			
		||||
        
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        batch_size = x.size(0)
 | 
			
		||||
        
 | 
			
		||||
        # CNN特征提取
 | 
			
		||||
        conv_out = self.cnn(x)
 | 
			
		||||
        
 | 
			
		||||
        # 重塑为RNN输入格式
 | 
			
		||||
        batch_size, channels, height, width = conv_out.size()
 | 
			
		||||
        conv_out = conv_out.permute(0, 3, 1, 2)
 | 
			
		||||
        conv_out = conv_out.contiguous().view(batch_size, width, channels * height)
 | 
			
		||||
        
 | 
			
		||||
        # RNN序列建模
 | 
			
		||||
        rnn_out, _ = self.rnn(conv_out)
 | 
			
		||||
        
 | 
			
		||||
        # 全连接分类
 | 
			
		||||
        output = self.fc(rnn_out)
 | 
			
		||||
        
 | 
			
		||||
        # 转换为CTC需要的格式:(width, batch_size, num_classes)
 | 
			
		||||
        output = output.permute(1, 0, 2)
 | 
			
		||||
        
 | 
			
		||||
        return output
 | 
			
		||||
 | 
			
		||||
class CTCDecoder:
 | 
			
		||||
    """CTC解码器"""
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        # 定义中国车牌字符集(68个字符)
 | 
			
		||||
        self.chars = [
 | 
			
		||||
            # 空白字符(CTC需要)
 | 
			
		||||
            '<BLANK>',
 | 
			
		||||
            # 中文省份简称
 | 
			
		||||
            '京', '沪', '津', '渝', '冀', '晋', '蒙', '辽', '吉', '黑',
 | 
			
		||||
            '苏', '浙', '皖', '闽', '赣', '鲁', '豫', '鄂', '湘', '粤',
 | 
			
		||||
            '桂', '琼', '川', '贵', '云', '藏', '陕', '甘', '青', '宁', '新',
 | 
			
		||||
            # 字母 A-Z
 | 
			
		||||
            'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 
 | 
			
		||||
            'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',
 | 
			
		||||
            # 数字 0-9
 | 
			
		||||
            '0', '1', '2', '3', '4', '5', '6', '7', '8', '9'
 | 
			
		||||
        ]
 | 
			
		||||
        
 | 
			
		||||
        self.char_to_idx = {char: idx for idx, char in enumerate(self.chars)}
 | 
			
		||||
        self.idx_to_char = {idx: char for idx, char in enumerate(self.chars)}
 | 
			
		||||
        self.blank_idx = 0
 | 
			
		||||
    
 | 
			
		||||
    def decode_greedy(self, predictions):
 | 
			
		||||
        """贪婪解码"""
 | 
			
		||||
        # 获取每个时间步的最大概率索引
 | 
			
		||||
        indices = torch.argmax(predictions, dim=1)
 | 
			
		||||
        
 | 
			
		||||
        # CTC解码:移除重复字符和空白字符
 | 
			
		||||
        decoded_chars = []
 | 
			
		||||
        prev_idx = -1
 | 
			
		||||
        
 | 
			
		||||
        for idx in indices:
 | 
			
		||||
            idx = idx.item()
 | 
			
		||||
            if idx != prev_idx and idx != self.blank_idx:
 | 
			
		||||
                if idx < len(self.chars):
 | 
			
		||||
                    decoded_chars.append(self.chars[idx])
 | 
			
		||||
            prev_idx = idx
 | 
			
		||||
        
 | 
			
		||||
        return ''.join(decoded_chars)
 | 
			
		||||
    
 | 
			
		||||
    def decode_with_confidence(self, predictions):
 | 
			
		||||
        """解码并返回置信度信息"""
 | 
			
		||||
        # 应用softmax获得概率
 | 
			
		||||
        probs = torch.softmax(predictions, dim=1)
 | 
			
		||||
        
 | 
			
		||||
        # 贪婪解码
 | 
			
		||||
        indices = torch.argmax(probs, dim=1)
 | 
			
		||||
        max_probs = torch.max(probs, dim=1)[0]
 | 
			
		||||
        
 | 
			
		||||
        # CTC解码
 | 
			
		||||
        decoded_chars = []
 | 
			
		||||
        char_confidences = []
 | 
			
		||||
        prev_idx = -1
 | 
			
		||||
        
 | 
			
		||||
        for i, idx in enumerate(indices):
 | 
			
		||||
            idx = idx.item()
 | 
			
		||||
            confidence = max_probs[i].item()
 | 
			
		||||
            
 | 
			
		||||
            if idx != prev_idx and idx != self.blank_idx:
 | 
			
		||||
                if idx < len(self.chars):
 | 
			
		||||
                    decoded_chars.append(self.chars[idx])
 | 
			
		||||
                    char_confidences.append(confidence)
 | 
			
		||||
            prev_idx = idx
 | 
			
		||||
        
 | 
			
		||||
        text = ''.join(decoded_chars)
 | 
			
		||||
        avg_confidence = np.mean(char_confidences) if char_confidences else 0.0
 | 
			
		||||
        
 | 
			
		||||
        return text, avg_confidence, char_confidences
 | 
			
		||||
 | 
			
		||||
class LicensePlatePreprocessor:
 | 
			
		||||
    """车牌图像预处理器"""
 | 
			
		||||
    def __init__(self, target_height=32, target_width=128):
 | 
			
		||||
        self.target_height = target_height
 | 
			
		||||
        self.target_width = target_width
 | 
			
		||||
        
 | 
			
		||||
        # 定义图像变换
 | 
			
		||||
        self.transform = transforms.Compose([
 | 
			
		||||
            transforms.Resize((target_height, target_width)),
 | 
			
		||||
            transforms.ToTensor(),
 | 
			
		||||
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
 | 
			
		||||
                               std=[0.229, 0.224, 0.225])
 | 
			
		||||
        ])
 | 
			
		||||
    
 | 
			
		||||
    def preprocess_numpy_array(self, image_array):
 | 
			
		||||
        """预处理numpy数组格式的图像"""
 | 
			
		||||
        try:
 | 
			
		||||
            # 确保图像是RGB格式
 | 
			
		||||
            if len(image_array.shape) == 3 and image_array.shape[2] == 3:
 | 
			
		||||
                # 如果是BGR格式,转换为RGB
 | 
			
		||||
                if image_array.dtype == np.uint8:
 | 
			
		||||
                    image_array = cv2.cvtColor(image_array, cv2.COLOR_BGR2RGB)
 | 
			
		||||
            
 | 
			
		||||
            # 转换为PIL图像
 | 
			
		||||
            if image_array.dtype != np.uint8:
 | 
			
		||||
                image_array = (image_array * 255).astype(np.uint8)
 | 
			
		||||
            
 | 
			
		||||
            image = Image.fromarray(image_array)
 | 
			
		||||
            
 | 
			
		||||
            # 应用变换
 | 
			
		||||
            tensor = self.transform(image)
 | 
			
		||||
            
 | 
			
		||||
            # 添加batch维度
 | 
			
		||||
            tensor = tensor.unsqueeze(0)
 | 
			
		||||
            
 | 
			
		||||
            return tensor
 | 
			
		||||
            
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            print(f"图像预处理失败: {e}")
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
def initialize_crnn_model():
 | 
			
		||||
    """
 | 
			
		||||
@@ -7,12 +214,65 @@ def initialize_crnn_model():
 | 
			
		||||
    返回:
 | 
			
		||||
        bool: 初始化是否成功
 | 
			
		||||
    """
 | 
			
		||||
    # CRNN模型初始化代码
 | 
			
		||||
    # 例如: 加载预训练模型、设置参数等
 | 
			
		||||
    global crnn_model, crnn_decoder, crnn_preprocessor, device
 | 
			
		||||
    
 | 
			
		||||
    try:
 | 
			
		||||
        # 设置设备
 | 
			
		||||
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
 | 
			
		||||
        print(f"CRNN使用设备: {device}")
 | 
			
		||||
        
 | 
			
		||||
        # 初始化组件
 | 
			
		||||
        crnn_decoder = CTCDecoder()
 | 
			
		||||
        crnn_preprocessor = LicensePlatePreprocessor(target_height=32, target_width=128)
 | 
			
		||||
        
 | 
			
		||||
        # 创建模型实例
 | 
			
		||||
        crnn_model = CRNN(num_classes=len(crnn_decoder.chars), hidden_size=256)
 | 
			
		||||
        
 | 
			
		||||
        # 加载模型权重
 | 
			
		||||
        model_path = os.path.join(os.path.dirname(__file__), 'best_model.pth')
 | 
			
		||||
        
 | 
			
		||||
        if not os.path.exists(model_path):
 | 
			
		||||
            raise FileNotFoundError(f"模型文件不存在: {model_path}")
 | 
			
		||||
        
 | 
			
		||||
        print(f"正在加载CRNN模型: {model_path}")
 | 
			
		||||
        
 | 
			
		||||
        # 加载检查点
 | 
			
		||||
        checkpoint = torch.load(model_path, map_location=device, weights_only=False)
 | 
			
		||||
        
 | 
			
		||||
        # 处理不同的模型保存格式
 | 
			
		||||
        if isinstance(checkpoint, dict):
 | 
			
		||||
            if 'model_state_dict' in checkpoint:
 | 
			
		||||
                # 完整检查点格式
 | 
			
		||||
                state_dict = checkpoint['model_state_dict']
 | 
			
		||||
                print(f"检查点信息:")
 | 
			
		||||
                print(f"  - 训练轮次: {checkpoint.get('epoch', 'N/A')}")
 | 
			
		||||
                print(f"  - 最佳验证损失: {checkpoint.get('best_val_loss', 'N/A')}")
 | 
			
		||||
            else:
 | 
			
		||||
                # 精简模型格式(只包含权重)
 | 
			
		||||
                print("加载精简模型(仅权重)")
 | 
			
		||||
                state_dict = checkpoint
 | 
			
		||||
        else:
 | 
			
		||||
            # 直接是状态字典
 | 
			
		||||
            state_dict = checkpoint
 | 
			
		||||
        
 | 
			
		||||
        # 加载权重
 | 
			
		||||
        crnn_model.load_state_dict(state_dict)
 | 
			
		||||
        crnn_model.to(device)
 | 
			
		||||
        crnn_model.eval()
 | 
			
		||||
        
 | 
			
		||||
        print("CRNN模型初始化完成")
 | 
			
		||||
        
 | 
			
		||||
        # 统计模型参数
 | 
			
		||||
        total_params = sum(p.numel() for p in crnn_model.parameters())
 | 
			
		||||
        print(f"CRNN模型参数数量: {total_params:,}")
 | 
			
		||||
        
 | 
			
		||||
    print("CRNN模型初始化完成(占位)")
 | 
			
		||||
        return True
 | 
			
		||||
        
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        print(f"CRNN模型初始化失败: {e}")
 | 
			
		||||
        import traceback
 | 
			
		||||
        traceback.print_exc()
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
def crnn_predict(image_array):
 | 
			
		||||
    """
 | 
			
		||||
@@ -25,13 +285,47 @@ def crnn_predict(image_array):
 | 
			
		||||
        list: 包含7个字符的列表,代表车牌号的每个字符
 | 
			
		||||
              例如: ['京', 'A', '1', '2', '3', '4', '5']
 | 
			
		||||
    """
 | 
			
		||||
    # 这是CRNN部分的占位函数
 | 
			
		||||
    # 实际实现时,这里应该包含:
 | 
			
		||||
    # 1. 图像预处理
 | 
			
		||||
    # 2. CRNN模型推理
 | 
			
		||||
    # 3. CTC解码
 | 
			
		||||
    # 4. 后处理和字符识别
 | 
			
		||||
    global crnn_model, crnn_decoder, crnn_preprocessor, device
 | 
			
		||||
    
 | 
			
		||||
    # 临时返回占位结果
 | 
			
		||||
    placeholder_result = ['待', '识', '别', '0', '0', '0', '0']
 | 
			
		||||
    return placeholder_result
 | 
			
		||||
    if crnn_model is None or crnn_decoder is None or crnn_preprocessor is None:
 | 
			
		||||
        print("CRNN模型未初始化,请先调用initialize_crnn_model()")
 | 
			
		||||
        return ['待', '识', '别', '0', '0', '0', '0']
 | 
			
		||||
    
 | 
			
		||||
    try:
 | 
			
		||||
        # 预处理图像
 | 
			
		||||
        input_tensor = crnn_preprocessor.preprocess_numpy_array(image_array)
 | 
			
		||||
        if input_tensor is None:
 | 
			
		||||
            raise ValueError("图像预处理失败")
 | 
			
		||||
        
 | 
			
		||||
        input_tensor = input_tensor.to(device)
 | 
			
		||||
        
 | 
			
		||||
        # 模型推理
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            outputs = crnn_model(input_tensor)  # (seq_len, batch_size, num_classes)
 | 
			
		||||
            
 | 
			
		||||
            # 移除batch维度
 | 
			
		||||
            outputs = outputs.squeeze(1)  # (seq_len, num_classes)
 | 
			
		||||
            
 | 
			
		||||
            # CTC解码
 | 
			
		||||
            predicted_text, confidence, char_confidences = crnn_decoder.decode_with_confidence(outputs)
 | 
			
		||||
            
 | 
			
		||||
            print(f"CRNN识别结果: {predicted_text}, 置信度: {confidence:.3f}")
 | 
			
		||||
            
 | 
			
		||||
            # 将字符串转换为字符列表
 | 
			
		||||
            char_list = list(predicted_text)
 | 
			
		||||
            
 | 
			
		||||
            # 确保返回7个字符(车牌标准长度)
 | 
			
		||||
            if len(char_list) < 7:
 | 
			
		||||
                # 如果识别结果少于7个字符,用'0'补齐
 | 
			
		||||
                char_list.extend(['0'] * (7 - len(char_list)))
 | 
			
		||||
            elif len(char_list) > 7:
 | 
			
		||||
                # 如果识别结果多于7个字符,截取前7个
 | 
			
		||||
                char_list = char_list[:7]
 | 
			
		||||
            
 | 
			
		||||
            return char_list
 | 
			
		||||
            
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        print(f"CRNN识别失败: {e}")
 | 
			
		||||
        import traceback
 | 
			
		||||
        traceback.print_exc()
 | 
			
		||||
        return ['识', '别', '失', '败', '0', '0', '0']
 | 
			
		||||
 
 | 
			
		||||
@@ -1,36 +1,28 @@
 | 
			
		||||
import numpy as np
 | 
			
		||||
from paddleocr import TextRecognition
 | 
			
		||||
import cv2
 | 
			
		||||
 | 
			
		||||
def initialize_ocr_model():
 | 
			
		||||
    """
 | 
			
		||||
    初始化OCR模型
 | 
			
		||||
    
 | 
			
		||||
    返回:
 | 
			
		||||
        bool: 初始化是否成功
 | 
			
		||||
    """
 | 
			
		||||
    # OCR模型初始化代码
 | 
			
		||||
    # 例如: 加载预训练模型、设置参数等
 | 
			
		||||
    
 | 
			
		||||
class OCRProcessor:
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.model = TextRecognition(model_name="PP-OCRv5_server_rec")
 | 
			
		||||
        print("OCR模型初始化完成(占位)")
 | 
			
		||||
    return True
 | 
			
		||||
 | 
			
		||||
def ocr_predict(image_array):
 | 
			
		||||
    """
 | 
			
		||||
    OCR车牌号识别接口函数
 | 
			
		||||
    
 | 
			
		||||
    参数:
 | 
			
		||||
        image_array: numpy数组格式的车牌图像,已经过矫正处理
 | 
			
		||||
    
 | 
			
		||||
    返回:
 | 
			
		||||
        list: 包含7个字符的列表,代表车牌号的每个字符
 | 
			
		||||
              例如: ['京', 'A', '1', '2', '3', '4', '5']
 | 
			
		||||
    """
 | 
			
		||||
    # 这是OCR部分的占位函数
 | 
			
		||||
    # 实际实现时,这里应该包含:
 | 
			
		||||
    # 1. 图像预处理
 | 
			
		||||
    # 2. OCR模型推理
 | 
			
		||||
    # 3. 后处理和字符识别
 | 
			
		||||
    
 | 
			
		||||
    # 临时返回占位结果
 | 
			
		||||
    placeholder_result = ['待', '识', '别', '0', '0', '0', '0']
 | 
			
		||||
    def predict(self, image_array):
 | 
			
		||||
        # 保持原有模型调用方式
 | 
			
		||||
        output = self.model.predict(input=image_array)
 | 
			
		||||
        # 结构化输出结果
 | 
			
		||||
        results = output[0]["rec_text"]
 | 
			
		||||
        placeholder_result = results.split(',')
 | 
			
		||||
        return placeholder_result
 | 
			
		||||
 | 
			
		||||
# 保留原有函数接口
 | 
			
		||||
_processor = OCRProcessor()
 | 
			
		||||
 | 
			
		||||
def initialize_ocr_model():
 | 
			
		||||
    return _processor
 | 
			
		||||
 | 
			
		||||
def ocr_predict(image_array):
 | 
			
		||||
    return _processor.predict(image_array)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -15,7 +15,7 @@ License_plate_recognition/
 | 
			
		||||
├── OCR_part/                 # OCR识别模块
 | 
			
		||||
│   └── ocr_interface.py      # OCR接口(占位)
 | 
			
		||||
└── CRNN_part/                # CRNN识别模块
 | 
			
		||||
    └── crnn_interface.py     # CRNN接口(占位)
 | 
			
		||||
    └── crnn_interface.py     # CRNN
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
## 功能特性
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										11
									
								
								main.py
									
									
									
									
									
								
							
							
						
						
									
										11
									
								
								main.py
									
									
									
									
									
								
							@@ -10,7 +10,10 @@ from PyQt5.QtGui import QImage, QPixmap, QFont, QPainter, QPen, QColor
 | 
			
		||||
import os
 | 
			
		||||
from yolopart.detector import LicensePlateYOLO
 | 
			
		||||
from OCR_part.ocr_interface import ocr_predict
 | 
			
		||||
#from CRNN_part.crnn_interface import crnn_predict(不使用CRNN)
 | 
			
		||||
from OCR_part.ocr_interface import initialize_ocr_model
 | 
			
		||||
# 使用CRNN进行车牌字符识别
 | 
			
		||||
# from CRNN_part.crnn_interface import crnn_predict
 | 
			
		||||
from CRNN_part.crnn_interface import initialize_crnn_model
 | 
			
		||||
 | 
			
		||||
class CameraThread(QThread):
 | 
			
		||||
    """摄像头线程类"""
 | 
			
		||||
@@ -160,6 +163,11 @@ class MainWindow(QMainWindow):
 | 
			
		||||
        self.init_detector()
 | 
			
		||||
        self.init_camera()
 | 
			
		||||
 | 
			
		||||
        # 初始化OCR/CRNN模型(具体用哪个模块识别车牌号就写在这儿)
 | 
			
		||||
        initialize_ocr_model()
 | 
			
		||||
        # initialize_crnn_model()
 | 
			
		||||
 | 
			
		||||
    
 | 
			
		||||
    def init_ui(self):
 | 
			
		||||
        """初始化用户界面"""
 | 
			
		||||
        self.setWindowTitle("车牌识别系统")
 | 
			
		||||
@@ -385,6 +393,7 @@ class MainWindow(QMainWindow):
 | 
			
		||||
            # 使用OCR接口进行识别
 | 
			
		||||
            # 可以根据需要切换为CRNN: crnn_predict(corrected_image)
 | 
			
		||||
            result = ocr_predict(corrected_image)
 | 
			
		||||
            # result = crnn_predict(corrected_image)
 | 
			
		||||
            
 | 
			
		||||
            # 将字符列表转换为字符串
 | 
			
		||||
            if isinstance(result, list) and len(result) >= 7:
 | 
			
		||||
 
 | 
			
		||||
@@ -11,6 +11,11 @@ PyQt5>=5.15.0
 | 
			
		||||
# 图像处理
 | 
			
		||||
Pillow>=8.0.0
 | 
			
		||||
 | 
			
		||||
#paddleocr
 | 
			
		||||
python -m pip install paddlepaddle-gpu==3.0.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu118/
 | 
			
		||||
python -m pip install "paddleocr[all]"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# 可选:如果需要GPU加速
 | 
			
		||||
# torch>=1.9.0
 | 
			
		||||
# torchvision>=0.10.0
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user