Compare commits
	
		
			13 Commits
		
	
	
		
			8ace9df86a
			...
			main
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 428b577808 | |||
| 15a83a5f06 | |||
| 418f7f3bc9 | |||
| a99e8fccb2 | |||
| 40f5e1c1be | |||
| c1fbccd7ee | |||
| d649738f6c | |||
| 6831a8cd01 | |||
| cf60d96066 | |||
| 09c3117f12 | |||
| 2a77e6ca8a | |||
| 56e7347c01 | |||
| 1c8e15bcd8 | 
							
								
								
									
										8
									
								
								.idea/.gitignore
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								.idea/.gitignore
									
									
									
										generated
									
									
										vendored
									
									
								
							@@ -1,8 +0,0 @@
 | 
			
		||||
# 默认忽略的文件
 | 
			
		||||
/shelf/
 | 
			
		||||
/workspace.xml
 | 
			
		||||
# 基于编辑器的 HTTP 客户端请求
 | 
			
		||||
/httpRequests/
 | 
			
		||||
# Datasource local storage ignored files
 | 
			
		||||
/dataSources/
 | 
			
		||||
/dataSources.local.xml
 | 
			
		||||
							
								
								
									
										12
									
								
								.idea/License_plate_recognition.iml
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										12
									
								
								.idea/License_plate_recognition.iml
									
									
									
										generated
									
									
									
								
							@@ -1,12 +0,0 @@
 | 
			
		||||
<?xml version="1.0" encoding="UTF-8"?>
 | 
			
		||||
<module type="PYTHON_MODULE" version="4">
 | 
			
		||||
  <component name="NewModuleRootManager">
 | 
			
		||||
    <content url="file://$MODULE_DIR$" />
 | 
			
		||||
    <orderEntry type="jdk" jdkName="cnm" jdkType="Python SDK" />
 | 
			
		||||
    <orderEntry type="sourceFolder" forTests="false" />
 | 
			
		||||
  </component>
 | 
			
		||||
  <component name="PyDocumentationSettings">
 | 
			
		||||
    <option name="format" value="PLAIN" />
 | 
			
		||||
    <option name="myDocStringFormat" value="Plain" />
 | 
			
		||||
  </component>
 | 
			
		||||
</module>
 | 
			
		||||
							
								
								
									
										6
									
								
								.idea/inspectionProfiles/profiles_settings.xml
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										6
									
								
								.idea/inspectionProfiles/profiles_settings.xml
									
									
									
										generated
									
									
									
								
							@@ -1,6 +0,0 @@
 | 
			
		||||
<component name="InspectionProjectProfileManager">
 | 
			
		||||
  <settings>
 | 
			
		||||
    <option name="USE_PROJECT_PROFILE" value="false" />
 | 
			
		||||
    <version value="1.0" />
 | 
			
		||||
  </settings>
 | 
			
		||||
</component>
 | 
			
		||||
							
								
								
									
										7
									
								
								.idea/misc.xml
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										7
									
								
								.idea/misc.xml
									
									
									
										generated
									
									
									
								
							@@ -1,7 +0,0 @@
 | 
			
		||||
<?xml version="1.0" encoding="UTF-8"?>
 | 
			
		||||
<project version="4">
 | 
			
		||||
  <component name="Black">
 | 
			
		||||
    <option name="sdkName" value="pytorh" />
 | 
			
		||||
  </component>
 | 
			
		||||
  <component name="ProjectRootManager" version="2" project-jdk-name="cnm" project-jdk-type="Python SDK" />
 | 
			
		||||
</project>
 | 
			
		||||
							
								
								
									
										8
									
								
								.idea/modules.xml
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										8
									
								
								.idea/modules.xml
									
									
									
										generated
									
									
									
								
							@@ -1,8 +0,0 @@
 | 
			
		||||
<?xml version="1.0" encoding="UTF-8"?>
 | 
			
		||||
<project version="4">
 | 
			
		||||
  <component name="ProjectModuleManager">
 | 
			
		||||
    <modules>
 | 
			
		||||
      <module fileurl="file://$PROJECT_DIR$/.idea/License_plate_recognition.iml" filepath="$PROJECT_DIR$/.idea/License_plate_recognition.iml" />
 | 
			
		||||
    </modules>
 | 
			
		||||
  </component>
 | 
			
		||||
</project>
 | 
			
		||||
							
								
								
									
										7
									
								
								.idea/vcs.xml
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										7
									
								
								.idea/vcs.xml
									
									
									
										generated
									
									
									
								
							@@ -1,7 +0,0 @@
 | 
			
		||||
<?xml version="1.0" encoding="UTF-8"?>
 | 
			
		||||
<project version="4">
 | 
			
		||||
  <component name="VcsDirectoryMappings">
 | 
			
		||||
    <mapping directory="$PROJECT_DIR$/.." vcs="Git" />
 | 
			
		||||
    <mapping directory="$PROJECT_DIR$" vcs="Git" />
 | 
			
		||||
  </component>
 | 
			
		||||
</project>
 | 
			
		||||
										
											Binary file not shown.
										
									
								
							@@ -1,328 +0,0 @@
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import cv2
 | 
			
		||||
import numpy as np
 | 
			
		||||
import os
 | 
			
		||||
import sys
 | 
			
		||||
from torch.autograd import Variable
 | 
			
		||||
from PIL import Image
 | 
			
		||||
 | 
			
		||||
# 添加父目录到路径,以便导入模型和数据加载器
 | 
			
		||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
 | 
			
		||||
 | 
			
		||||
# LPRNet字符集定义(与训练时保持一致)
 | 
			
		||||
CHARS = ['京', '沪', '津', '渝', '冀', '晋', '蒙', '辽', '吉', '黑',
 | 
			
		||||
         '苏', '浙', '皖', '闽', '赣', '鲁', '豫', '鄂', '湘', '粤',
 | 
			
		||||
         '桂', '琼', '川', '贵', '云', '藏', '陕', '甘', '青', '宁', '新',
 | 
			
		||||
         '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
 | 
			
		||||
         'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J', 'K',
 | 
			
		||||
         'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'U', 'V',
 | 
			
		||||
         'W', 'X', 'Y', 'Z', 'I', 'O', '-']
 | 
			
		||||
 | 
			
		||||
CHARS_DICT = {char: i for i, char in enumerate(CHARS)}
 | 
			
		||||
 | 
			
		||||
# 简化的LPRNet模型定义
 | 
			
		||||
