546 lines
		
	
	
		
			20 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			546 lines
		
	
	
		
			20 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import torch
 | 
						||
import torch.nn as nn
 | 
						||
import torch.nn.functional as F
 | 
						||
import numpy as np
 | 
						||
from PIL import Image
 | 
						||
import cv2
 | 
						||
from torchvision import transforms
 | 
						||
import os
 | 
						||
import math
 | 
						||
 | 
						||
# 全局变量
 | 
						||
lightcrnn_model = None
 | 
						||
lightcrnn_decoder = None
 | 
						||
lightcrnn_preprocessor = None
 | 
						||
device = None
 | 
						||
 | 
						||
class DepthwiseSeparableConv(nn.Module):
 | 
						||
    """深度可分离卷积"""
 | 
						||
    
 | 
						||
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
 | 
						||
        super(DepthwiseSeparableConv, self).__init__()
 | 
						||
        # 深度卷积
 | 
						||
        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, 
 | 
						||
                                 stride=stride, padding=padding, groups=in_channels, bias=False)
 | 
						||
        # 逐点卷积
 | 
						||
        self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
 | 
						||
        self.bn = nn.BatchNorm2d(out_channels)
 | 
						||
        self.relu = nn.ReLU6(inplace=True)
 | 
						||
    
 | 
						||
    def forward(self, x):
 | 
						||
        x = self.depthwise(x)
 | 
						||
        x = self.pointwise(x)
 | 
						||
        x = self.bn(x)
 | 
						||
        x = self.relu(x)
 | 
						||
        return x
 | 
						||
 | 
						||