class small_basic_block(nn.Module):
 | 
			
		||||
    def __init__(self, ch_in, ch_out):
 | 
			
		||||
        super(small_basic_block, self).__init__()
 | 
			
		||||
        self.block = nn.Sequential(
 | 
			
		||||
            nn.Conv2d(ch_in, ch_out // 4, kernel_size=1),
 | 
			
		||||
            nn.ReLU(),
 | 
			
		||||
            nn.Conv2d(ch_out // 4, ch_out // 4, kernel_size=(3, 1), padding=(1, 0)),
 | 
			
		||||
            nn.ReLU(),
 | 
			
		||||
            nn.Conv2d(ch_out // 4, ch_out // 4, kernel_size=(1, 3), padding=(0, 1)),
 | 
			
		||||
            nn.ReLU(),
 | 
			
		||||
            nn.Conv2d(ch_out // 4, ch_out, kernel_size=1),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        return self.block(x)
 | 
			
		||||
 | 
			
		||||
class LPRNet(nn.Module):
 | 
			
		||||
    def __init__(self, lpr_max_len, phase, class_num, dropout_rate):
 | 
			
		||||
        super(LPRNet, self).__init__()
 | 
			
		||||
        self.phase = phase
 | 
			
		||||
        self.lpr_max_len = lpr_max_len
 | 
			
		||||
        self.class_num = class_num
 | 
			
		||||
        self.backbone = nn.Sequential(
 | 
			
		||||
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1), # 0
 | 
			
		||||
            nn.BatchNorm2d(num_features=64),
 | 
			
		||||
            nn.ReLU(),  # 2
 | 
			
		||||
            nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 1, 1)),
 | 
			
		||||
            small_basic_block(ch_in=64, ch_out=128),    # *** 4 ***
 | 
			
		||||
            nn.BatchNorm2d(num_features=128),
 | 
			
		||||
            nn.ReLU(),  # 6
 | 
			
		||||
            nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(2, 1, 2)),
 | 
			
		||||
            small_basic_block(ch_in=64, ch_out=256),   # 8
 | 
			
		||||
            nn.BatchNorm2d(num_features=256),
 | 
			
		||||
            nn.ReLU(),  # 10
 | 
			
		||||
            small_basic_block(ch_in=256, ch_out=256),   # *** 11 ***
 | 
			
		||||
            nn.BatchNorm2d(num_features=256),
 | 
			
		||||
            nn.ReLU(),  # 13
 | 
			
		||||
            nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(4, 1, 2)),  # 14
 | 
			
		||||
            nn.Dropout(dropout_rate),
 | 
			
		||||
            nn.Conv2d(in_channels=64, out_channels=256, kernel_size=(1, 4), stride=1), # 16
 | 
			
		||||
            nn.BatchNorm2d(num_features=256),
 | 
			
		||||
            nn.ReLU(),  # 18
 | 
			
		||||
            nn.Dropout(dropout_rate),
 | 
			
		||||
            nn.Conv2d(in_channels=256, out_channels=class_num, kernel_size=(13, 1), stride=1), # 20
 | 
			
		||||
            nn.BatchNorm2d(num_features=class_num),
 | 
			
		||||
            nn.ReLU(),  # 22
 | 
			
		||||
        )
 | 
			
		||||
        self.container = nn.Sequential(
 | 
			
		||||
            nn.Conv2d(in_channels=448+self.class_num, out_channels=self.class_num, kernel_size=(1,1), stride=(1,1)),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        keep_features = list()
 | 
			
		||||
        for i, layer in enumerate(self.backbone.children()):
 | 
			
		||||
            x = layer(x)
 | 
			
		||||
            if i in [2, 6, 13, 22]: # [2, 4, 8, 11, 22]
 | 
			
		||||
                keep_features.append(x)
 | 
			
		||||
 | 
			
		||||
        global_context = list()
 | 
			
		||||
        for i, f in enumerate(keep_features):
 | 
			
		||||
            if i in [0, 1]:
 | 
			
		||||
                f = nn.AvgPool2d(kernel_size=5, stride=5)(f)
 | 
			
		||||
            if i in [2]:
 | 
			
		||||
                f = nn.AvgPool2d(kernel_size=(4, 10), stride=(4, 2))(f)
 | 
			
		||||
            f_pow = torch.pow(f, 2)
 | 
			
		||||
            f_mean = torch.mean(f_pow)
 | 
			
		||||
            f = torch.div(f, f_mean)
 | 
			
		||||
            global_context.append(f)
 | 
			
		||||
 | 
			
		||||
        x = torch.cat(global_context, 1)
 | 
			
		||||
        x = self.container(x)
 | 
			
		||||
        logits = torch.mean(x, dim=2)
 | 
			
		||||
 | 
			
		||||
        return logits
 | 
			
		||||
 | 
			
		||||
class LPRNetInference:
 | 
			
		||||
    def __init__(self, model_path=None, img_size=[94, 24], lpr_max_len=8, dropout_rate=0.5):
 | 
			
		||||
        """
 | 
			
		||||
        初始化LPRNet推理类
 | 
			
		||||
        Args:
 | 
			
		||||
            model_path: 训练好的模型权重文件路径
 | 
			
		||||
            img_size: 输入图像尺寸 [width, height]
 | 
			
		||||
            lpr_max_len: 车牌最大长度
 | 
			
		||||
            dropout_rate: dropout率
 | 
			
		||||
        """
 | 
			
		||||
        self.img_size = img_size
 | 
			
		||||
        self.lpr_max_len = lpr_max_len
 | 
			
		||||
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 | 
			
		||||
        
 | 
			
		||||
        # 设置默认模型路径
 | 
			
		||||
        if model_path is None:
 | 
			
		||||
            current_dir = os.path.dirname(os.path.abspath(__file__))
 | 
			
		||||
            model_path = os.path.join(current_dir, 'LPRNet__iteration_74000.pth')
 | 
			
		||||
        
 | 
			
		||||
        # 初始化模型
 | 
			
		||||
        self.model = LPRNet(lpr_max_len=lpr_max_len, phase=False, class_num=len(CHARS), dropout_rate=dropout_rate)
 | 
			
		||||
        
 | 
			
		||||
        # 加载模型权重
 | 
			
		||||
        if model_path and os.path.exists(model_path):
 | 
			
		||||
            print(f"Loading LPRNet model from {model_path}")
 | 
			
		||||
            try:
 | 
			
		||||
                self.model.load_state_dict(torch.load(model_path, map_location=self.device))
 | 
			
		||||
                print("LPRNet模型权重加载成功")
 | 
			
		||||
            except Exception as e:
 | 
			
		||||
                print(f"Warning: 加载模型权重失败: {e}. 使用随机权重.")
 | 
			
		||||
        else:
 | 
			
		||||
            print(f"Warning: 模型文件不存在或未指定: {model_path}. 使用随机权重.")
 | 
			
		||||
        
 | 
			
		||||
        self.model.to(self.device)
 | 
			
		||||
        self.model.eval()
 | 
			
		||||
        
 | 
			
		||||
        print(f"LPRNet模型加载完成,设备: {self.device}")
 | 
			
		||||
        print(f"模型参数数量: {sum(p.numel() for p in self.model.parameters()):,}")
 | 
			
		||||
    
 | 
			
		||||
    def preprocess_image(self, image_array):
 | 
			
		||||
        """
 | 
			
		||||
        预处理图像数组 - 使用与训练时相同的预处理方式
 | 
			
		||||
        Args:
 | 
			
		||||
            image_array: numpy数组格式的图像 (H, W, C)
 | 
			
		||||
        Returns:
 | 
			
		||||
            preprocessed_image: 预处理后的图像tensor
 | 
			
		||||
        """
 | 
			
		||||
        if image_array is None:
 | 
			
		||||
            raise ValueError("Input image is None")
 | 
			
		||||
        
 | 
			
		||||
        # 确保图像是numpy数组
 | 
			
		||||
        if not isinstance(image_array, np.ndarray):
 | 
			
		||||
            raise ValueError("Input must be numpy array")
 | 
			
		||||
        
 | 
			
		||||
        # 检查图像维度
 | 
			
		||||
        if len(image_array.shape) != 3:
 | 
			
		||||
            raise ValueError(f"Expected 3D image array, got {len(image_array.shape)}D")
 | 
			
		||||
        
 | 
			
		||||
        height, width, channels = image_array.shape
 | 
			
		||||
        if channels != 3:
 | 
			
		||||
            raise ValueError(f"Expected 3 channels, got {channels}")
 | 
			
		||||
        
 | 
			
		||||
        # 调整图像尺寸到模型要求的尺寸
 | 
			
		||||
        if height != self.img_size[1] or width != self.img_size[0]:
 | 
			
		||||
            image_array = cv2.resize(image_array, tuple(self.img_size))
 | 
			
		||||
        
 | 
			
		||||
        # 使用与训练时相同的预处理方式
 | 
			
		||||
        image_array = image_array.astype('float32')
 | 
			
		||||
        image_array -= 127.5
 | 
			
		||||
        image_array *= 0.0078125
 | 
			
		||||
        image_array = np.transpose(image_array, (2, 0, 1))  # HWC -> CHW
 | 
			
		||||
        
 | 
			
		||||
        # 转换为tensor并添加batch维度
 | 
			
		||||
        image_tensor = torch.from_numpy(image_array).unsqueeze(0)
 | 
			
		||||
        
 | 
			
		||||
        return image_tensor
 | 
			
		||||
    
 | 
			
		||||
    def decode_prediction(self, logits):
 | 
			
		||||
        """
 | 
			
		||||
        解码模型预测结果 - 使用正确的CTC贪婪解码
 | 
			
		||||
        Args:
 | 
			
		||||
            logits: 模型输出的logits [batch_size, num_classes, sequence_length]
 | 
			
		||||
        Returns:
 | 
			
		||||
            predicted_text: 预测的车牌号码
 | 
			
		||||
        """
 | 
			
		||||
        # 转换为numpy进行处理
 | 
			
		||||
        prebs = logits.cpu().detach().numpy()
 | 
			
		||||
        preb = prebs[0, :, :]  # 取第一个batch [num_classes, sequence_length]
 | 
			
		||||
        
 | 
			
		||||
        # 贪婪解码:对每个时间步选择最大概率的字符
 | 
			
		||||
        preb_label = []
 | 
			
		||||
        for j in range(preb.shape[1]):  # 遍历每个时间步
 | 
			
		||||
            preb_label.append(np.argmax(preb[:, j], axis=0))
 | 
			
		||||
        
 | 
			
		||||
        # CTC解码:去除重复字符和空白字符
 | 
			
		||||
        no_repeat_blank_label = []
 | 
			
		||||
        pre_c = preb_label[0]
 | 
			
		||||
        
 | 
			
		||||
        # 处理第一个字符
 | 
			
		||||
        if pre_c != len(CHARS) - 1:  # 不是空白字符
 | 
			
		||||
            no_repeat_blank_label.append(pre_c)
 | 
			
		||||
        
 | 
			
		||||
        # 处理后续字符
 | 
			
		||||
        for c in preb_label:
 | 
			
		||||
            if (pre_c == c) or (c == len(CHARS) - 1):  # 重复字符或空白字符
 | 
			
		||||
                if c == len(CHARS) - 1:
 | 
			
		||||
                    pre_c = c
 | 
			
		||||
                continue
 | 
			
		||||
            no_repeat_blank_label.append(c)
 | 
			
		||||
            pre_c = c
 | 
			
		||||
        
 | 
			
		||||
        # 转换为字符
 | 
			
		||||
        decoded_chars = [CHARS[idx] for idx in no_repeat_blank_label]
 | 
			
		||||
        return ''.join(decoded_chars)
 | 
			
		||||
    
 | 
			
		||||
    def predict(self, image_array):
 | 
			
		||||
        """
 | 
			
		||||
        预测单张图像的车牌号码
 | 
			
		||||
        Args:
 | 
			
		||||
            image_array: numpy数组格式的图像
 | 
			
		||||
        Returns:
 | 
			
		||||
            prediction: 预测的车牌号码
 | 
			
		||||
            confidence: 预测置信度
 | 
			
		||||
        """
 | 
			
		||||
        try:
 | 
			
		||||
            # 预处理图像
 | 
			
		||||
            image = self.preprocess_image(image_array)
 | 
			
		||||
            if image is None:
 | 
			
		||||
                return None, 0.0
 | 
			
		||||
            
 | 
			
		||||
            image = image.to(self.device)
 | 
			
		||||
            
 | 
			
		||||
            # 模型推理
 | 
			
		||||
            with torch.no_grad():
 | 
			
		||||
                logits = self.model(image)
 | 
			
		||||
                # logits shape: [batch_size, class_num, sequence_length]
 | 
			
		||||
                
 | 
			
		||||
                # 计算置信度(使用softmax后的最大概率平均值)
 | 
			
		||||
                probs = torch.softmax(logits, dim=1)
 | 
			
		||||
                max_probs = torch.max(probs, dim=1)[0]
 | 
			
		||||
                confidence = torch.mean(max_probs).item()
 | 
			
		||||
                
 | 
			
		||||
                # 解码预测结果
 | 
			
		||||
                prediction = self.decode_prediction(logits)
 | 
			
		||||
            
 | 
			
		||||
            return prediction, confidence
 | 
			
		||||
            
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            print(f"预测图像失败: {e}")
 | 
			
		||||
            return None, 0.0
 | 
			
		||||
 | 
			
		||||
# 全局变量
 | 
			
		||||
lpr_model = None
 | 
			
		||||
 | 
			
		||||
def LPRNinitialize_model():
 | 
			
		||||
    """
 | 
			
		||||
    初始化LPRNet模型
 | 
			
		||||
    
 | 
			
		||||
    返回:
 | 
			
		||||
        bool: 初始化是否成功
 | 
			
		||||
    """
 | 
			
		||||
    global lpr_model
 | 
			
		||||
    
 | 
			
		||||
    try:
 | 
			
		||||
        # 模型权重文件路径
 | 
			
		||||
        model_path = os.path.join(os.path.dirname(__file__), 'LPRNet__iteration_74000.pth')
 | 
			
		||||
        
 | 
			
		||||
        # 创建推理对象
 | 
			
		||||
        lpr_model = LPRNetInference(model_path)
 | 
			
		||||
        
 | 
			
		||||
        print("LPRNet模型初始化完成")
 | 
			
		||||
        return True
 | 
			
		||||
        
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        print(f"LPRNet模型初始化失败: {e}")
 | 
			
		||||
        import traceback
 | 
			
		||||
        traceback.print_exc()
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
def LPRNmodel_predict(image_array):
 | 
			
		||||
    """
 | 
			
		||||
    LPRNet车牌号识别接口函数
 | 
			
		||||
    
 | 
			
		||||
    参数:
 | 
			
		||||
        image_array: numpy数组格式的车牌图像,已经过矫正处理
 | 
			
		||||
    
 | 
			
		||||
    返回:
 | 
			
		||||
        list: 包含最多8个字符的列表,代表车牌号的每个字符
 | 
			
		||||
              例如: ['京', 'A', '1', '2', '3', '4', '5'] (蓝牌7位)
 | 
			
		||||
                   ['京', 'A', 'D', '1', '2', '3', '4', '5'] (绿牌8位)
 | 
			
		||||
    """
 | 
			
		||||
    global lpr_model
 | 
			
		||||
    
 | 
			
		||||
    if lpr_model is None:
 | 
			
		||||
        print("LPRNet模型未初始化,请先调用LPRNinitialize_model()")
 | 
			
		||||
        return ['待', '识', '别', '0', '0', '0', '0', '0']
 | 
			
		||||
    
 | 
			
		||||
    try:
 | 
			
		||||
        # 预测车牌号
 | 
			
		||||
        predicted_text, confidence = lpr_model.predict(image_array)
 | 
			
		||||
        
 | 
			
		||||
        if predicted_text is None:
 | 
			
		||||
            print("LPRNet识别失败")
 | 
			
		||||
            return ['识', '别', '失', '败', '0', '0', '0', '0']
 | 
			
		||||
        
 | 
			
		||||
        print(f"LPRNet识别结果: {predicted_text}, 置信度: {confidence:.3f}")
 | 
			
		||||
        
 | 
			
		||||
        # 将字符串转换为字符列表
 | 
			
		||||
        char_list = list(predicted_text)
 | 
			
		||||
        
 | 
			
		||||
        # 确保返回至少7个字符,最多8个字符
 | 
			
		||||
        if len(char_list) < 7:
 | 
			
		||||
            # 如果识别结果少于7个字符,用'0'补齐到7位
 | 
			
		||||
            char_list.extend(['0'] * (7 - len(char_list)))
 | 
			
		||||
        elif len(char_list) > 8:
 | 
			
		||||
            # 如果识别结果多于8个字符,截取前8个
 | 
			
		||||
            char_list = char_list[:8]
 | 
			
		||||
        
 | 
			
		||||
        # 如果是7位,补齐到8位以保持接口一致性(第8位用空字符或占位符)
 | 
			
		||||
        if len(char_list) == 7:
 | 
			
		||||
            char_list.append('')  # 添加空字符作为第8位占位符
 | 
			
		||||
        
 | 
			
		||||
        return char_list
 | 
			
		||||
        
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        print(f"LPRNet识别失败: {e}")
 | 
			
		||||
        import traceback
 | 
			
		||||
        traceback.print_exc()
 | 
			
		||||
        return ['识', '别', '失', '败', '0', '0', '0', '0']
 | 
			
		||||
@@ -5,6 +5,18 @@ import cv2
 | 
			
		||||
class OCRProcessor:
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.model = TextRecognition(model_name="PP-OCRv5_server_rec")
 | 
			
		||||
        # 定义允许的字符集合(不包含空白字符)
 | 
			
		||||
        self.allowed_chars = [
 | 
			
		||||
            # 中文省份简称
 | 
			
		||||
            '京', '沪', '津', '渝', '冀', '晋', '蒙', '辽', '吉', '黑',
 | 
			
		||||
            '苏', '浙', '皖', '闽', '赣', '鲁', '豫', '鄂', '湘', '粤',
 | 
			
		||||
            '桂', '琼', '川', '贵', '云', '藏', '陕', '甘', '青', '宁', '新',
 | 
			
		||||
            # 字母 A-Z
 | 
			
		||||
            'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 
 | 
			
		||||
            'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',
 | 
			
		||||
            # 数字 0-9
 | 
			
		||||
            '0', '1', '2', '3', '4', '5', '6', '7', '8', '9'
 | 
			
		||||
        ]
 | 
			
		||||
        print("OCR模型初始化完成(占位)")
 | 
			
		||||
 | 
			
		||||
    def predict(self, image_array):
 | 
			
		||||
@@ -15,6 +27,14 @@ class OCRProcessor:
 | 
			
		||||
        placeholder_result = results.split(',')
 | 
			
		||||
        return placeholder_result
 | 
			
		||||
    
 | 
			
		||||
    def filter_allowed_chars(self, text):
 | 
			
		||||
        """只保留允许的字符"""
 | 
			
		||||
        filtered_text = ""
 | 
			
		||||
        for char in text:
 | 
			
		||||
            if char in self.allowed_chars:
 | 
			
		||||
                filtered_text += char
 | 
			
		||||
        return filtered_text
 | 
			
		||||
 | 
			
		||||
# 保留原有函数接口
 | 
			
		||||
_processor = OCRProcessor()
 | 
			
		||||
 | 
			
		||||
@@ -42,8 +62,12 @@ def LPRNmodel_predict(image_array):
 | 
			
		||||
    else:
 | 
			
		||||
        result_str = str(raw_result)
 | 
			
		||||
    
 | 
			
		||||
    # 过滤掉'·'字符
 | 
			
		||||
    # 过滤掉'·'和'-'字符
 | 
			
		||||
    filtered_str = result_str.replace('·', '')
 | 
			
		||||
    filtered_str = filtered_str.replace('-', '')
 | 
			
		||||
    
 | 
			
		||||
    # 只保留允许的字符
 | 
			
		||||
    filtered_str = _processor.filter_allowed_chars(filtered_str)
 | 
			
		||||
    
 | 
			
		||||
    # 转换为字符列表
 | 
			
		||||
    char_list = list(filtered_str)
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										69
									
								
								communicate.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										69
									
								
								communicate.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,69 @@
 | 
			
		||||
#!/usr/bin/env python3
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
"""
 | 
			
		||||
向Hi3861设备发送JSON命令
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
import socket
 | 
			
		||||
import json
 | 
			
		||||
import time
 | 
			
		||||
import pyttsx3
 | 
			
		||||
import threading
 | 
			
		||||
 | 
			
		||||
target_ip = "192.168.43.12"
 | 
			
		||||
target_port = 8081
 | 
			
		||||
 | 
			
		||||
def speak_text(text):
 | 
			
		||||
    """
 | 
			
		||||
    使用文本转语音播放文本
 | 
			
		||||
    每次调用都创建新的引擎实例以避免并发问题
 | 
			
		||||
    """
 | 
			
		||||
    def _speak():
 | 
			
		||||
        try:
 | 
			
		||||
            if text and text.strip():  # 确保文本不为空
 | 
			
		||||
                # 在线程内部创建新的引擎实例
 | 
			
		||||
                engine = pyttsx3.init()
 | 
			
		||||
                # 设置语音速度
 | 
			
		||||
                engine.setProperty('rate', 150)
 | 
			
		||||
                # 设置音量(0.0到1.0)
 | 
			
		||||
                engine.setProperty('volume', 0.8)
 | 
			
		||||
                
 | 
			
		||||
                engine.say(text)
 | 
			
		||||
                engine.runAndWait()
 | 
			
		||||
                
 | 
			
		||||
                # 清理引擎
 | 
			
		||||
                engine.stop()
 | 
			
		||||
                del engine
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            print(f"语音播放失败: {e}")
 | 
			
		||||
    
 | 
			
		||||
    # 在新线程中播放语音,避免阻塞
 | 
			
		||||
    speech_thread = threading.Thread(target=_speak)
 | 
			
		||||
    speech_thread.daemon = True
 | 
			
		||||
    speech_thread.start()
 | 
			
		||||
 | 
			
		||||
def send_command(cmd, text):
 | 
			
		||||
    #cmd为1,道闸打开十秒后关闭,oled显示字符串信息(默认使用及cmd为4)
 | 
			
		||||
    #cmd为2,道闸舵机向打开方向旋转90度,oled上不显示(仅在qt界面手动开闸时调用)
 | 
			
		||||
    #cmd为3,道闸舵机向关闭方向旋转90度,oled上不显示(仅在qt界面手动关闸时调用)
 | 
			
		||||
    #cmd为4,oled显示字符串信息,道闸舵机不旋转
 | 
			
		||||
 | 
			
		||||
    command = {
 | 
			
		||||
        "cmd": cmd,
 | 
			
		||||
        "text": text
 | 
			
		||||
    }
 | 
			
		||||
    
 | 
			
		||||
    json_command = json.dumps(command, ensure_ascii=False)
 | 
			
		||||
    try:
 | 
			
		||||
        # 创建UDP socket
 | 
			
		||||
        sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
 | 
			
		||||
        sock.sendto(json_command.encode('utf-8'), (target_ip, target_port))
 | 
			
		||||
        
 | 
			
		||||
        # 发送命令后播放语音
 | 
			
		||||
        if text and text.strip():
 | 
			
		||||
            speak_text(text)
 | 
			
		||||
            
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        print(f"发送命令失败: {e}")
 | 
			
		||||
    finally:
 | 
			
		||||
        sock.close()
 | 
			
		||||
							
								
								
									
										251
									
								
								gate_control.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										251
									
								
								gate_control.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,251 @@
 | 
			
		||||
#!/usr/bin/env python3
 | 
			
		||||
# -*- coding: utf-8 -*-
 | 
			
		||||
"""
 | 
			
		||||
道闸控制模块
 | 
			
		||||
负责与Hi3861设备通信,控制道闸开关
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
import socket
 | 
			
		||||
import json
 | 
			
		||||
import time
 | 
			
		||||
from datetime import datetime, timedelta
 | 
			
		||||
from PyQt5.QtCore import QObject, pyqtSignal, QThread
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class GateControlThread(QThread):
 | 
			
		||||
    """道闸控制线程,用于异步发送命令"""
 | 
			
		||||
    command_sent = pyqtSignal(str, bool)  # 信号:命令内容,是否成功
 | 
			
		||||
    
 | 
			
		||||
    def __init__(self, ip, port, command):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.ip = ip
 | 
			
		||||
        self.port = port
 | 
			
		||||
        self.command = command
 | 
			
		||||
    
 | 
			
		||||
    def run(self):
 | 
			
		||||
        """发送命令到Hi3861设备"""
 | 
			
		||||
        try:
 | 
			
		||||
            # 创建UDP socket
 | 
			
		||||
            sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
 | 
			
		||||
            
 | 
			
		||||
            # 发送命令
 | 
			
		||||
            json_command = json.dumps(self.command, ensure_ascii=False)
 | 
			
		||||
            sock.sendto(json_command.encode('utf-8'), (self.ip, self.port))
 | 
			
		||||
            
 | 
			
		||||
            # 发出成功信号
 | 
			
		||||
            self.command_sent.emit(json_command, True)
 | 
			
		||||
            
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            # 发出失败信号
 | 
			
		||||
            self.command_sent.emit(f"发送失败: {e}", False)
 | 
			
		||||
        finally:
 | 
			
		||||
            sock.close()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class GateController(QObject):
 | 
			
		||||
    """道闸控制器"""
 | 
			
		||||
    
 | 
			
		||||
    # 信号
 | 
			
		||||
    log_message = pyqtSignal(str)  # 日志消息
 | 
			
		||||
    gate_opened = pyqtSignal(str)  # 道闸打开信号,附带车牌号
 | 
			
		||||
    
 | 
			
		||||
    def __init__(self, ip="192.168.43.12", port=8081):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.ip = ip
 | 
			
		||||
        self.port = port
 | 
			
		||||
        self.last_pass_times = {}  # 记录车牌上次通过时间
 | 
			
		||||
        self.thread_pool = []  # 线程池
 | 
			
		||||
        
 | 
			
		||||
    def send_command(self, cmd, text=""):
 | 
			
		||||
        """
 | 
			
		||||
        发送命令到道闸
 | 
			
		||||
        
 | 
			
		||||
        参数:
 | 
			
		||||
            cmd: 命令类型 (1-4)
 | 
			
		||||
            text: 显示文本
 | 
			
		||||
            
 | 
			
		||||
        返回:
 | 
			
		||||
            bool: 是否发送成功
 | 
			
		||||
        """
 | 
			
		||||
        # 创建JSON命令
 | 
			
		||||
        command = {
 | 
			
		||||
            "cmd": cmd,
 | 
			
		||||
            "text": text
 | 
			
		||||
        }
 | 
			
		||||
        
 | 
			
		||||
        # 创建并启动线程发送命令
 | 
			
		||||
        thread = GateControlThread(self.ip, self.port, command)
 | 
			
		||||
        thread.command_sent.connect(self.on_command_sent)
 | 
			
		||||
        thread.start()
 | 
			
		||||
        self.thread_pool.append(thread)
 | 
			
		||||
        
 | 
			
		||||
        # 记录日志
 | 
			
		||||
        cmd_desc = {
 | 
			
		||||
            1: "自动开闸(10秒后关闭)",
 | 
			
		||||
            2: "手动开闸",
 | 
			
		||||
            3: "手动关闸",
 | 
			
		||||
            4: "仅显示信息"
 | 
			
		||||
        }
 | 
			
		||||
        self.log_message.emit(f"发送命令: {cmd_desc.get(cmd, '未知命令')} - {text}")
 | 
			
		||||
        
 | 
			
		||||
        return True
 | 
			
		||||
    
 | 
			
		||||
    def on_command_sent(self, message, success):
 | 
			
		||||
        """命令发送结果处理"""
 | 
			
		||||
        if success:
 | 
			
		||||
            self.log_message.emit(f"命令发送成功: {message}")
 | 
			
		||||
        else:
 | 
			
		||||
            self.log_message.emit(f"命令发送失败: {message}")
 | 
			
		||||
    
 | 
			
		||||
    def auto_open_gate(self, plate_number):
 | 
			
		||||
        """
 | 
			
		||||
        自动开闸(检测到白名单车牌时调用)
 | 
			
		||||
        
 | 
			
		||||
        参数:
 | 
			
		||||
            plate_number: 车牌号
 | 
			
		||||
        """
 | 
			
		||||
        # 获取当前时间
 | 
			
		||||
        current_time = datetime.now()
 | 
			
		||||
        time_diff_str = ""
 | 
			
		||||
        
 | 
			
		||||
        # 检查是否是第一次通行
 | 
			
		||||
        if plate_number in self.last_pass_times:
 | 
			
		||||
            # 第二次或更多次通行,计算时间差
 | 
			
		||||
            last_time = self.last_pass_times[plate_number]
 | 
			
		||||
            time_diff = current_time - last_time
 | 
			
		||||
            
 | 
			
		||||
            # 格式化时间差
 | 
			
		||||
            total_seconds = int(time_diff.total_seconds())
 | 
			
		||||
            minutes = total_seconds // 60
 | 
			
		||||
            seconds = total_seconds % 60
 | 
			
		||||
            
 | 
			
		||||
            if minutes > 0:
 | 
			
		||||
                time_diff_str = f" {minutes}min{seconds}sec"
 | 
			
		||||
            else:
 | 
			
		||||
                time_diff_str = f" {seconds}sec"
 | 
			
		||||
            
 | 
			
		||||
            # 计算时间差后清空之前记录的时间点
 | 
			
		||||
            del self.last_pass_times[plate_number]
 | 
			
		||||
            log_msg = f"检测到白名单车牌: {plate_number},自动开闸{time_diff_str},已清空时间记录"
 | 
			
		||||
        else:
 | 
			
		||||
            # 第一次通行,只记录时间,不计算时间差
 | 
			
		||||
            self.last_pass_times[plate_number] = current_time
 | 
			
		||||
            log_msg = f"检测到白名单车牌: {plate_number},首次通行,已记录时间"
 | 
			
		||||
        
 | 
			
		||||
        # 发送开闸命令
 | 
			
		||||
        display_text = f"{plate_number} 通行{time_diff_str}"
 | 
			
		||||
        self.send_command(1, display_text)
 | 
			
		||||
        
 | 
			
		||||
        # 发出信号
 | 
			
		||||
        self.gate_opened.emit(plate_number)
 | 
			
		||||
        
 | 
			
		||||
        # 记录日志
 | 
			
		||||
        self.log_message.emit(log_msg)
 | 
			
		||||
    
 | 
			
		||||
    def manual_open_gate(self):
 | 
			
		||||
        """手动开闸"""
 | 
			
		||||
        self.send_command(2, "")
 | 
			
		||||
        self.log_message.emit("手动开闸")
 | 
			
		||||
    
 | 
			
		||||
    def manual_close_gate(self):
 | 
			
		||||
        """手动关闸"""
 | 
			
		||||
        self.send_command(3, "")
 | 
			
		||||
        self.log_message.emit("手动关闸")
 | 
			
		||||
    
 | 
			
		||||
    def display_message(self, text):
 | 
			
		||||
        """仅显示信息,不控制道闸"""
 | 
			
		||||
        self.send_command(4, text)
 | 
			
		||||
        self.log_message.emit(f"显示信息: {text}")
 | 
			
		||||
    
 | 
			
		||||
    def deny_access(self, plate_number):
 | 
			
		||||
        """
 | 
			
		||||
        拒绝通行(检测到非白名单车牌时调用)
 | 
			
		||||
        
 | 
			
		||||
        参数:
 | 
			
		||||
            plate_number: 车牌号
 | 
			
		||||
        """
 | 
			
		||||
        self.send_command(4, f"{plate_number} 禁止通行")
 | 
			
		||||
        self.log_message.emit(f"检测到非白名单车牌: {plate_number},拒绝通行")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class WhitelistManager(QObject):
 | 
			
		||||
    """白名单管理器"""
 | 
			
		||||
    
 | 
			
		||||
    # 信号
 | 
			
		||||
    whitelist_changed = pyqtSignal(list)  # 白名单变更信号
 | 
			
		||||
    
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.whitelist = []  # 白名单车牌列表
 | 
			
		||||
    
 | 
			
		||||
    def add_plate(self, plate_number):
 | 
			
		||||
        """
 | 
			
		||||
        添加车牌到白名单
 | 
			
		||||
        
 | 
			
		||||
        参数:
 | 
			
		||||
            plate_number: 车牌号
 | 
			
		||||
            
 | 
			
		||||
        返回:
 | 
			
		||||
            bool: 是否添加成功
 | 
			
		||||
        """
 | 
			
		||||
        if not plate_number or plate_number in self.whitelist:
 | 
			
		||||
            return False
 | 
			
		||||
        
 | 
			
		||||
        self.whitelist.append(plate_number)
 | 
			
		||||
        self.whitelist_changed.emit(self.whitelist.copy())
 | 
			
		||||
        return True
 | 
			
		||||
    
 | 
			
		||||
    def remove_plate(self, plate_number):
 | 
			
		||||
        """
 | 
			
		||||
        从白名单移除车牌
 | 
			
		||||
        
 | 
			
		||||
        参数:
 | 
			
		||||
            plate_number: 车牌号
 | 
			
		||||
            
 | 
			
		||||
        返回:
 | 
			
		||||
            bool: 是否移除成功
 | 
			
		||||
        """
 | 
			
		||||
        if plate_number in self.whitelist:
 | 
			
		||||
            self.whitelist.remove(plate_number)
 | 
			
		||||
            self.whitelist_changed.emit(self.whitelist.copy())
 | 
			
		||||
            return True
 | 
			
		||||
        return False
 | 
			
		||||
    
 | 
			
		||||
    def edit_plate(self, old_plate, new_plate):
 | 
			
		||||
        """
 | 
			
		||||
        编辑白名单中的车牌
 | 
			
		||||
        
 | 
			
		||||
        参数:
 | 
			
		||||
            old_plate: 原车牌号
 | 
			
		||||
            new_plate: 新车牌号
 | 
			
		||||
            
 | 
			
		||||
        返回:
 | 
			
		||||
            bool: 是否编辑成功
 | 
			
		||||
        """
 | 
			
		||||
        if old_plate in self.whitelist and new_plate not in self.whitelist:
 | 
			
		||||
            index = self.whitelist.index(old_plate)
 | 
			
		||||
            self.whitelist[index] = new_plate
 | 
			
		||||
            self.whitelist_changed.emit(self.whitelist.copy())
 | 
			
		||||
            return True
 | 
			
		||||
        return False
 | 
			
		||||
    
 | 
			
		||||
    def is_whitelisted(self, plate_number):
 | 
			
		||||
        """
 | 
			
		||||
        检查车牌是否在白名单中
 | 
			
		||||
        
 | 
			
		||||
        参数:
 | 
			
		||||
            plate_number: 车牌号
 | 
			
		||||
            
 | 
			
		||||
        返回:
 | 
			
		||||
            bool: 是否在白名单中
 | 
			
		||||
        """
 | 
			
		||||
        return plate_number in self.whitelist
 | 
			
		||||
    
 | 
			
		||||
    def get_whitelist(self):
 | 
			
		||||
        """获取白名单副本"""
 | 
			
		||||
        return self.whitelist.copy()
 | 
			
		||||
    
 | 
			
		||||
    def clear_whitelist(self):
 | 
			
		||||
        """清空白名单"""
 | 
			
		||||
        self.whitelist.clear()
 | 
			
		||||
        self.whitelist_changed.emit(self.whitelist.copy())
 | 
			
		||||
							
								
								
									
										
											BIN
										
									
								
								lightCRNN_part/best_model.pth
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								lightCRNN_part/best_model.pth
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										546
									
								
								lightCRNN_part/lightcrnn_interface.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										546
									
								
								lightCRNN_part/lightcrnn_interface.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,546 @@
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
import numpy as np
 | 
			
		||||
from PIL import Image
 | 
			
		||||
import cv2
 | 
			
		||||
from torchvision import transforms
 | 
			
		||||
import os
 | 
			
		||||
import math
 | 
			
		||||
 | 
			
		||||
# 全局变量
 | 
			
		||||
lightcrnn_model = None
 | 
			
		||||
lightcrnn_decoder = None
 | 
			
		||||
lightcrnn_preprocessor = None
 | 
			
		||||
device = None
 | 
			
		||||
 | 
			
		||||
class DepthwiseSeparableConv(nn.Module):
 | 
			
		||||
    """深度可分离卷积"""
 | 
			
		||||
    
 | 
			
		||||
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
 | 
			
		||||
        super(DepthwiseSeparableConv, self).__init__()
 | 
			
		||||
        # 深度卷积
 | 
			
		||||
        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, 
 | 
			
		||||
                                 stride=stride, padding=padding, groups=in_channels, bias=False)
 | 
			
		||||
        # 逐点卷积
 | 
			
		||||
        self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
 | 
			
		||||
        self.bn = nn.BatchNorm2d(out_channels)
 | 
			
		||||
        self.relu = nn.ReLU6(inplace=True)
 | 
			
		||||
    
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        x = self.depthwise(x)
 | 
			
		||||
        x = self.pointwise(x)
 | 
			
		||||
        x = self.bn(x)
 | 
			
		||||
        x = self.relu(x)
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
class ChannelAttention(nn.Module):
 | 
			
		||||
    """通道注意力机制"""
 | 
			
		||||
    
 | 
			
		||||
    def __init__(self, in_channels, reduction=16):
 | 
			
		||||
        super(ChannelAttention, self).__init__()
 | 
			
		||||
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
 | 
			
		||||
        self.max_pool = nn.AdaptiveMaxPool2d(1)
 | 
			
		||||
        
 | 
			
		||||
        self.fc = nn.Sequential(
 | 
			
		||||
            nn.Conv2d(in_channels, in_channels // reduction, 1, bias=False),
 | 
			
		||||
            nn.ReLU(inplace=True),
 | 
			
		||||
            nn.Conv2d(in_channels // reduction, in_channels, 1, bias=False)
 | 
			
		||||
        )
 | 
			
		||||
        self.sigmoid = nn.Sigmoid()
 | 
			
		||||
    
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        avg_out = self.fc(self.avg_pool(x))
 | 
			
		||||
        max_out = self.fc(self.max_pool(x))
 | 
			
		||||
        out = avg_out + max_out
 | 
			
		||||
        return x * self.sigmoid(out)
 | 
			
		||||
 | 
			
		||||
class InvertedResidual(nn.Module):
 | 
			
		||||
    """MobileNetV2的倒残差块"""
 | 
			
		||||
    
 | 
			
		||||
    def __init__(self, in_channels, out_channels, stride=1, expand_ratio=6):
 | 
			
		||||
        super(InvertedResidual, self).__init__()
 | 
			
		||||
        self.stride = stride
 | 
			
		||||
        self.use_residual = stride == 1 and in_channels == out_channels
 | 
			
		||||
        
 | 
			
		||||
        hidden_dim = int(round(in_channels * expand_ratio))
 | 
			
		||||
        
 | 
			
		||||
        layers = []
 | 
			
		||||
        if expand_ratio != 1:
 | 
			
		||||
            # 扩展层
 | 
			
		||||
            layers.extend([
 | 
			
		||||
                nn.Conv2d(in_channels, hidden_dim, 1, bias=False),
 | 
			
		||||
                nn.BatchNorm2d(hidden_dim),
 | 
			
		||||
                nn.ReLU6(inplace=True)
 | 
			
		||||
            ])
 | 
			
		||||
        
 | 
			
		||||
        # 深度卷积
 | 
			
		||||
        layers.extend([
 | 
			
		||||
            nn.Conv2d(hidden_dim, hidden_dim, 3, stride=stride, padding=1, groups=hidden_dim, bias=False),
 | 
			
		||||
            nn.BatchNorm2d(hidden_dim),
 | 
			
		||||
            nn.ReLU6(inplace=True),
 | 
			
		||||
            # 线性瓶颈
 | 
			
		||||
            nn.Conv2d(hidden_dim, out_channels, 1, bias=False),
 | 
			
		||||
            nn.BatchNorm2d(out_channels)
 | 
			
		||||
        ])
 | 
			
		||||
        
 | 
			
		||||
        self.conv = nn.Sequential(*layers)
 | 
			
		||||
    
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        if self.use_residual:
 | 
			
		||||
            return x + self.conv(x)
 | 
			
		||||
        else:
 | 
			
		||||
            return self.conv(x)
 | 
			
		||||
 | 
			
		||||
class LightweightCNN(nn.Module):
 | 
			
		||||
    """增强版轻量化CNN特征提取器"""
 | 
			
		||||
    
 | 
			
		||||
    def __init__(self, num_channels=3):
 | 
			
		||||
        super(LightweightCNN, self).__init__()
 | 
			
		||||
        
 | 
			
		||||
        # 初始卷积层 - 适当增加通道数
 | 
			
		||||
        self.conv1 = nn.Sequential(
 | 
			
		||||
            nn.Conv2d(num_channels, 48, kernel_size=3, stride=1, padding=1, bias=False),
 | 
			
		||||
            nn.BatchNorm2d(48),
 | 
			
		||||
            nn.ReLU6(inplace=True)
 | 
			
		||||
        )
 | 
			
		||||
        
 | 
			
		||||
        # 增强版MobileNet风格的特征提取
 | 
			
		||||
        self.features = nn.Sequential(
 | 
			
		||||
            # 第一组:48 -> 32
 | 
			
		||||
            InvertedResidual(48, 32, stride=1, expand_ratio=2),
 | 
			
		||||
            InvertedResidual(32, 32, stride=1, expand_ratio=2),  # 增加一层
 | 
			
		||||
            nn.MaxPool2d(kernel_size=2, stride=2),  # 32x128 -> 16x64
 | 
			
		||||
            
 | 
			
		||||
            # 第二组:32 -> 48
 | 
			
		||||
            InvertedResidual(32, 48, stride=1, expand_ratio=4),
 | 
			
		||||
            InvertedResidual(48, 48, stride=1, expand_ratio=4),
 | 
			
		||||
            nn.MaxPool2d(kernel_size=2, stride=2),  # 16x64 -> 8x32
 | 
			
		||||
            
 | 
			
		||||
            # 第三组:48 -> 64
 | 
			
		||||
            InvertedResidual(48, 64, stride=1, expand_ratio=4),
 | 
			
		||||
            InvertedResidual(64, 64, stride=1, expand_ratio=4),
 | 
			
		||||
            
 | 
			
		||||
            # 第四组:64 -> 96
 | 
			
		||||
            InvertedResidual(64, 96, stride=1, expand_ratio=4),
 | 
			
		||||
            InvertedResidual(96, 96, stride=1, expand_ratio=4),
 | 
			
		||||
            nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1)),  # 8x32 -> 4x32
 | 
			
		||||
            
 | 
			
		||||
            # 第五组:96 -> 128
 | 
			
		||||
            InvertedResidual(96, 128, stride=1, expand_ratio=4),
 | 
			
		||||
            InvertedResidual(128, 128, stride=1, expand_ratio=4),
 | 
			
		||||
            nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1)),  # 4x32 -> 2x32
 | 
			
		||||
            
 | 
			
		||||
            # 最后的卷积层 - 增加通道数
 | 
			
		||||
            nn.Conv2d(128, 160, kernel_size=2, stride=1, padding=0, bias=False),  # 2x32 -> 1x31
 | 
			
		||||
            nn.BatchNorm2d(160),
 | 
			
		||||
            nn.ReLU6(inplace=True)
 | 
			
		||||
        )
 | 
			
		||||
        
 | 
			
		||||
        # 通道注意力
 | 
			
		||||
        self.channel_attention = ChannelAttention(160)
 | 
			
		||||
    
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        x = self.conv1(x)
 | 
			
		||||
        x = self.features(x)
 | 
			
		||||
        x = self.channel_attention(x)
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
class LightweightGRU(nn.Module):
 | 
			
		||||
    """增强版轻量化GRU层"""
 | 
			
		||||
    
 | 
			
		||||
    def __init__(self, input_size, hidden_size, num_layers=2):  # 默认增加到2层
 | 
			
		||||
        super(LightweightGRU, self).__init__()
 | 
			
		||||
        self.gru = nn.GRU(input_size, hidden_size, num_layers=num_layers, 
 | 
			
		||||
                         bidirectional=True, batch_first=True, dropout=0.2 if num_layers > 1 else 0)
 | 
			
		||||
        # 增加一个额外的线性层
 | 
			
		||||
        self.linear1 = nn.Linear(hidden_size * 2, hidden_size * 2)
 | 
			
		||||
        self.linear2 = nn.Linear(hidden_size * 2, hidden_size)
 | 
			
		||||
        self.dropout = nn.Dropout(0.2)  # 增加dropout率
 | 
			
		||||
        self.norm = nn.LayerNorm(hidden_size)  # 添加层归一化
 | 
			
		||||
    
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        gru_out, _ = self.gru(x)
 | 
			
		||||
        output = self.linear1(gru_out)
 | 
			
		||||
        output = F.relu(output)  # 添加激活函数
 | 
			
		||||
        output = self.dropout(output)
 | 
			
		||||
        output = self.linear2(output)
 | 
			
		||||
        output = self.norm(output)  # 应用层归一化
 | 
			
		||||
        output = self.dropout(output)
 | 
			
		||||
        return output
 | 
			
		||||
 | 
			
		||||
class LightweightCRNN(nn.Module):
 | 
			
		||||
    """增强版轻量化CRNN模型"""
 | 
			
		||||
    
 | 
			
		||||
    def __init__(self, img_height, num_classes, num_channels=3, hidden_size=160):  # 调整隐藏层大小
 | 
			
		||||
        super(LightweightCRNN, self).__init__()
 | 
			
		||||
        
 | 
			
		||||
        self.img_height = img_height
 | 
			
		||||
        self.num_classes = num_classes
 | 
			
		||||
        self.hidden_size = hidden_size
 | 
			
		||||
        
 | 
			
		||||
        # 增强版轻量化CNN特征提取器
 | 
			
		||||
        self.cnn = LightweightCNN(num_channels)
 | 
			
		||||
        
 | 
			
		||||
        # 增强版轻量化RNN序列建模器
 | 
			
		||||
        self.rnn = LightweightGRU(160, hidden_size, num_layers=2)  # 使用更大的输入尺寸和2层GRU
 | 
			
		||||
        
 | 
			
		||||
        # 输出层 - 添加额外的全连接层
 | 
			
		||||
        self.fc = nn.Linear(hidden_size, hidden_size // 2)
 | 
			
		||||
        self.dropout = nn.Dropout(0.2)
 | 
			
		||||
        self.classifier = nn.Linear(hidden_size // 2, num_classes)
 | 
			
		||||
        
 | 
			
		||||
        # 初始化权重
 | 
			
		||||
        self._initialize_weights()
 | 
			
		||||
    
 | 
			
		||||
    def _initialize_weights(self):
 | 
			
		||||
        """初始化模型权重"""
 | 
			
		||||
        for m in self.modules():
 | 
			
		||||
            if isinstance(m, nn.Conv2d):
 | 
			
		||||
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
 | 
			
		||||
                if m.bias is not None:
 | 
			
		||||
                    nn.init.constant_(m.bias, 0)
 | 
			
		||||
            elif isinstance(m, nn.BatchNorm2d):
 | 
			
		||||
                nn.init.constant_(m.weight, 1)
 | 
			
		||||
                nn.init.constant_(m.bias, 0)
 | 
			
		||||
            elif isinstance(m, nn.Linear):
 | 
			
		||||
                nn.init.normal_(m.weight, 0, 0.01)
 | 
			
		||||
                if m.bias is not None:
 | 
			
		||||
                    nn.init.constant_(m.bias, 0)
 | 
			
		||||
    
 | 
			
		||||
    def forward(self, input):
 | 
			
		||||
        """
 | 
			
		||||
        input: [batch_size, channels, height, width]
 | 
			
		||||
        output: [seq_len, batch_size, num_classes]
 | 
			
		||||
        """
 | 
			
		||||
        # CNN特征提取
 | 
			
		||||
        conv_features = self.cnn(input)  # [batch_size, 160, 1, seq_len]
 | 
			
		||||
        
 | 
			
		||||
        # 重塑为RNN输入格式
 | 
			
		||||
        batch_size, channels, height, width = conv_features.size()
 | 
			
		||||
        assert height == 1, f"Height should be 1, got {height}"
 | 
			
		||||
        
 | 
			
		||||
        # [batch_size, 160, 1, seq_len] -> [batch_size, seq_len, 160]
 | 
			
		||||
        conv_features = conv_features.squeeze(2)  # [batch_size, 160, seq_len]
 | 
			
		||||
        conv_features = conv_features.permute(0, 2, 1)  # [batch_size, seq_len, 160]
 | 
			
		||||
        
 | 
			
		||||
        # RNN序列建模
 | 
			
		||||
        rnn_output = self.rnn(conv_features)  # [batch_size, seq_len, hidden_size]
 | 
			
		||||
        
 | 
			
		||||
        # 全连接层处理
 | 
			
		||||
        fc_output = self.fc(rnn_output)  # [batch_size, seq_len, hidden_size//2]
 | 
			
		||||
        fc_output = F.relu(fc_output)
 | 
			
		||||
        fc_output = self.dropout(fc_output)
 | 
			
		||||
        
 | 
			
		||||
        # 分类
 | 
			
		||||
        output = self.classifier(fc_output)  # [batch_size, seq_len, num_classes]
 | 
			
		||||
        
 | 
			
		||||
        # 转换为CTC期望的格式: [seq_len, batch_size, num_classes]
 | 
			
		||||
        output = output.permute(1, 0, 2)
 | 
			
		||||
        
 | 
			
		||||
        return output
 | 
			
		||||
 | 
			
		||||
class LightCTCDecoder:
 | 
			
		||||
    """轻量化CTC解码器"""
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        # 中国车牌字符集
 | 
			
		||||
        # 省份简称
 | 
			
		||||
        provinces = ['京', '津', '沪', '渝', '冀', '豫', '云', '辽', '黑', '湘', '皖', '鲁',
 | 
			
		||||
                    '新', '苏', '浙', '赣', '鄂', '桂', '甘', '晋', '蒙', '陕', '吉', '闽',
 | 
			
		||||
                    '贵', '粤', '青', '藏', '川', '宁', '琼']
 | 
			
		||||
        
 | 
			
		||||
        # 字母(包含I和O)
 | 
			
		||||
        letters = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M',
 | 
			
		||||
                  'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z']
 | 
			
		||||
        
 | 
			
		||||
        # 数字
 | 
			
		||||
        digits = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
 | 
			
		||||
        
 | 
			
		||||
        # 组合所有字符
 | 
			
		||||
        self.character = provinces + letters + digits
 | 
			
		||||
        
 | 
			
		||||
        # 添加空白字符用于CTC
 | 
			
		||||
        self.character = ['[blank]'] + self.character
 | 
			
		||||
        
 | 
			
		||||
        # 创建字符到索引的映射
 | 
			
		||||
        self.dict = {char: i for i, char in enumerate(self.character)}
 | 
			
		||||
        self.dict_reverse = {i: char for i, char in enumerate(self.character)}
 | 
			
		||||
        
 | 
			
		||||
        self.num_classes = len(self.character)
 | 
			
		||||
        self.blank_idx = 0
 | 
			
		||||
    
 | 
			
		||||
    def decode_greedy(self, predictions):
 | 
			
		||||
        """贪婪解码"""
 | 
			
		||||
        # 获取每个时间步的最大概率索引
 | 
			
		||||
        indices = torch.argmax(predictions, dim=1)
 | 
			
		||||
        
 | 
			
		||||
        # CTC解码:移除重复字符和空白字符
 | 
			
		||||
        decoded_chars = []
 | 
			
		||||
        prev_idx = -1
 | 
			
		||||
        
 | 
			
		||||
        for idx in indices:
 | 
			
		||||
            idx = idx.item()
 | 
			
		||||
            if idx != prev_idx and idx != self.blank_idx:
 | 
			
		||||
                if idx < len(self.character):
 | 
			
		||||
                    decoded_chars.append(self.character[idx])
 | 
			
		||||
            prev_idx = idx
 | 
			
		||||
        
 | 
			
		||||
        return ''.join(decoded_chars)
 | 
			
		||||
    
 | 
			
		||||
    def decode_with_confidence(self, predictions):
 | 
			
		||||
        """解码并返回置信度信息"""
 | 
			
		||||
        # 应用softmax获得概率
 | 
			
		||||
        probs = torch.softmax(predictions, dim=1)
 | 
			
		||||
        
 | 
			
		||||
        # 贪婪解码
 | 
			
		||||
        indices = torch.argmax(probs, dim=1)
 | 
			
		||||
        max_probs = torch.max(probs, dim=1)[0]
 | 
			
		||||
        
 | 
			
		||||
        # CTC解码
 | 
			
		||||
        decoded_chars = []
 | 
			
		||||
        char_confidences = []
 | 
			
		||||
        prev_idx = -1
 | 
			
		||||
        
 | 
			
		||||
        for i, idx in enumerate(indices):
 | 
			
		||||
            idx = idx.item()
 | 
			
		||||
            confidence = max_probs[i].item()
 | 
			
		||||
            
 | 
			
		||||
            if idx != prev_idx and idx != self.blank_idx:
 | 
			
		||||
                if idx < len(self.character):
 | 
			
		||||
                    decoded_chars.append(self.character[idx])
 | 
			
		||||
                    char_confidences.append(confidence)
 | 
			
		||||
            prev_idx = idx
 | 
			
		||||
        
 | 
			
		||||
        text = ''.join(decoded_chars)
 | 
			
		||||
        avg_confidence = np.mean(char_confidences) if char_confidences else 0.0
 | 
			
		||||
        
 | 
			
		||||
        return text, avg_confidence, char_confidences
 | 
			
		||||
 | 
			
		||||
class LightLicensePlatePreprocessor:
 | 
			
		||||
    """轻量化车牌图像预处理器"""
 | 
			
		||||
    def __init__(self, target_height=32, target_width=128):
 | 
			
		||||
        self.target_height = target_height
 | 
			
		||||
        self.target_width = target_width
 | 
			
		||||
        
 | 
			
		||||
        # 定义图像变换
 | 
			
		||||
        self.transform = transforms.Compose([
 | 
			
		||||
            transforms.Resize((target_height, target_width)),
 | 
			
		||||
            transforms.ToTensor(),
 | 
			
		||||
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
 | 
			
		||||
                               std=[0.229, 0.224, 0.225])
 | 
			
		||||
        ])
 | 
			
		||||
    
 | 
			
		||||
    def preprocess_numpy_array(self, image_array):
 | 
			
		||||
        """预处理numpy数组格式的图像"""
 | 
			
		||||
        try:
 | 
			
		||||
            # 确保图像是RGB格式
 | 
			
		||||
            if len(image_array.shape) == 3 and image_array.shape[2] == 3:
 | 
			
		||||
                # 如果是BGR格式,转换为RGB
 | 
			
		||||
                if image_array.dtype == np.uint8:
 | 
			
		||||
                    image_array = cv2.cvtColor(image_array, cv2.COLOR_BGR2RGB)
 | 
			
		||||
            
 | 
			
		||||
            # 转换为PIL图像
 | 
			
		||||
            if image_array.dtype != np.uint8:
 | 
			
		||||
                image_array = (image_array * 255).astype(np.uint8)
 | 
			
		||||
            
 | 
			
		||||
            image = Image.fromarray(image_array)
 | 
			
		||||
            
 | 
			
		||||
            # 应用变换
 | 
			
		||||
            tensor = self.transform(image)
 | 
			
		||||
            
 | 
			
		||||
            # 添加batch维度
 | 
			
		||||
            tensor = tensor.unsqueeze(0)
 | 
			
		||||
            
 | 
			
		||||
            return tensor
 | 
			
		||||
            
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            print(f"图像预处理失败: {e}")
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
def LPRNinitialize_model():
 | 
			
		||||
    """
 | 
			
		||||
    初始化轻量化CRNN模型
 | 
			
		||||
    
 | 
			
		||||
    返回:
 | 
			
		||||
        bool: 初始化是否成功
 | 
			
		||||
    """
 | 
			
		||||
    global lightcrnn_model, lightcrnn_decoder, lightcrnn_preprocessor, device
 | 
			
		||||
    
 | 
			
		||||
    try:
 | 
			
		||||
        # 设置设备
 | 
			
		||||
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
 | 
			
		||||
        print(f"LightCRNN使用设备: {device}")
 | 
			
		||||
        
 | 
			
		||||
        # 初始化组件
 | 
			
		||||
        lightcrnn_decoder = LightCTCDecoder()
 | 
			
		||||
        lightcrnn_preprocessor = LightLicensePlatePreprocessor(target_height=32, target_width=128)
 | 
			
		||||
        
 | 
			
		||||
        # 创建模型实例
 | 
			
		||||
        lightcrnn_model = LightweightCRNN(
 | 
			
		||||
            img_height=32, 
 | 
			
		||||
            num_classes=lightcrnn_decoder.num_classes, 
 | 
			
		||||
            hidden_size=160
 | 
			
		||||
        )
 | 
			
		||||
        
 | 
			
		||||
        # 加载模型权重
 | 
			
		||||
        model_path = os.path.join(os.path.dirname(__file__), 'best_model.pth')
 | 
			
		||||
        
 | 
			
		||||
        if not os.path.exists(model_path):
 | 
			
		||||
            raise FileNotFoundError(f"模型文件不存在: {model_path}")
 | 
			
		||||
        
 | 
			
		||||
        print(f"正在加载LightCRNN模型: {model_path}")
 | 
			
		||||
        
 | 
			
		||||
        # 加载检查点,处理可能的模块依赖问题
 | 
			
		||||
        try:
 | 
			
		||||
            checkpoint = torch.load(model_path, map_location=device, weights_only=False)
 | 
			
		||||
        except (ModuleNotFoundError, AttributeError) as e:
 | 
			
		||||
            if 'config' in str(e) or 'Config' in str(e):
 | 
			
		||||
                print("检测到模型文件包含config依赖,尝试使用weights_only模式加载...")
 | 
			
		||||
                try:
 | 
			
		||||
                    # 尝试使用weights_only=True来避免pickle问题
 | 
			
		||||
                    checkpoint = torch.load(model_path, map_location=device, weights_only=True)
 | 
			
		||||
                except Exception:
 | 
			
		||||
                    # 如果还是失败,创建一个更完整的mock config
 | 
			
		||||
                    import sys
 | 
			
		||||
                    import types
 | 
			
		||||
                    
 | 
			
		||||
                    # 创建mock config模块
 | 
			
		||||
                    mock_config = types.ModuleType('config')
 | 
			
		||||
                    
 | 
			
		||||
                    # 添加可能需要的Config类
 | 
			
		||||
                    class Config:
 | 
			
		||||
                        def __init__(self):
 | 
			
		||||
                            pass
 | 
			
		||||
                    
 | 
			
		||||
                    mock_config.Config = Config
 | 
			
		||||
                    sys.modules['config'] = mock_config
 | 
			
		||||
                    
 | 
			
		||||
                    try:
 | 
			
		||||
                        checkpoint = torch.load(model_path, map_location=device, weights_only=False)
 | 
			
		||||
                    finally:
 | 
			
		||||
                        # 清理临时模块
 | 
			
		||||
                        if 'config' in sys.modules:
 | 
			
		||||
                            del sys.modules['config']
 | 
			
		||||
            else:
 | 
			
		||||
                raise e
 | 
			
		||||
        
 | 
			
		||||
        # 处理不同的模型保存格式
 | 
			
		||||
        if isinstance(checkpoint, dict):
 | 
			
		||||
            if 'model_state_dict' in checkpoint:
 | 
			
		||||
                # 完整检查点格式
 | 
			
		||||
                state_dict = checkpoint['model_state_dict']
 | 
			
		||||
                print(f"检查点信息:")
 | 
			
		||||
                print(f"  - 训练轮次: {checkpoint.get('epoch', 'N/A')}")
 | 
			
		||||
                print(f"  - 最佳验证损失: {checkpoint.get('best_val_loss', 'N/A')}")
 | 
			
		||||
            else:
 | 
			
		||||
                # 精简模型格式(只包含权重)
 | 
			
		||||
                print("加载精简模型(仅权重)")
 | 
			
		||||
                state_dict = checkpoint
 | 
			
		||||
        else:
 | 
			
		||||
            # 直接是状态字典
 | 
			
		||||
            state_dict = checkpoint
 | 
			
		||||
        
 | 
			
		||||
        # 加载权重
 | 
			
		||||
        lightcrnn_model.load_state_dict(state_dict)
 | 
			
		||||
        lightcrnn_model.to(device)
 | 
			
		||||
        lightcrnn_model.eval()
 | 
			
		||||
        
 | 
			
		||||
        print("LightCRNN模型初始化完成")
 | 
			
		||||
        
 | 
			
		||||
        # 统计模型参数
 | 
			
		||||
        total_params = sum(p.numel() for p in lightcrnn_model.parameters())
 | 
			
		||||
        print(f"LightCRNN模型参数数量: {total_params:,}")
 | 
			
		||||
        
 | 
			
		||||
        return True
 | 
			
		||||
        
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        print(f"LightCRNN模型初始化失败: {e}")
 | 
			
		||||
        import traceback
 | 
			
		||||
        traceback.print_exc()
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
def LPRNmodel_predict(image_array):
 | 
			
		||||
    """
 | 
			
		||||
    轻量化CRNN车牌号识别接口函数
 | 
			
		||||
    
 | 
			
		||||
    参数:
 | 
			
		||||
        image_array: numpy数组格式的车牌图像,已经过矫正处理
 | 
			
		||||
    
 | 
			
		||||
    返回:
 | 
			
		||||
        list: 包含最多8个字符的列表,代表车牌号的每个字符
 | 
			
		||||
              例如: ['京', 'A', '1', '2', '3', '4', '5', ''] (蓝牌7位+占位符)
 | 
			
		||||
                   ['京', 'A', 'D', '1', '2', '3', '4', '5'] (绿牌8位)
 | 
			
		||||
    """
 | 
			
		||||
    global lightcrnn_model, lightcrnn_decoder, lightcrnn_preprocessor, device
 | 
			
		||||
    
 | 
			
		||||
    if lightcrnn_model is None or lightcrnn_decoder is None or lightcrnn_preprocessor is None:
 | 
			
		||||
        print("LightCRNN模型未初始化,请先调用LPRNinitialize_model()")
 | 
			
		||||
        return ['待', '识', '别', '0', '0', '0', '0', '0']
 | 
			
		||||
    
 | 
			
		||||
    try:
 | 
			
		||||
        # 预处理图像
 | 
			
		||||
        input_tensor = lightcrnn_preprocessor.preprocess_numpy_array(image_array)
 | 
			
		||||
        if input_tensor is None:
 | 
			
		||||
            raise ValueError("图像预处理失败")
 | 
			
		||||
        
 | 
			
		||||
        input_tensor = input_tensor.to(device)
 | 
			
		||||
        
 | 
			
		||||
        # 模型推理
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            outputs = lightcrnn_model(input_tensor)  # (seq_len, batch_size, num_classes)
 | 
			
		||||
            
 | 
			
		||||
            # 移除batch维度
 | 
			
		||||
            outputs = outputs.squeeze(1)  # (seq_len, num_classes)
 | 
			
		||||
            
 | 
			
		||||
            # CTC解码
 | 
			
		||||
            predicted_text, confidence, char_confidences = lightcrnn_decoder.decode_with_confidence(outputs)
 | 
			
		||||
            
 | 
			
		||||
            print(f"LightCRNN识别结果: {predicted_text}, 置信度: {confidence:.3f}")
 | 
			
		||||
            
 | 
			
		||||
            # 将字符串转换为字符列表
 | 
			
		||||
            char_list = list(predicted_text)
 | 
			
		||||
            
 | 
			
		||||
            # 确保返回至少7个字符,最多8个字符
 | 
			
		||||
            if len(char_list) < 7:
 | 
			
		||||
                # 如果识别结果少于7个字符,用'0'补齐到7位
 | 
			
		||||
                char_list.extend(['0'] * (7 - len(char_list)))
 | 
			
		||||
            elif len(char_list) > 8:
 | 
			
		||||
                # 如果识别结果多于8个字符,截取前8个
 | 
			
		||||
                char_list = char_list[:8]
 | 
			
		||||
            
 | 
			
		||||
            # 如果是7位,补齐到8位以保持接口一致性(第8位用空字符或占位符)
 | 
			
		||||
            if len(char_list) == 7:
 | 
			
		||||
                char_list.append('')  # 添加空字符作为第8位占位符
 | 
			
		||||
            
 | 
			
		||||
            return char_list
 | 
			
		||||
            
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        print(f"LightCRNN识别失败: {e}")
 | 
			
		||||
        import traceback
 | 
			
		||||
        traceback.print_exc()
 | 
			
		||||
        return ['识', '别', '失', '败', '0', '0', '0', '0']
 | 
			
		||||
 | 
			
		||||