class ChannelAttention(nn.Module):
 | 
						||
    """通道注意力机制"""
 | 
						||
    
 | 
						||
    def __init__(self, in_channels, reduction=16):
 | 
						||
        super(ChannelAttention, self).__init__()
 | 
						||
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
 | 
						||
        self.max_pool = nn.AdaptiveMaxPool2d(1)
 | 
						||
        
 | 
						||
        self.fc = nn.Sequential(
 | 
						||
            nn.Conv2d(in_channels, in_channels // reduction, 1, bias=False),
 | 
						||
            nn.ReLU(inplace=True),
 | 
						||
            nn.Conv2d(in_channels // reduction, in_channels, 1, bias=False)
 | 
						||
        )
 | 
						||
        self.sigmoid = nn.Sigmoid()
 | 
						||
    
 | 
						||
    def forward(self, x):
 | 
						||
        avg_out = self.fc(self.avg_pool(x))
 | 
						||
        max_out = self.fc(self.max_pool(x))
 | 
						||
        out = avg_out + max_out
 | 
						||
        return x * self.sigmoid(out)
 | 
						||
 | 
						||
class InvertedResidual(nn.Module):
 | 
						||
    """MobileNetV2的倒残差块"""
 | 
						||
    
 | 
						||
    def __init__(self, in_channels, out_channels, stride=1, expand_ratio=6):
 | 
						||
        super(InvertedResidual, self).__init__()
 | 
						||
        self.stride = stride
 | 
						||
        self.use_residual = stride == 1 and in_channels == out_channels
 | 
						||
        
 | 
						||
        hidden_dim = int(round(in_channels * expand_ratio))
 | 
						||
        
 | 
						||
        layers = []
 | 
						||
        if expand_ratio != 1:
 | 
						||
            # 扩展层
 | 
						||
            layers.extend([
 | 
						||
                nn.Conv2d(in_channels, hidden_dim, 1, bias=False),
 | 
						||
                nn.BatchNorm2d(hidden_dim),
 | 
						||
                nn.ReLU6(inplace=True)
 | 
						||
            ])
 | 
						||
        
 | 
						||
        # 深度卷积
 | 
						||
        layers.extend([
 | 
						||
            nn.Conv2d(hidden_dim, hidden_dim, 3, stride=stride, padding=1, groups=hidden_dim, bias=False),
 | 
						||
            nn.BatchNorm2d(hidden_dim),
 | 
						||
            nn.ReLU6(inplace=True),
 | 
						||
            # 线性瓶颈
 | 
						||
            nn.Conv2d(hidden_dim, out_channels, 1, bias=False),
 | 
						||
            nn.BatchNorm2d(out_channels)
 | 
						||
        ])
 | 
						||
        
 | 
						||
        self.conv = nn.Sequential(*layers)
 | 
						||
    
 | 
						||
    def forward(self, x):
 | 
						||
        if self.use_residual:
 | 
						||
            return x + self.conv(x)
 | 
						||
        else:
 | 
						||
            return self.conv(x)
 | 
						||
 | 
						||
class LightweightCNN(nn.Module):
 | 
						||
    """增强版轻量化CNN特征提取器"""
 | 
						||
    
 | 
						||
    def __init__(self, num_channels=3):
 | 
						||
        super(LightweightCNN, self).__init__()
 | 
						||
        
 | 
						||
        # 初始卷积层 - 适当增加通道数
 | 
						||
        self.conv1 = nn.Sequential(
 | 
						||
            nn.Conv2d(num_channels, 48, kernel_size=3, stride=1, padding=1, bias=False),
 | 
						||
            nn.BatchNorm2d(48),
 | 
						||
            nn.ReLU6(inplace=True)
 | 
						||
        )
 | 
						||
        
 | 
						||
        # 增强版MobileNet风格的特征提取
 | 
						||
        self.features = nn.Sequential(
 | 
						||
            # 第一组:48 -> 32
 | 
						||
            InvertedResidual(48, 32, stride=1, expand_ratio=2),
 | 
						||
            InvertedResidual(32, 32, stride=1, expand_ratio=2),  # 增加一层
 | 
						||
            nn.MaxPool2d(kernel_size=2, stride=2),  # 32x128 -> 16x64
 | 
						||
            
 | 
						||
            # 第二组:32 -> 48
 | 
						||
            InvertedResidual(32, 48, stride=1, expand_ratio=4),
 | 
						||
            InvertedResidual(48, 48, stride=1, expand_ratio=4),
 | 
						||
            nn.MaxPool2d(kernel_size=2, stride=2),  # 16x64 -> 8x32
 | 
						||
            
 | 
						||
            # 第三组:48 -> 64
 | 
						||
            InvertedResidual(48, 64, stride=1, expand_ratio=4),
 | 
						||
            InvertedResidual(64, 64, stride=1, expand_ratio=4),
 | 
						||
            
 | 
						||
            # 第四组:64 -> 96
 | 
						||
            InvertedResidual(64, 96, stride=1, expand_ratio=4),
 | 
						||
            InvertedResidual(96, 96, stride=1, expand_ratio=4),
 | 
						||
            nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1)),  # 8x32 -> 4x32
 | 
						||
            
 | 
						||
            # 第五组:96 -> 128
 | 
						||
            InvertedResidual(96, 128, stride=1, expand_ratio=4),
 | 
						||
            InvertedResidual(128, 128, stride=1, expand_ratio=4),
 | 
						||
            nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1)),  # 4x32 -> 2x32
 | 
						||
            
 | 
						||
            # 最后的卷积层 - 增加通道数
 | 
						||
            nn.Conv2d(128, 160, kernel_size=2, stride=1, padding=0, bias=False),  # 2x32 -> 1x31
 | 
						||
            nn.BatchNorm2d(160),
 | 
						||
            nn.ReLU6(inplace=True)
 | 
						||
        )
 | 
						||
        
 | 
						||
        # 通道注意力
 | 
						||
        self.channel_attention = ChannelAttention(160)
 | 
						||
    
 | 
						||
    def forward(self, x):
 | 
						||
        x = self.conv1(x)
 | 
						||
        x = self.features(x)
 | 
						||
        x = self.channel_attention(x)
 | 
						||
        return x
 | 
						||
 | 
						||