def create_lightweight_model(model_type='lightweight_crnn', img_height=32, num_classes=66, hidden_size=160):
 | 
			
		||||
    """创建增强版轻量化模型"""
 | 
			
		||||
    if model_type == 'lightweight_crnn':
 | 
			
		||||
        return LightweightCRNN(img_height, num_classes, hidden_size=hidden_size)
 | 
			
		||||
    else:
 | 
			
		||||
        raise ValueError(f"Unknown lightweight model type: {model_type}")
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    # 测试轻量化模型
 | 
			
		||||
    print("测试LightCRNN模型...")
 | 
			
		||||
    
 | 
			
		||||
    # 初始化模型
 | 
			
		||||
    success = LPRNinitialize_model()
 | 
			
		||||
    if success:
 | 
			
		||||
        print("模型初始化成功")
 | 
			
		||||
        
 | 
			
		||||
        # 创建测试输入
 | 
			
		||||
        test_input = np.random.randint(0, 255, (32, 128, 3), dtype=np.uint8)
 | 
			
		||||
        
 | 
			
		||||
        # 测试预测
 | 
			
		||||
        result = LPRNmodel_predict(test_input)
 | 
			
		||||
        print(f"测试预测结果: {result}")
 | 
			
		||||
    else:
 | 
			
		||||
        print("模型初始化失败")
 | 
			
		||||
							
								
								
									
										5
									
								
								parking_config.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								parking_config.json
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,5 @@
 | 
			
		||||
{
 | 
			
		||||
    "free_parking_duration": 5,
 | 
			
		||||
    "billing_cycle": 3,
 | 
			
		||||
    "price_per_cycle": 5.0
 | 
			
		||||
}
 | 
			
		||||