class LightweightGRU(nn.Module):
 | 
						||
    """增强版轻量化GRU层"""
 | 
						||
    
 | 
						||
    def __init__(self, input_size, hidden_size, num_layers=2):  # 默认增加到2层
 | 
						||
        super(LightweightGRU, self).__init__()
 | 
						||
        self.gru = nn.GRU(input_size, hidden_size, num_layers=num_layers, 
 | 
						||
                         bidirectional=True, batch_first=True, dropout=0.2 if num_layers > 1 else 0)
 | 
						||
        # 增加一个额外的线性层
 | 
						||
        self.linear1 = nn.Linear(hidden_size * 2, hidden_size * 2)
 | 
						||
        self.linear2 = nn.Linear(hidden_size * 2, hidden_size)
 | 
						||
        self.dropout = nn.Dropout(0.2)  # 增加dropout率
 | 
						||
        self.norm = nn.LayerNorm(hidden_size)  # 添加层归一化
 | 
						||
    
 | 
						||
    def forward(self, x):
 | 
						||
        gru_out, _ = self.gru(x)
 | 
						||
        output = self.linear1(gru_out)
 | 
						||
        output = F.relu(output)  # 添加激活函数
 | 
						||
        output = self.dropout(output)
 | 
						||
        output = self.linear2(output)
 | 
						||
        output = self.norm(output)  # 应用层归一化
 | 
						||
        output = self.dropout(output)
 | 
						||
        return output
 | 
						||
 | 
						||