@@ -2,6 +2,7 @@ import cv2
 | 
			
		||||
import numpy as np
 | 
			
		||||
from ultralytics import YOLO
 | 
			
		||||
import os
 | 
			
		||||
from PIL import Image, ImageDraw, ImageFont
 | 
			
		||||
 | 
			
		||||
class LicensePlateYOLO:
 | 
			
		||||
    """
 | 
			
		||||
@@ -45,7 +46,7 @@ class LicensePlateYOLO:
 | 
			
		||||
            print(f"YOLO模型加载失败: {e}")
 | 
			
		||||
            return False
 | 
			
		||||
    
 | 
			
		||||
    def detect_license_plates(self, image, conf_threshold=0.5):
 | 
			
		||||
    def detect_license_plates(self, image, conf_threshold=0.6):
 | 
			
		||||
        """
 | 
			
		||||
        检测图像中的车牌
 | 
			
		||||
        
 | 
			
		||||
@@ -113,19 +114,38 @@ class LicensePlateYOLO:
 | 
			
		||||
            print(f"检测过程中出错: {e}")
 | 
			
		||||
            return []
 | 
			
		||||
    
 | 
			
		||||
    def draw_detections(self, image, detections):
 | 
			
		||||
    def draw_detections(self, image, detections, plate_numbers=None):
 | 
			
		||||
        """
 | 
			
		||||
        在图像上绘制检测结果
 | 
			
		||||
        
 | 
			
		||||
        参数:
 | 
			
		||||
            image: 输入图像
 | 
			
		||||
            detections: 检测结果列表
 | 
			
		||||
            plate_numbers: 车牌号列表,与detections对应
 | 
			
		||||
        
 | 
			
		||||
        返回:
 | 
			
		||||
            numpy.ndarray: 绘制了检测结果的图像
 | 
			
		||||
        """
 | 
			
		||||
        draw_image = image.copy()
 | 
			
		||||
        
 | 
			
		||||
        # 转换为PIL图像以支持中文字符
 | 
			
		||||
        pil_image = Image.fromarray(cv2.cvtColor(draw_image, cv2.COLOR_BGR2RGB))
 | 
			
		||||
        draw = ImageDraw.Draw(pil_image)
 | 
			
		||||
        
 | 
			
		||||
        # 尝试加载中文字体
 | 
			
		||||
        try:
 | 
			
		||||
            # Windows系统常见的中文字体
 | 
			
		||||
            font_path = "C:/Windows/Fonts/simhei.ttf"  # 黑体
 | 
			
		||||
            if not os.path.exists(font_path):
 | 
			
		||||
                font_path = "C:/Windows/Fonts/msyh.ttc"  # 微软雅黑
 | 
			
		||||
            if not os.path.exists(font_path):
 | 
			
		||||
                font_path = "C:/Windows/Fonts/simsun.ttc"  # 宋体
 | 
			
		||||
            
 | 
			
		||||
            font = ImageFont.truetype(font_path, 20)
 | 
			
		||||
        except:
 | 
			
		||||
            # 如果无法加载字体,使用默认字体
 | 
			
		||||
            font = ImageFont.load_default()
 | 
			
		||||
        
 | 
			
		||||
        for i, detection in enumerate(detections):
 | 
			
		||||
            box = detection['box']
 | 
			
		||||
            keypoints = detection['keypoints']
 | 
			
		||||