class LightweightCRNN(nn.Module):
 | 
						||
    """增强版轻量化CRNN模型"""
 | 
						||
    
 | 
						||
    def __init__(self, img_height, num_classes, num_channels=3, hidden_size=160):  # 调整隐藏层大小
 | 
						||
        super(LightweightCRNN, self).__init__()
 | 
						||
        
 | 
						||
        self.img_height = img_height
 | 
						||
        self.num_classes = num_classes
 | 
						||
        self.hidden_size = hidden_size
 | 
						||
        
 | 
						||
        # 增强版轻量化CNN特征提取器
 | 
						||
        self.cnn = LightweightCNN(num_channels)
 | 
						||
        
 | 
						||
        # 增强版轻量化RNN序列建模器
 | 
						||
        self.rnn = LightweightGRU(160, hidden_size, num_layers=2)  # 使用更大的输入尺寸和2层GRU
 | 
						||
        
 | 
						||
        # 输出层 - 添加额外的全连接层
 | 
						||
        self.fc = nn.Linear(hidden_size, hidden_size // 2)
 | 
						||
        self.dropout = nn.Dropout(0.2)
 | 
						||
        self.classifier = nn.Linear(hidden_size // 2, num_classes)
 | 
						||
        
 | 
						||
        # 初始化权重
 | 
						||
        self._initialize_weights()
 | 
						||
    
 | 
						||
    def _initialize_weights(self):
 | 
						||
        """初始化模型权重"""
 | 
						||
        for m in self.modules():
 | 
						||
            if isinstance(m, nn.Conv2d):
 | 
						||
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
 | 
						||
                if m.bias is not None:
 | 
						||
                    nn.init.constant_(m.bias, 0)
 | 
						||
            elif isinstance(m, nn.BatchNorm2d):
 | 
						||
                nn.init.constant_(m.weight, 1)
 | 
						||
                nn.init.constant_(m.bias, 0)
 | 
						||
            elif isinstance(m, nn.Linear):
 | 
						||
                nn.init.normal_(m.weight, 0, 0.01)
 | 
						||
                if m.bias is not None:
 | 
						||
                    nn.init.constant_(m.bias, 0)
 | 
						||
    
 | 
						||
    def forward(self, input):
 | 
						||
        """
 | 
						||
        input: [batch_size, channels, height, width]
 | 
						||
        output: [seq_len, batch_size, num_classes]
 | 
						||
        """
 | 
						||
        # CNN特征提取
 | 
						||
        conv_features = self.cnn(input)  # [batch_size, 160, 1, seq_len]
 | 
						||
        
 | 
						||
        # 重塑为RNN输入格式
 | 
						||
        batch_size, channels, height, width = conv_features.size()
 | 
						||
        assert height == 1, f"Height should be 1, got {height}"
 | 
						||
        
 | 
						||
        # [batch_size, 160, 1, seq_len] -> [batch_size, seq_len, 160]
 | 
						||
        conv_features = conv_features.squeeze(2)  # [batch_size, 160, seq_len]
 | 
						||
        conv_features = conv_features.permute(0, 2, 1)  # [batch_size, seq_len, 160]
 | 
						||
        
 | 
						||
        # RNN序列建模
 | 
						||
        rnn_output = self.rnn(conv_features)  # [batch_size, seq_len, hidden_size]
 | 
						||
        
 | 
						||
        # 全连接层处理
 | 
						||
        fc_output = self.fc(rnn_output)  # [batch_size, seq_len, hidden_size//2]
 | 
						||
        fc_output = F.relu(fc_output)
 | 
						||
        fc_output = self.dropout(fc_output)
 | 
						||
        
 | 
						||
        # 分类
 | 
						||
        output = self.classifier(fc_output)  # [batch_size, seq_len, num_classes]
 | 
						||
        
 | 
						||
        # 转换为CTC期望的格式: [seq_len, batch_size, num_classes]
 | 
						||
        output = output.permute(1, 0, 2)
 | 
						||
        
 | 
						||
        return output
 | 
						||
 | 
						||
class LightCTCDecoder:
 | 
						||
    """轻量化CTC解码器"""
 | 
						||
    def __init__(self):
 | 
						||
        # 中国车牌字符集
 | 
						||
        # 省份简称
 | 
						||
        provinces = ['京', '津', '沪', '渝', '冀', '豫', '云', '辽', '黑', '湘', '皖', '鲁',
 | 
						||
                    '新', '苏', '浙', '赣', '鄂', '桂', '甘', '晋', '蒙', '陕', '吉', '闽',
 | 
						||
                    '贵', '粤', '青', '藏', '川', '宁', '琼']
 | 
						||
        
 | 
						||
        # 字母(包含I和O)
 | 
						||
        letters = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M',
 | 
						||
                  'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z']
 | 
						||
        
 | 
						||
        # 数字
 | 
						||
        digits = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
 | 
						||
        
 | 
						||
        # 组合所有字符
 | 
						||
        self.character = provinces + letters + digits
 | 
						||
        
 | 
						||
        # 添加空白字符用于CTC
 | 
						||
        self.character = ['[blank]'] + self.character
 | 
						||
        
 | 
						||
        # 创建字符到索引的映射
 | 
						||
        self.dict = {char: i for i, char in enumerate(self.character)}
 | 
						||
        self.dict_reverse = {i: char for i, char in enumerate(self.character)}
 | 
						||
        
 | 
						||
        self.num_classes = len(self.character)
 | 
						||
        self.blank_idx = 0
 | 
						||
    
 | 
						||
    def decode_greedy(self, predictions):
 | 
						||
        """贪婪解码"""
 | 
						||
        # 获取每个时间步的最大概率索引
 | 
						||
        indices = torch.argmax(predictions, dim=1)
 | 
						||
        
 | 
						||
        # CTC解码:移除重复字符和空白字符
 | 
						||
        decoded_chars = []
 | 
						||
        prev_idx = -1
 | 
						||
        
 | 
						||
        for idx in indices:
 | 
						||
            idx = idx.item()
 | 
						||
            if idx != prev_idx and idx != self.blank_idx:
 | 
						||
                if idx < len(self.character):
 | 
						||
                    decoded_chars.append(self.character[idx])
 | 
						||
            prev_idx = idx
 | 
						||
        
 | 
						||
        return ''.join(decoded_chars)
 | 
						||
    
 | 
						||
    def decode_with_confidence(self, predictions):
 | 
						||
        """解码并返回置信度信息"""
 | 
						||
        # 应用softmax获得概率
 | 
						||
        probs = torch.softmax(predictions, dim=1)
 | 
						||
        
 | 
						||
        # 贪婪解码
 | 
						||
        indices = torch.argmax(probs, dim=1)
 | 
						||
        max_probs = torch.max(probs, dim=1)[0]
 | 
						||
        
 | 
						||
        # CTC解码
 | 
						||
        decoded_chars = []
 | 
						||
        char_confidences = []
 | 
						||
        prev_idx = -1
 | 
						||
        
 | 
						||
        for i, idx in enumerate(indices):
 | 
						||
            idx = idx.item()
 | 
						||
            confidence = max_probs[i].item()
 | 
						||
            
 | 
						||
            if idx != prev_idx and idx != self.blank_idx:
 | 
						||
                if idx < len(self.character):
 | 
						||
                    decoded_chars.append(self.character[idx])
 | 
						||
                    char_confidences.append(confidence)
 | 
						||
            prev_idx = idx
 | 
						||
        
 | 
						||
        text = ''.join(decoded_chars)
 | 
						||
        avg_confidence = np.mean(char_confidences) if char_confidences else 0.0
 | 
						||
        
 | 
						||
        return text, avg_confidence, char_confidences
 | 
						||
 | 
						||
class LightLicensePlatePreprocessor:
 | 
						||
    """轻量化车牌图像预处理器"""
 | 
						||
    def __init__(self, target_height=32, target_width=128):
 | 
						||
        self.target_height = target_height
 | 
						||
        self.target_width = target_width
 | 
						||
        
 | 
						||
        # 定义图像变换
 | 
						||
        self.transform = transforms.Compose([
 | 
						||
            transforms.Resize((target_height, target_width)),
 | 
						||
            transforms.ToTensor(),
 | 
						||
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
 | 
						||
                               std=[0.229, 0.224, 0.225])
 | 
						||
        ])
 | 
						||
    
 | 
						||
    def preprocess_numpy_array(self, image_array):
 | 
						||
        """预处理numpy数组格式的图像"""
 | 
						||
        try:
 | 
						||
            # 确保图像是RGB格式
 | 
						||
            if len(image_array.shape) == 3 and image_array.shape[2] == 3:
 | 
						||
                # 如果是BGR格式,转换为RGB
 | 
						||
                if image_array.dtype == np.uint8:
 | 
						||
                    image_array = cv2.cvtColor(image_array, cv2.COLOR_BGR2RGB)
 | 
						||
            
 | 
						||
            # 转换为PIL图像
 | 
						||
            if image_array.dtype != np.uint8:
 | 
						||
                image_array = (image_array * 255).astype(np.uint8)
 | 
						||
            
 | 
						||
            image = Image.fromarray(image_array)
 | 
						||
            
 | 
						||
            # 应用变换
 | 
						||
            tensor = self.transform(image)
 | 
						||
            
 | 
						||
            # 添加batch维度
 | 
						||
            tensor = tensor.unsqueeze(0)
 | 
						||
            
 | 
						||
            return tensor
 | 
						||
            
 | 
						||
        except Exception as e:
 | 
						||
            print(f"图像预处理失败: {e}")
 | 
						||
            return None
 | 
						||
 | 
						||
def LPRNinitialize_model():
 | 
						||
    """
 | 
						||
    初始化轻量化CRNN模型
 | 
						||
    
 | 
						||
    返回:
 | 
						||
        bool: 初始化是否成功
 | 
						||
    """
 | 
						||
    global lightcrnn_model, lightcrnn_decoder, lightcrnn_preprocessor, device
 | 
						||
    
 | 
						||
    try:
 | 
						||
        # 设置设备
 | 
						||
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
 | 
						||
        print(f"LightCRNN使用设备: {device}")
 | 
						||
        
 | 
						||
        # 初始化组件
 | 
						||
        lightcrnn_decoder = LightCTCDecoder()
 | 
						||
        lightcrnn_preprocessor = LightLicensePlatePreprocessor(target_height=32, target_width=128)
 | 
						||
        
 | 
						||
        # 创建模型实例
 | 
						||
        lightcrnn_model = LightweightCRNN(
 | 
						||
            img_height=32, 
 | 
						||
            num_classes=lightcrnn_decoder.num_classes, 
 | 
						||
            hidden_size=160
 | 
						||
        )
 | 
						||
        
 | 
						||
        # 加载模型权重
 | 
						||
        model_path = os.path.join(os.path.dirname(__file__), 'best_model.pth')
 | 
						||
        
 | 
						||
        if not os.path.exists(model_path):
 | 
						||
            raise FileNotFoundError(f"模型文件不存在: {model_path}")
 | 
						||
        
 | 
						||
        print(f"正在加载LightCRNN模型: {model_path}")
 | 
						||
        
 | 
						||
        # 加载检查点,处理可能的模块依赖问题
 | 
						||
        try:
 | 
						||
            checkpoint = torch.load(model_path, map_location=device, weights_only=False)
 | 
						||
        except (ModuleNotFoundError, AttributeError) as e:
 | 
						||
            if 'config' in str(e) or 'Config' in str(e):
 | 
						||
                print("检测到模型文件包含config依赖,尝试使用weights_only模式加载...")
 | 
						||
                try:
 | 
						||
                    # 尝试使用weights_only=True来避免pickle问题
 | 
						||
                    checkpoint = torch.load(model_path, map_location=device, weights_only=True)
 | 
						||
                except Exception:
 | 
						||
                    # 如果还是失败,创建一个更完整的mock config
 | 
						||
                    import sys
 | 
						||
                    import types
 | 
						||
                    
 | 
						||
                    # 创建mock config模块
 | 
						||
                    mock_config = types.ModuleType('config')
 | 
						||
                    
 | 
						||
                    # 添加可能需要的Config类
 | 
						||
                    class Config:
 | 
						||
                        def __init__(self):
 | 
						||
                            pass
 | 
						||
                    
 | 
						||
                    mock_config.Config = Config
 | 
						||
                    sys.modules['config'] = mock_config
 | 
						||
                    
 | 
						||
                    try:
 | 
						||
                        checkpoint = torch.load(model_path, map_location=device, weights_only=False)
 | 
						||
                    finally:
 | 
						||
                        # 清理临时模块
 | 
						||
                        if 'config' in sys.modules:
 | 
						||
                            del sys.modules['config']
 | 
						||
            else:
 | 
						||
                raise e
 | 
						||
        
 | 
						||
        # 处理不同的模型保存格式
 | 
						||
        if isinstance(checkpoint, dict):
 | 
						||
            if 'model_state_dict' in checkpoint:
 | 
						||
                # 完整检查点格式
 | 
						||
                state_dict = checkpoint['model_state_dict']
 | 
						||
                print(f"检查点信息:")
 | 
						||
                print(f"  - 训练轮次: {checkpoint.get('epoch', 'N/A')}")
 | 
						||
                print(f"  - 最佳验证损失: {checkpoint.get('best_val_loss', 'N/A')}")
 | 
						||
            else:
 | 
						||
                # 精简模型格式(只包含权重)
 | 
						||
                print("加载精简模型(仅权重)")
 | 
						||
                state_dict = checkpoint
 | 
						||
        else:
 | 
						||
            # 直接是状态字典
 | 
						||
            state_dict = checkpoint
 | 
						||
        
 | 
						||
        # 加载权重
 | 
						||
        lightcrnn_model.load_state_dict(state_dict)
 | 
						||
        lightcrnn_model.to(device)
 | 
						||
        lightcrnn_model.eval()
 | 
						||
        
 | 
						||
        print("LightCRNN模型初始化完成")
 | 
						||
        
 | 
						||
        # 统计模型参数
 | 
						||
        total_params = sum(p.numel() for p in lightcrnn_model.parameters())
 | 
						||
        print(f"LightCRNN模型参数数量: {total_params:,}")
 | 
						||
        
 | 
						||
        return True
 | 
						||
        
 | 
						||
    except Exception as e:
 | 
						||
        print(f"LightCRNN模型初始化失败: {e}")
 | 
						||
        import traceback
 | 
						||
        traceback.print_exc()
 | 
						||
        return False
 | 
						||
 | 
						||