@@ -133,6 +153,11 @@ class LicensePlateYOLO:
 | 
			
		||||
            confidence = detection['confidence']
 | 
			
		||||
            incomplete = detection.get('incomplete', False)
 | 
			
		||||
            
 | 
			
		||||
            # 获取对应的车牌号
 | 
			
		||||
            plate_number = ""
 | 
			
		||||
            if plate_numbers and i < len(plate_numbers):
 | 
			
		||||
                plate_number = plate_numbers[i]
 | 
			
		||||
            
 | 
			
		||||
            # 绘制边界框
 | 
			
		||||
            x1, y1, x2, y2 = map(int, box)
 | 
			
		||||
            
 | 
			
		||||
@@ -140,30 +165,53 @@ class LicensePlateYOLO:
 | 
			
		||||
            if class_name == '绿牌':
 | 
			
		||||
                box_color = (0, 255, 0)  # 绿色
 | 
			
		||||
            elif class_name == '蓝牌':
 | 
			
		||||
                box_color = (255, 0, 0)  # 蓝色
 | 
			
		||||
                box_color = (0, 0, 255)  # 蓝色
 | 
			
		||||
            else:
 | 
			
		||||
                box_color = (128, 128, 128)  # 灰色
 | 
			
		||||
            
 | 
			
		||||
            cv2.rectangle(draw_image, (x1, y1), (x2, y2), box_color, 2)
 | 
			
		||||
            # 在PIL图像上绘制边界框
 | 
			
		||||
            draw.rectangle([(x1, y1), (x2, y2)], outline=box_color, width=2)
 | 
			
		||||
            
 | 
			
		||||
            # 绘制标签
 | 
			
		||||
            # 构建标签文本
 | 
			
		||||
            if plate_number:
 | 
			
		||||
                label = f"{class_name} {plate_number} {confidence:.2f}"
 | 
			
		||||
            else:
 | 
			
		||||
                label = f"{class_name} {confidence:.2f}"
 | 
			
		||||
            
 | 
			
		||||
            if incomplete:
 | 
			
		||||
                label += " (不完整)"
 | 
			
		||||
            
 | 
			
		||||
            # 计算文本大小和位置
 | 
			
		||||
            font = cv2.FONT_HERSHEY_SIMPLEX
 | 
			
		||||
            font_scale = 0.6
 | 
			
		||||
            thickness = 2
 | 
			
		||||
            (text_width, text_height), _ = cv2.getTextSize(label, font, font_scale, thickness)
 | 
			
		||||
            # 计算文本大小
 | 
			
		||||
            bbox = draw.textbbox((0, 0), label, font=font)
 | 
			
		||||
            text_width = bbox[2] - bbox[0]
 | 
			
		||||
            text_height = bbox[3] - bbox[1]
 | 
			
		||||
            
 | 
			
		||||
            # 绘制文本背景
 | 
			
		||||
            cv2.rectangle(draw_image, (x1, y1 - text_height - 10), 
 | 
			
		||||
                         (x1 + text_width, y1), box_color, -1)
 | 
			
		||||
            draw.rectangle([(x1, y1 - text_height - 10), (x1 + text_width, y1)], 
 | 
			
		||||
                         fill=box_color)
 | 
			
		||||
            
 | 
			
		||||
            # 绘制文本
 | 
			
		||||
            cv2.putText(draw_image, label, (x1, y1 - 5), 
 | 
			
		||||
                       font, font_scale, (255, 255, 255), thickness)
 | 
			
		||||
            draw.text((x1, y1 - text_height - 5), label, fill=(255, 255, 255), font=font)
 | 
			
		||||
        
 | 
			
		||||
        # 转换回OpenCV格式
 | 
			
		||||
        draw_image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
 | 
			
		||||
        
 | 
			
		||||
        # 绘制关键点和连线(使用OpenCV)
 | 
			
		||||
        for i, detection in enumerate(detections):
 | 
			
		||||
            box = detection['box']
 | 
			
		||||
            keypoints = detection['keypoints']
 | 
			
		||||
            incomplete = detection.get('incomplete', False)
 | 
			
		||||
            
 | 
			
		||||
            x1, y1, x2, y2 = map(int, box)
 | 
			
		||||
            
 | 
			
		||||
            # 根据车牌类型选择颜色
 | 
			
		||||
            class_name = detection['class_name']
 | 
			
		||||
            if class_name == '绿牌':
 | 
			
		||||
                box_color = (0, 255, 0)  # 绿色
 | 
			
		||||
            elif class_name == '蓝牌':
 | 
			
		||||
                box_color = (0, 0, 255)  # 蓝色
 | 
			
		||||
            else:
 | 
			
		||||
                box_color = (128, 128, 128)  # 灰色
 | 
			
		||||
            
 | 
			
		||||
            # 绘制关键点和连线
 | 
			
		||||
            if len(keypoints) >= 4 and not incomplete:
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user