def LPRNmodel_predict(image_array):
 | 
						||
    """
 | 
						||
    轻量化CRNN车牌号识别接口函数
 | 
						||
    
 | 
						||
    参数:
 | 
						||
        image_array: numpy数组格式的车牌图像,已经过矫正处理
 | 
						||
    
 | 
						||
    返回:
 | 
						||
        list: 包含最多8个字符的列表,代表车牌号的每个字符
 | 
						||
              例如: ['京', 'A', '1', '2', '3', '4', '5', ''] (蓝牌7位+占位符)
 | 
						||
                   ['京', 'A', 'D', '1', '2', '3', '4', '5'] (绿牌8位)
 | 
						||
    """
 | 
						||
    global lightcrnn_model, lightcrnn_decoder, lightcrnn_preprocessor, device
 | 
						||
    
 | 
						||
    if lightcrnn_model is None or lightcrnn_decoder is None or lightcrnn_preprocessor is None:
 | 
						||
        print("LightCRNN模型未初始化,请先调用LPRNinitialize_model()")
 | 
						||
        return ['待', '识', '别', '0', '0', '0', '0', '0']
 | 
						||
    
 | 
						||
    try:
 | 
						||
        # 预处理图像
 | 
						||
        input_tensor = lightcrnn_preprocessor.preprocess_numpy_array(image_array)
 | 
						||
        if input_tensor is None:
 | 
						||
            raise ValueError("图像预处理失败")
 | 
						||
        
 | 
						||
        input_tensor = input_tensor.to(device)
 | 
						||
        
 | 
						||
        # 模型推理
 | 
						||
        with torch.no_grad():
 | 
						||
            outputs = lightcrnn_model(input_tensor)  # (seq_len, batch_size, num_classes)
 | 
						||
            
 | 
						||
            # 移除batch维度
 | 
						||
            outputs = outputs.squeeze(1)  # (seq_len, num_classes)
 | 
						||
            
 | 
						||
            # CTC解码
 | 
						||
            predicted_text, confidence, char_confidences = lightcrnn_decoder.decode_with_confidence(outputs)
 | 
						||
            
 | 
						||
            print(f"LightCRNN识别结果: {predicted_text}, 置信度: {confidence:.3f}")
 | 
						||
            
 | 
						||
            # 将字符串转换为字符列表
 | 
						||
            char_list = list(predicted_text)
 | 
						||
            
 | 
						||
            # 确保返回至少7个字符,最多8个字符
 | 
						||
            if len(char_list) < 7:
 | 
						||
                # 如果识别结果少于7个字符,用'0'补齐到7位
 | 
						||
                char_list.extend(['0'] * (7 - len(char_list)))
 | 
						||
            elif len(char_list) > 8:
 | 
						||
                # 如果识别结果多于8个字符,截取前8个
 | 
						||
                char_list = char_list[:8]
 | 
						||
            
 | 
						||
            # 如果是7位,补齐到8位以保持接口一致性(第8位用空字符或占位符)
 | 
						||
            if len(char_list) == 7:
 | 
						||
                char_list.append('')  # 添加空字符作为第8位占位符
 | 
						||
            
 | 
						||
            return char_list
 | 
						||
            
 | 
						||
    except Exception as e:
 | 
						||
        print(f"LightCRNN识别失败: {e}")
 | 
						||
        import traceback
 | 
						||
        traceback.print_exc()
 | 
						||
        return ['识', '别', '失', '败', '0', '0', '0', '0']
 | 
						||
 | 
						||
def create_lightweight_model(model_type='lightweight_crnn', img_height=32, num_classes=66, hidden_size=160):
 | 
						||
    """创建增强版轻量化模型"""
 | 
						||
    if model_type == 'lightweight_crnn':
 | 
						||
        return LightweightCRNN(img_height, num_classes, hidden_size=hidden_size)
 | 
						||
    else:
 | 
						||
        raise ValueError(f"Unknown lightweight model type: {model_type}")
 | 
						||
 | 
						||
if __name__ == "__main__":
 | 
						||
    # 测试轻量化模型
 | 
						||
    print("测试LightCRNN模型...")
 | 
						||
    
 | 
						||
    # 初始化模型
 | 
						||
    success = LPRNinitialize_model()
 | 
						||
    if success:
 | 
						||
        print("模型初始化成功")
 | 
						||
        
 | 
						||
        # 创建测试输入
 | 
						||
        test_input = np.random.randint(0, 255, (32, 128, 3), dtype=np.uint8)
 | 
						||
        
 | 
						||
        # 测试预测
 | 
						||
        result = LPRNmodel_predict(test_input)
 | 
						||
        print(f"测试预测结果: {result}")
 | 
						||
    else:
 | 
						||
        print("模型初始化失败") |