Compare commits
29 Commits
8eef0d9414
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 428b577808 | |||
| 15a83a5f06 | |||
| 418f7f3bc9 | |||
| a99e8fccb2 | |||
| 40f5e1c1be | |||
| c1fbccd7ee | |||
| d649738f6c | |||
| 6831a8cd01 | |||
| cf60d96066 | |||
| 09c3117f12 | |||
| 2a77e6ca8a | |||
| 56e7347c01 | |||
| 1c8e15bcd8 | |||
| 6c7f013a0c | |||
| 95aa6b6bba | |||
| 739cd1d914 | |||
| 01df759772 | |||
| cb88e6fccd | |||
| 80e995b47c | |||
| f82df06a68 | |||
| dc651af561 | |||
| 9f9bd25ce7 | |||
| 97ca0d75c2 | |||
| 75cc3b8ea3 | |||
| aca5703b9e | |||
| 2eba46bc40 | |||
| f342d37d63 | |||
| 1c914cf89f | |||
| afba7af80b |
8
.idea/.gitignore
generated
vendored
8
.idea/.gitignore
generated
vendored
@@ -1,8 +0,0 @@
|
||||
# 默认忽略的文件
|
||||
/shelf/
|
||||
/workspace.xml
|
||||
# 基于编辑器的 HTTP 客户端请求
|
||||
/httpRequests/
|
||||
# Datasource local storage ignored files
|
||||
/dataSources/
|
||||
/dataSources.local.xml
|
||||
12
.idea/License_plate_recognition.iml
generated
12
.idea/License_plate_recognition.iml
generated
@@ -1,12 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$" />
|
||||
<orderEntry type="jdk" jdkName="D:\conda_envs\RLP" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
<component name="PyDocumentationSettings">
|
||||
<option name="format" value="PLAIN" />
|
||||
<option name="myDocStringFormat" value="Plain" />
|
||||
</component>
|
||||
</module>
|
||||
6
.idea/inspectionProfiles/profiles_settings.xml
generated
6
.idea/inspectionProfiles/profiles_settings.xml
generated
@@ -1,6 +0,0 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<settings>
|
||||
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||
<version value="1.0" />
|
||||
</settings>
|
||||
</component>
|
||||
7
.idea/misc.xml
generated
7
.idea/misc.xml
generated
@@ -1,7 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="Black">
|
||||
<option name="sdkName" value="pytorh" />
|
||||
</component>
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="D:\conda_envs\RLP" project-jdk-type="Python SDK" />
|
||||
</project>
|
||||
8
.idea/modules.xml
generated
8
.idea/modules.xml
generated
@@ -1,8 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectModuleManager">
|
||||
<modules>
|
||||
<module fileurl="file://$PROJECT_DIR$/.idea/License_plate_recognition.iml" filepath="$PROJECT_DIR$/.idea/License_plate_recognition.iml" />
|
||||
</modules>
|
||||
</component>
|
||||
</project>
|
||||
7
.idea/vcs.xml
generated
7
.idea/vcs.xml
generated
@@ -1,7 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="" vcs="Git" />
|
||||
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
||||
</component>
|
||||
</project>
|
||||
Binary file not shown.
@@ -207,7 +207,7 @@ class LicensePlatePreprocessor:
|
||||
print(f"图像预处理失败: {e}")
|
||||
return None
|
||||
|
||||
def initialize_crnn_model():
|
||||
def LPRNinitialize_model():
|
||||
"""
|
||||
初始化CRNN模型
|
||||
|
||||
@@ -274,7 +274,7 @@ def initialize_crnn_model():
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def crnn_predict(image_array):
|
||||
def LPRNmodel_predict(image_array):
|
||||
"""
|
||||
CRNN车牌号识别接口函数
|
||||
|
||||
@@ -282,14 +282,15 @@ def crnn_predict(image_array):
|
||||
image_array: numpy数组格式的车牌图像,已经过矫正处理
|
||||
|
||||
返回:
|
||||
list: 包含7个字符的列表,代表车牌号的每个字符
|
||||
例如: ['京', 'A', '1', '2', '3', '4', '5']
|
||||
list: 包含最多8个字符的列表,代表车牌号的每个字符
|
||||
例如: ['京', 'A', '1', '2', '3', '4', '5', ''] (蓝牌7位+占位符)
|
||||
['京', 'A', 'D', '1', '2', '3', '4', '5'] (绿牌8位)
|
||||
"""
|
||||
global crnn_model, crnn_decoder, crnn_preprocessor, device
|
||||
|
||||
if crnn_model is None or crnn_decoder is None or crnn_preprocessor is None:
|
||||
print("CRNN模型未初始化,请先调用initialize_crnn_model()")
|
||||
return ['待', '识', '别', '0', '0', '0', '0']
|
||||
return ['待', '识', '别', '0', '0', '0', '0', '0']
|
||||
|
||||
try:
|
||||
# 预处理图像
|
||||
@@ -314,13 +315,17 @@ def crnn_predict(image_array):
|
||||
# 将字符串转换为字符列表
|
||||
char_list = list(predicted_text)
|
||||
|
||||
# 确保返回7个字符(车牌标准长度)
|
||||
# 确保返回至少7个字符,最多8个字符
|
||||
if len(char_list) < 7:
|
||||
# 如果识别结果少于7个字符,用'0'补齐
|
||||
# 如果识别结果少于7个字符,用'0'补齐到7位
|
||||
char_list.extend(['0'] * (7 - len(char_list)))
|
||||
elif len(char_list) > 7:
|
||||
# 如果识别结果多于7个字符,截取前7个
|
||||
char_list = char_list[:7]
|
||||
elif len(char_list) > 8:
|
||||
# 如果识别结果多于8个字符,截取前8个
|
||||
char_list = char_list[:8]
|
||||
|
||||
# 如果是7位,补齐到8位以保持接口一致性(第8位用空字符或占位符)
|
||||
if len(char_list) == 7:
|
||||
char_list.append('') # 添加空字符作为第8位占位符
|
||||
|
||||
return char_list
|
||||
|
||||
@@ -328,4 +333,4 @@ def crnn_predict(image_array):
|
||||
print(f"CRNN识别失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return ['识', '别', '失', '败', '0', '0', '0']
|
||||
return ['识', '别', '失', '败', '0', '0', '0', '0']
|
||||
|
||||
@@ -5,6 +5,18 @@ import cv2
|
||||
class OCRProcessor:
|
||||
def __init__(self):
|
||||
self.model = TextRecognition(model_name="PP-OCRv5_server_rec")
|
||||
# 定义允许的字符集合(不包含空白字符)
|
||||
self.allowed_chars = [
|
||||
# 中文省份简称
|
||||
'京', '沪', '津', '渝', '冀', '晋', '蒙', '辽', '吉', '黑',
|
||||
'苏', '浙', '皖', '闽', '赣', '鲁', '豫', '鄂', '湘', '粤',
|
||||
'桂', '琼', '川', '贵', '云', '藏', '陕', '甘', '青', '宁', '新',
|
||||
# 字母 A-Z
|
||||
'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M',
|
||||
'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',
|
||||
# 数字 0-9
|
||||
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9'
|
||||
]
|
||||
print("OCR模型初始化完成(占位)")
|
||||
|
||||
def predict(self, image_array):
|
||||
@@ -15,14 +27,64 @@ class OCRProcessor:
|
||||
placeholder_result = results.split(',')
|
||||
return placeholder_result
|
||||
|
||||
def filter_allowed_chars(self, text):
|
||||
"""只保留允许的字符"""
|
||||
filtered_text = ""
|
||||
for char in text:
|
||||
if char in self.allowed_chars:
|
||||
filtered_text += char
|
||||
return filtered_text
|
||||
|
||||
# 保留原有函数接口
|
||||
_processor = OCRProcessor()
|
||||
|
||||
def initialize_ocr_model():
|
||||
def LPRNinitialize_model():
|
||||
return _processor
|
||||
|
||||
def ocr_predict(image_array):
|
||||
return _processor.predict(image_array)
|
||||
def LPRNmodel_predict(image_array):
|
||||
"""
|
||||
OCR车牌号识别接口函数
|
||||
|
||||
参数:
|
||||
image_array: numpy数组格式的车牌图像,已经过矫正处理
|
||||
|
||||
返回:
|
||||
list: 包含最多8个字符的列表,代表车牌号的每个字符
|
||||
例如: ['京', 'A', '1', '2', '3', '4', '5', ''] (蓝牌7位+占位符)
|
||||
['京', 'A', 'D', '1', '2', '3', '4', '5'] (绿牌8位)
|
||||
"""
|
||||
# 获取原始预测结果
|
||||
raw_result = _processor.predict(image_array)
|
||||
|
||||
# 将结果合并为字符串(如果是列表的话)
|
||||
if isinstance(raw_result, list):
|
||||
result_str = ''.join(raw_result)
|
||||
else:
|
||||
result_str = str(raw_result)
|
||||
|
||||
# 过滤掉'·'和'-'字符
|
||||
filtered_str = result_str.replace('·', '')
|
||||
filtered_str = filtered_str.replace('-', '')
|
||||
|
||||
# 只保留允许的字符
|
||||
filtered_str = _processor.filter_allowed_chars(filtered_str)
|
||||
|
||||
# 转换为字符列表
|
||||
char_list = list(filtered_str)
|
||||
|
||||
# 确保返回至少7个字符,最多8个字符
|
||||
if len(char_list) < 7:
|
||||
# 如果识别结果少于7个字符,用'0'补齐到7位
|
||||
char_list.extend(['0'] * (7 - len(char_list)))
|
||||
elif len(char_list) > 8:
|
||||
# 如果识别结果多于8个字符,截取前8个
|
||||
char_list = char_list[:8]
|
||||
|
||||
# 如果是7位,补齐到8位以保持接口一致性(第8位用空字符或占位符)
|
||||
if len(char_list) == 7:
|
||||
char_list.append('') # 添加空字符作为第8位占位符
|
||||
|
||||
return char_list
|
||||
|
||||
|
||||
|
||||
|
||||
79
README.md
79
README.md
@@ -14,8 +14,11 @@ License_plate_recognition/
|
||||
│ └── yolo11s-pose42.pt # YOLO pose模型文件
|
||||
├── OCR_part/ # OCR识别模块
|
||||
│ └── ocr_interface.py # OCR接口(占位)
|
||||
└── CRNN_part/ # CRNN识别模块
|
||||
└── crnn_interface.py # CRNN
|
||||
├── CRNN_part/ # CRNN识别模块
|
||||
│ └── crnn_interface.py # CRNN接口(占位)
|
||||
└── LPRNET_part/ # LPRNet识别模块
|
||||
├── lpr_interface.py # LPRNet接口(已完成)
|
||||
└── LPRNet__iteration_74000.pth # LPRNet模型权重文件
|
||||
```
|
||||
|
||||
## 功能特性
|
||||
@@ -35,16 +38,21 @@ License_plate_recognition/
|
||||
- 将倾斜的车牌矫正为标准矩形
|
||||
- 输出标准尺寸的车牌图像供识别使用
|
||||
|
||||
### 4. PyQt界面
|
||||
### 4. 多种识别方案
|
||||
- 支持OCR、CRNN和LPRNet三种车牌识别方法
|
||||
- LPRNet模型准确率高达98%
|
||||
- 模块化接口设计,便于切换不同识别算法
|
||||
|
||||
### 5. PyQt界面
|
||||
- 左侧:实时摄像头画面显示
|
||||
- 右侧:检测结果展示区域
|
||||
- 顶部显示识别到的车牌数量
|
||||
- 每行显示:车牌类型、矫正后图像、车牌号
|
||||
- 美观的现代化界面设计
|
||||
|
||||
### 5. 模块化设计
|
||||
### 6. 模块化设计
|
||||
- yolopart:负责车牌定位和矫正
|
||||
- OCR_part/CRNN_part:负责车牌号识别(接口已预留)
|
||||
- OCR_part/CRNN_part/LPRNET_part:负责车牌号识别
|
||||
- 各模块独立,便于维护和扩展
|
||||
|
||||
## 安装和使用
|
||||
@@ -67,7 +75,21 @@ pip install -r requirements.txt
|
||||
python main.py
|
||||
```
|
||||
|
||||
### 5. 使用说明
|
||||
### 5. 选择识别模块
|
||||
在 `main.py` 中修改导入语句来选择不同的识别方案:
|
||||
|
||||
```python
|
||||
# 使用LPRNet(推荐,准确率98%)
|
||||
from LPRNET_part.lpr_interface import LPRNmodel_predict, LPRNinitialize_model
|
||||
|
||||
# 使用OCR
|
||||
from OCR_part.ocr_interface import LPRNmodel_predict, LPRNinitialize_model
|
||||
|
||||
# 使用CRNN
|
||||
from CRNN_part.crnn_interface import LPRNmodel_predict, LPRNinitialize_model
|
||||
```
|
||||
|
||||
### 6. 使用说明
|
||||
1. 点击"启动摄像头"按钮开始检测
|
||||
2. 将车牌对准摄像头
|
||||
3. 系统会自动检测车牌并显示:
|
||||
@@ -89,17 +111,20 @@ YOLO Pose模型输出包含:
|
||||
|
||||
## 接口说明
|
||||
|
||||
### OCR/CRNN接口
|
||||
车牌号识别部分使用统一接口:
|
||||
### 车牌识别接口
|
||||
|
||||
项目为OCR、CRNN和LPRNet识别模块提供了标准接口:
|
||||
|
||||
```python
|
||||
# OCR接口
|
||||
from OCR_part.ocr_interface import ocr_predict
|
||||
result = ocr_predict(corrected_image) # 返回7个字符的列表
|
||||
# 接口函数名(导入所需模块,每个模块统一函数名)
|
||||
|
||||
# CRNN接口
|
||||
from CRNN_part.crnn_interface import crnn_predict
|
||||
result = crnn_predict(corrected_image) # 返回7个字符的列表
|
||||
# 初始化
|
||||
from 对应模块 import LPRNinitialize_model
|
||||
LPRNinitialize_model()
|
||||
|
||||
# 预测主函数
|
||||
from 对应模块 import LPRNmodel_predict
|
||||
result = LPRNmodel_predict(corrected_image) # 返回7个字符的列表
|
||||
```
|
||||
|
||||
### 输入参数
|
||||
@@ -109,34 +134,26 @@ result = crnn_predict(corrected_image) # 返回7个字符的列表
|
||||
- 长度为7的字符列表,包含车牌号的每个字符
|
||||
- 例如:`['京', 'A', '1', '2', '3', '4', '5']`
|
||||
|
||||
### LPRNet模块特性
|
||||
|
||||
- **高准确率**: 模型准确率高达98%
|
||||
- **快速推理**: 基于深度学习的端到端识别
|
||||
- **CTC解码**: 使用CTC(Connectionist Temporal Classification)解码算法
|
||||
- **支持中文**: 完整支持中文省份简称和字母数字组合
|
||||
- **模型权重**: 使用预训练的LPRNet__iteration_74000.pth权重文件
|
||||
|
||||
## 开发说明
|
||||
|
||||
### 添加新的识别算法
|
||||
1. 在对应目录(OCR_part或CRNN_part)实现识别函数
|
||||
2. 确保函数签名与接口一致
|
||||
3. 在main.py中切换调用的函数即可
|
||||
3. 在main.py中导入对应模块即可
|
||||
|
||||
### 自定义模型
|
||||
1. 替换 `yolopart/yolo11s-pose42.pt` 文件
|
||||
2. 确保新模型输出格式与现有接口兼容
|
||||
3. 根据需要调整类别名称和数量
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. **模型文件**:确保YOLO模型文件路径正确
|
||||
2. **摄像头权限**:程序需要摄像头访问权限
|
||||
3. **光照条件**:良好的光照有助于提高检测精度
|
||||
4. **车牌角度**:尽量保持车牌完整出现在画面中
|
||||
5. **性能优化**:可根据硬件配置调整检测参数
|
||||
|
||||
## 故障排除
|
||||
|
||||
### 常见问题
|
||||
1. **摄像头无法启动**:检查摄像头是否被其他程序占用
|
||||
2. **模型加载失败**:确认模型文件路径和格式正确
|
||||
3. **检测效果差**:调整光照条件或摄像头角度
|
||||
4. **界面显示异常**:检查PyQt5安装是否完整
|
||||
|
||||
### 调试模式
|
||||
在代码中设置调试标志可以输出更多信息:
|
||||
```python
|
||||
|
||||
69
communicate.py
Normal file
69
communicate.py
Normal file
@@ -0,0 +1,69 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
向Hi3861设备发送JSON命令
|
||||
"""
|
||||
|
||||
import socket
|
||||
import json
|
||||
import time
|
||||
import pyttsx3
|
||||
import threading
|
||||
|
||||
target_ip = "192.168.43.12"
|
||||
target_port = 8081
|
||||
|
||||
def speak_text(text):
|
||||
"""
|
||||
使用文本转语音播放文本
|
||||
每次调用都创建新的引擎实例以避免并发问题
|
||||
"""
|
||||
def _speak():
|
||||
try:
|
||||
if text and text.strip(): # 确保文本不为空
|
||||
# 在线程内部创建新的引擎实例
|
||||
engine = pyttsx3.init()
|
||||
# 设置语音速度
|
||||
engine.setProperty('rate', 150)
|
||||
# 设置音量(0.0到1.0)
|
||||
engine.setProperty('volume', 0.8)
|
||||
|
||||
engine.say(text)
|
||||
engine.runAndWait()
|
||||
|
||||
# 清理引擎
|
||||
engine.stop()
|
||||
del engine
|
||||
except Exception as e:
|
||||
print(f"语音播放失败: {e}")
|
||||
|
||||
# 在新线程中播放语音,避免阻塞
|
||||
speech_thread = threading.Thread(target=_speak)
|
||||
speech_thread.daemon = True
|
||||
speech_thread.start()
|
||||
|
||||
def send_command(cmd, text):
|
||||
#cmd为1,道闸打开十秒后关闭,oled显示字符串信息(默认使用及cmd为4)
|
||||
#cmd为2,道闸舵机向打开方向旋转90度,oled上不显示(仅在qt界面手动开闸时调用)
|
||||
#cmd为3,道闸舵机向关闭方向旋转90度,oled上不显示(仅在qt界面手动关闸时调用)
|
||||
#cmd为4,oled显示字符串信息,道闸舵机不旋转
|
||||
|
||||
command = {
|
||||
"cmd": cmd,
|
||||
"text": text
|
||||
}
|
||||
|
||||
json_command = json.dumps(command, ensure_ascii=False)
|
||||
try:
|
||||
# 创建UDP socket
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
sock.sendto(json_command.encode('utf-8'), (target_ip, target_port))
|
||||
|
||||
# 发送命令后播放语音
|
||||
if text and text.strip():
|
||||
speak_text(text)
|
||||
|
||||
except Exception as e:
|
||||
print(f"发送命令失败: {e}")
|
||||
finally:
|
||||
sock.close()
|
||||
251
gate_control.py
Normal file
251
gate_control.py
Normal file
@@ -0,0 +1,251 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
道闸控制模块
|
||||
负责与Hi3861设备通信,控制道闸开关
|
||||
"""
|
||||
|
||||
import socket
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from PyQt5.QtCore import QObject, pyqtSignal, QThread
|
||||
|
||||
|
||||
class GateControlThread(QThread):
|
||||
"""道闸控制线程,用于异步发送命令"""
|
||||
command_sent = pyqtSignal(str, bool) # 信号:命令内容,是否成功
|
||||
|
||||
def __init__(self, ip, port, command):
|
||||
super().__init__()
|
||||
self.ip = ip
|
||||
self.port = port
|
||||
self.command = command
|
||||
|
||||
def run(self):
|
||||
"""发送命令到Hi3861设备"""
|
||||
try:
|
||||
# 创建UDP socket
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
|
||||
# 发送命令
|
||||
json_command = json.dumps(self.command, ensure_ascii=False)
|
||||
sock.sendto(json_command.encode('utf-8'), (self.ip, self.port))
|
||||
|
||||
# 发出成功信号
|
||||
self.command_sent.emit(json_command, True)
|
||||
|
||||
except Exception as e:
|
||||
# 发出失败信号
|
||||
self.command_sent.emit(f"发送失败: {e}", False)
|
||||
finally:
|
||||
sock.close()
|
||||
|
||||
|
||||
class GateController(QObject):
|
||||
"""道闸控制器"""
|
||||
|
||||
# 信号
|
||||
log_message = pyqtSignal(str) # 日志消息
|
||||
gate_opened = pyqtSignal(str) # 道闸打开信号,附带车牌号
|
||||
|
||||
def __init__(self, ip="192.168.43.12", port=8081):
|
||||
super().__init__()
|
||||
self.ip = ip
|
||||
self.port = port
|
||||
self.last_pass_times = {} # 记录车牌上次通过时间
|
||||
self.thread_pool = [] # 线程池
|
||||
|
||||
def send_command(self, cmd, text=""):
|
||||
"""
|
||||
发送命令到道闸
|
||||
|
||||
参数:
|
||||
cmd: 命令类型 (1-4)
|
||||
text: 显示文本
|
||||
|
||||
返回:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
# 创建JSON命令
|
||||
command = {
|
||||
"cmd": cmd,
|
||||
"text": text
|
||||
}
|
||||
|
||||
# 创建并启动线程发送命令
|
||||
thread = GateControlThread(self.ip, self.port, command)
|
||||
thread.command_sent.connect(self.on_command_sent)
|
||||
thread.start()
|
||||
self.thread_pool.append(thread)
|
||||
|
||||
# 记录日志
|
||||
cmd_desc = {
|
||||
1: "自动开闸(10秒后关闭)",
|
||||
2: "手动开闸",
|
||||
3: "手动关闸",
|
||||
4: "仅显示信息"
|
||||
}
|
||||
self.log_message.emit(f"发送命令: {cmd_desc.get(cmd, '未知命令')} - {text}")
|
||||
|
||||
return True
|
||||
|
||||
def on_command_sent(self, message, success):
|
||||
"""命令发送结果处理"""
|
||||
if success:
|
||||
self.log_message.emit(f"命令发送成功: {message}")
|
||||
else:
|
||||
self.log_message.emit(f"命令发送失败: {message}")
|
||||
|
||||
def auto_open_gate(self, plate_number):
|
||||
"""
|
||||
自动开闸(检测到白名单车牌时调用)
|
||||
|
||||
参数:
|
||||
plate_number: 车牌号
|
||||
"""
|
||||
# 获取当前时间
|
||||
current_time = datetime.now()
|
||||
time_diff_str = ""
|
||||
|
||||
# 检查是否是第一次通行
|
||||
if plate_number in self.last_pass_times:
|
||||
# 第二次或更多次通行,计算时间差
|
||||
last_time = self.last_pass_times[plate_number]
|
||||
time_diff = current_time - last_time
|
||||
|
||||
# 格式化时间差
|
||||
total_seconds = int(time_diff.total_seconds())
|
||||
minutes = total_seconds // 60
|
||||
seconds = total_seconds % 60
|
||||
|
||||
if minutes > 0:
|
||||
time_diff_str = f" {minutes}min{seconds}sec"
|
||||
else:
|
||||
time_diff_str = f" {seconds}sec"
|
||||
|
||||
# 计算时间差后清空之前记录的时间点
|
||||
del self.last_pass_times[plate_number]
|
||||
log_msg = f"检测到白名单车牌: {plate_number},自动开闸{time_diff_str},已清空时间记录"
|
||||
else:
|
||||
# 第一次通行,只记录时间,不计算时间差
|
||||
self.last_pass_times[plate_number] = current_time
|
||||
log_msg = f"检测到白名单车牌: {plate_number},首次通行,已记录时间"
|
||||
|
||||
# 发送开闸命令
|
||||
display_text = f"{plate_number} 通行{time_diff_str}"
|
||||
self.send_command(1, display_text)
|
||||
|
||||
# 发出信号
|
||||
self.gate_opened.emit(plate_number)
|
||||
|
||||
# 记录日志
|
||||
self.log_message.emit(log_msg)
|
||||
|
||||
def manual_open_gate(self):
|
||||
"""手动开闸"""
|
||||
self.send_command(2, "")
|
||||
self.log_message.emit("手动开闸")
|
||||
|
||||
def manual_close_gate(self):
|
||||
"""手动关闸"""
|
||||
self.send_command(3, "")
|
||||
self.log_message.emit("手动关闸")
|
||||
|
||||
def display_message(self, text):
|
||||
"""仅显示信息,不控制道闸"""
|
||||
self.send_command(4, text)
|
||||
self.log_message.emit(f"显示信息: {text}")
|
||||
|
||||
def deny_access(self, plate_number):
|
||||
"""
|
||||
拒绝通行(检测到非白名单车牌时调用)
|
||||
|
||||
参数:
|
||||
plate_number: 车牌号
|
||||
"""
|
||||
self.send_command(4, f"{plate_number} 禁止通行")
|
||||
self.log_message.emit(f"检测到非白名单车牌: {plate_number},拒绝通行")
|
||||
|
||||
|
||||
class WhitelistManager(QObject):
|
||||
"""白名单管理器"""
|
||||
|
||||
# 信号
|
||||
whitelist_changed = pyqtSignal(list) # 白名单变更信号
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.whitelist = [] # 白名单车牌列表
|
||||
|
||||
def add_plate(self, plate_number):
|
||||
"""
|
||||
添加车牌到白名单
|
||||
|
||||
参数:
|
||||
plate_number: 车牌号
|
||||
|
||||
返回:
|
||||
bool: 是否添加成功
|
||||
"""
|
||||
if not plate_number or plate_number in self.whitelist:
|
||||
return False
|
||||
|
||||
self.whitelist.append(plate_number)
|
||||
self.whitelist_changed.emit(self.whitelist.copy())
|
||||
return True
|
||||
|
||||
def remove_plate(self, plate_number):
|
||||
"""
|
||||
从白名单移除车牌
|
||||
|
||||
参数:
|
||||
plate_number: 车牌号
|
||||
|
||||
返回:
|
||||
bool: 是否移除成功
|
||||
"""
|
||||
if plate_number in self.whitelist:
|
||||
self.whitelist.remove(plate_number)
|
||||
self.whitelist_changed.emit(self.whitelist.copy())
|
||||
return True
|
||||
return False
|
||||
|
||||
def edit_plate(self, old_plate, new_plate):
|
||||
"""
|
||||
编辑白名单中的车牌
|
||||
|
||||
参数:
|
||||
old_plate: 原车牌号
|
||||
new_plate: 新车牌号
|
||||
|
||||
返回:
|
||||
bool: 是否编辑成功
|
||||
"""
|
||||
if old_plate in self.whitelist and new_plate not in self.whitelist:
|
||||
index = self.whitelist.index(old_plate)
|
||||
self.whitelist[index] = new_plate
|
||||
self.whitelist_changed.emit(self.whitelist.copy())
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_whitelisted(self, plate_number):
|
||||
"""
|
||||
检查车牌是否在白名单中
|
||||
|
||||
参数:
|
||||
plate_number: 车牌号
|
||||
|
||||
返回:
|
||||
bool: 是否在白名单中
|
||||
"""
|
||||
return plate_number in self.whitelist
|
||||
|
||||
def get_whitelist(self):
|
||||
"""获取白名单副本"""
|
||||
return self.whitelist.copy()
|
||||
|
||||
def clear_whitelist(self):
|
||||
"""清空白名单"""
|
||||
self.whitelist.clear()
|
||||
self.whitelist_changed.emit(self.whitelist.copy())
|
||||
BIN
lightCRNN_part/best_model.pth
Normal file
BIN
lightCRNN_part/best_model.pth
Normal file
Binary file not shown.
546
lightCRNN_part/lightcrnn_interface.py
Normal file
546
lightCRNN_part/lightcrnn_interface.py
Normal file
@@ -0,0 +1,546 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import cv2
|
||||
from torchvision import transforms
|
||||
import os
|
||||
import math
|
||||
|
||||
# 全局变量
|
||||
lightcrnn_model = None
|
||||
lightcrnn_decoder = None
|
||||
lightcrnn_preprocessor = None
|
||||
device = None
|
||||
|
||||
class DepthwiseSeparableConv(nn.Module):
|
||||
"""深度可分离卷积"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
|
||||
super(DepthwiseSeparableConv, self).__init__()
|
||||
# 深度卷积
|
||||
self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size,
|
||||
stride=stride, padding=padding, groups=in_channels, bias=False)
|
||||
# 逐点卷积
|
||||
self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
|
||||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
self.relu = nn.ReLU6(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.depthwise(x)
|
||||
x = self.pointwise(x)
|
||||
x = self.bn(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
class ChannelAttention(nn.Module):
|
||||
"""通道注意力机制"""
|
||||
|
||||
def __init__(self, in_channels, reduction=16):
|
||||
super(ChannelAttention, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.max_pool = nn.AdaptiveMaxPool2d(1)
|
||||
|
||||
self.fc = nn.Sequential(
|
||||
nn.Conv2d(in_channels, in_channels // reduction, 1, bias=False),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(in_channels // reduction, in_channels, 1, bias=False)
|
||||
)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
avg_out = self.fc(self.avg_pool(x))
|
||||
max_out = self.fc(self.max_pool(x))
|
||||
out = avg_out + max_out
|
||||
return x * self.sigmoid(out)
|
||||
|
||||
class InvertedResidual(nn.Module):
|
||||
"""MobileNetV2的倒残差块"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, stride=1, expand_ratio=6):
|
||||
super(InvertedResidual, self).__init__()
|
||||
self.stride = stride
|
||||
self.use_residual = stride == 1 and in_channels == out_channels
|
||||
|
||||
hidden_dim = int(round(in_channels * expand_ratio))
|
||||
|
||||
layers = []
|
||||
if expand_ratio != 1:
|
||||
# 扩展层
|
||||
layers.extend([
|
||||
nn.Conv2d(in_channels, hidden_dim, 1, bias=False),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.ReLU6(inplace=True)
|
||||
])
|
||||
|
||||
# 深度卷积
|
||||
layers.extend([
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, stride=stride, padding=1, groups=hidden_dim, bias=False),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.ReLU6(inplace=True),
|
||||
# 线性瓶颈
|
||||
nn.Conv2d(hidden_dim, out_channels, 1, bias=False),
|
||||
nn.BatchNorm2d(out_channels)
|
||||
])
|
||||
|
||||
self.conv = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_residual:
|
||||
return x + self.conv(x)
|
||||
else:
|
||||
return self.conv(x)
|
||||
|
||||
class LightweightCNN(nn.Module):
|
||||
"""增强版轻量化CNN特征提取器"""
|
||||
|
||||
def __init__(self, num_channels=3):
|
||||
super(LightweightCNN, self).__init__()
|
||||
|
||||
# 初始卷积层 - 适当增加通道数
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv2d(num_channels, 48, kernel_size=3, stride=1, padding=1, bias=False),
|
||||
nn.BatchNorm2d(48),
|
||||
nn.ReLU6(inplace=True)
|
||||
)
|
||||
|
||||
# 增强版MobileNet风格的特征提取
|
||||
self.features = nn.Sequential(
|
||||
# 第一组:48 -> 32
|
||||
InvertedResidual(48, 32, stride=1, expand_ratio=2),
|
||||
InvertedResidual(32, 32, stride=1, expand_ratio=2), # 增加一层
|
||||
nn.MaxPool2d(kernel_size=2, stride=2), # 32x128 -> 16x64
|
||||
|
||||
# 第二组:32 -> 48
|
||||
InvertedResidual(32, 48, stride=1, expand_ratio=4),
|
||||
InvertedResidual(48, 48, stride=1, expand_ratio=4),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2), # 16x64 -> 8x32
|
||||
|
||||
# 第三组:48 -> 64
|
||||
InvertedResidual(48, 64, stride=1, expand_ratio=4),
|
||||
InvertedResidual(64, 64, stride=1, expand_ratio=4),
|
||||
|
||||
# 第四组:64 -> 96
|
||||
InvertedResidual(64, 96, stride=1, expand_ratio=4),
|
||||
InvertedResidual(96, 96, stride=1, expand_ratio=4),
|
||||
nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1)), # 8x32 -> 4x32
|
||||
|
||||
# 第五组:96 -> 128
|
||||
InvertedResidual(96, 128, stride=1, expand_ratio=4),
|
||||
InvertedResidual(128, 128, stride=1, expand_ratio=4),
|
||||
nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1)), # 4x32 -> 2x32
|
||||
|
||||
# 最后的卷积层 - 增加通道数
|
||||
nn.Conv2d(128, 160, kernel_size=2, stride=1, padding=0, bias=False), # 2x32 -> 1x31
|
||||
nn.BatchNorm2d(160),
|
||||
nn.ReLU6(inplace=True)
|
||||
)
|
||||
|
||||
# 通道注意力
|
||||
self.channel_attention = ChannelAttention(160)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.features(x)
|
||||
x = self.channel_attention(x)
|
||||
return x
|
||||
|
||||
class LightweightGRU(nn.Module):
|
||||
"""增强版轻量化GRU层"""
|
||||
|
||||
def __init__(self, input_size, hidden_size, num_layers=2): # 默认增加到2层
|
||||
super(LightweightGRU, self).__init__()
|
||||
self.gru = nn.GRU(input_size, hidden_size, num_layers=num_layers,
|
||||
bidirectional=True, batch_first=True, dropout=0.2 if num_layers > 1 else 0)
|
||||
# 增加一个额外的线性层
|
||||
self.linear1 = nn.Linear(hidden_size * 2, hidden_size * 2)
|
||||
self.linear2 = nn.Linear(hidden_size * 2, hidden_size)
|
||||
self.dropout = nn.Dropout(0.2) # 增加dropout率
|
||||
self.norm = nn.LayerNorm(hidden_size) # 添加层归一化
|
||||
|
||||
def forward(self, x):
|
||||
gru_out, _ = self.gru(x)
|
||||
output = self.linear1(gru_out)
|
||||
output = F.relu(output) # 添加激活函数
|
||||
output = self.dropout(output)
|
||||
output = self.linear2(output)
|
||||
output = self.norm(output) # 应用层归一化
|
||||
output = self.dropout(output)
|
||||
return output
|
||||
|
||||
class LightweightCRNN(nn.Module):
|
||||
"""增强版轻量化CRNN模型"""
|
||||
|
||||
def __init__(self, img_height, num_classes, num_channels=3, hidden_size=160): # 调整隐藏层大小
|
||||
super(LightweightCRNN, self).__init__()
|
||||
|
||||
self.img_height = img_height
|
||||
self.num_classes = num_classes
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
# 增强版轻量化CNN特征提取器
|
||||
self.cnn = LightweightCNN(num_channels)
|
||||
|
||||
# 增强版轻量化RNN序列建模器
|
||||
self.rnn = LightweightGRU(160, hidden_size, num_layers=2) # 使用更大的输入尺寸和2层GRU
|
||||
|
||||
# 输出层 - 添加额外的全连接层
|
||||
self.fc = nn.Linear(hidden_size, hidden_size // 2)
|
||||
self.dropout = nn.Dropout(0.2)
|
||||
self.classifier = nn.Linear(hidden_size // 2, num_classes)
|
||||
|
||||
# 初始化权重
|
||||
self._initialize_weights()
|
||||
|
||||
def _initialize_weights(self):
|
||||
"""初始化模型权重"""
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.normal_(m.weight, 0, 0.01)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, input):
|
||||
"""
|
||||
input: [batch_size, channels, height, width]
|
||||
output: [seq_len, batch_size, num_classes]
|
||||
"""
|
||||
# CNN特征提取
|
||||
conv_features = self.cnn(input) # [batch_size, 160, 1, seq_len]
|
||||
|
||||
# 重塑为RNN输入格式
|
||||
batch_size, channels, height, width = conv_features.size()
|
||||
assert height == 1, f"Height should be 1, got {height}"
|
||||
|
||||
# [batch_size, 160, 1, seq_len] -> [batch_size, seq_len, 160]
|
||||
conv_features = conv_features.squeeze(2) # [batch_size, 160, seq_len]
|
||||
conv_features = conv_features.permute(0, 2, 1) # [batch_size, seq_len, 160]
|
||||
|
||||
# RNN序列建模
|
||||
rnn_output = self.rnn(conv_features) # [batch_size, seq_len, hidden_size]
|
||||
|
||||
# 全连接层处理
|
||||
fc_output = self.fc(rnn_output) # [batch_size, seq_len, hidden_size//2]
|
||||
fc_output = F.relu(fc_output)
|
||||
fc_output = self.dropout(fc_output)
|
||||
|
||||
# 分类
|
||||
output = self.classifier(fc_output) # [batch_size, seq_len, num_classes]
|
||||
|
||||
# 转换为CTC期望的格式: [seq_len, batch_size, num_classes]
|
||||
output = output.permute(1, 0, 2)
|
||||
|
||||
return output
|
||||
|
||||
class LightCTCDecoder:
|
||||
"""轻量化CTC解码器"""
|
||||
def __init__(self):
|
||||
# 中国车牌字符集
|
||||
# 省份简称
|
||||
provinces = ['京', '津', '沪', '渝', '冀', '豫', '云', '辽', '黑', '湘', '皖', '鲁',
|
||||
'新', '苏', '浙', '赣', '鄂', '桂', '甘', '晋', '蒙', '陕', '吉', '闽',
|
||||
'贵', '粤', '青', '藏', '川', '宁', '琼']
|
||||
|
||||
# 字母(包含I和O)
|
||||
letters = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M',
|
||||
'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z']
|
||||
|
||||
# 数字
|
||||
digits = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
|
||||
|
||||
# 组合所有字符
|
||||
self.character = provinces + letters + digits
|
||||
|
||||
# 添加空白字符用于CTC
|
||||
self.character = ['[blank]'] + self.character
|
||||
|
||||
# 创建字符到索引的映射
|
||||
self.dict = {char: i for i, char in enumerate(self.character)}
|
||||
self.dict_reverse = {i: char for i, char in enumerate(self.character)}
|
||||
|
||||
self.num_classes = len(self.character)
|
||||
self.blank_idx = 0
|
||||
|
||||
def decode_greedy(self, predictions):
|
||||
"""贪婪解码"""
|
||||
# 获取每个时间步的最大概率索引
|
||||
indices = torch.argmax(predictions, dim=1)
|
||||
|
||||
# CTC解码:移除重复字符和空白字符
|
||||
decoded_chars = []
|
||||
prev_idx = -1
|
||||
|
||||
for idx in indices:
|
||||
idx = idx.item()
|
||||
if idx != prev_idx and idx != self.blank_idx:
|
||||
if idx < len(self.character):
|
||||
decoded_chars.append(self.character[idx])
|
||||
prev_idx = idx
|
||||
|
||||
return ''.join(decoded_chars)
|
||||
|
||||
def decode_with_confidence(self, predictions):
|
||||
"""解码并返回置信度信息"""
|
||||
# 应用softmax获得概率
|
||||
probs = torch.softmax(predictions, dim=1)
|
||||
|
||||
# 贪婪解码
|
||||
indices = torch.argmax(probs, dim=1)
|
||||
max_probs = torch.max(probs, dim=1)[0]
|
||||
|
||||
# CTC解码
|
||||
decoded_chars = []
|
||||
char_confidences = []
|
||||
prev_idx = -1
|
||||
|
||||
for i, idx in enumerate(indices):
|
||||
idx = idx.item()
|
||||
confidence = max_probs[i].item()
|
||||
|
||||
if idx != prev_idx and idx != self.blank_idx:
|
||||
if idx < len(self.character):
|
||||
decoded_chars.append(self.character[idx])
|
||||
char_confidences.append(confidence)
|
||||
prev_idx = idx
|
||||
|
||||
text = ''.join(decoded_chars)
|
||||
avg_confidence = np.mean(char_confidences) if char_confidences else 0.0
|
||||
|
||||
return text, avg_confidence, char_confidences
|
||||
|
||||
class LightLicensePlatePreprocessor:
|
||||
"""轻量化车牌图像预处理器"""
|
||||
def __init__(self, target_height=32, target_width=128):
|
||||
self.target_height = target_height
|
||||
self.target_width = target_width
|
||||
|
||||
# 定义图像变换
|
||||
self.transform = transforms.Compose([
|
||||
transforms.Resize((target_height, target_width)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
def preprocess_numpy_array(self, image_array):
|
||||
"""预处理numpy数组格式的图像"""
|
||||
try:
|
||||
# 确保图像是RGB格式
|
||||
if len(image_array.shape) == 3 and image_array.shape[2] == 3:
|
||||
# 如果是BGR格式,转换为RGB
|
||||
if image_array.dtype == np.uint8:
|
||||
image_array = cv2.cvtColor(image_array, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# 转换为PIL图像
|
||||
if image_array.dtype != np.uint8:
|
||||
image_array = (image_array * 255).astype(np.uint8)
|
||||
|
||||
image = Image.fromarray(image_array)
|
||||
|
||||
# 应用变换
|
||||
tensor = self.transform(image)
|
||||
|
||||
# 添加batch维度
|
||||
tensor = tensor.unsqueeze(0)
|
||||
|
||||
return tensor
|
||||
|
||||
except Exception as e:
|
||||
print(f"图像预处理失败: {e}")
|
||||
return None
|
||||
|
||||
def LPRNinitialize_model():
|
||||
"""
|
||||
初始化轻量化CRNN模型
|
||||
|
||||
返回:
|
||||
bool: 初始化是否成功
|
||||
"""
|
||||
global lightcrnn_model, lightcrnn_decoder, lightcrnn_preprocessor, device
|
||||
|
||||
try:
|
||||
# 设置设备
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
print(f"LightCRNN使用设备: {device}")
|
||||
|
||||
# 初始化组件
|
||||
lightcrnn_decoder = LightCTCDecoder()
|
||||
lightcrnn_preprocessor = LightLicensePlatePreprocessor(target_height=32, target_width=128)
|
||||
|
||||
# 创建模型实例
|
||||
lightcrnn_model = LightweightCRNN(
|
||||
img_height=32,
|
||||
num_classes=lightcrnn_decoder.num_classes,
|
||||
hidden_size=160
|
||||
)
|
||||
|
||||
# 加载模型权重
|
||||
model_path = os.path.join(os.path.dirname(__file__), 'best_model.pth')
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
raise FileNotFoundError(f"模型文件不存在: {model_path}")
|
||||
|
||||
print(f"正在加载LightCRNN模型: {model_path}")
|
||||
|
||||
# 加载检查点,处理可能的模块依赖问题
|
||||
try:
|
||||
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
|
||||
except (ModuleNotFoundError, AttributeError) as e:
|
||||
if 'config' in str(e) or 'Config' in str(e):
|
||||
print("检测到模型文件包含config依赖,尝试使用weights_only模式加载...")
|
||||
try:
|
||||
# 尝试使用weights_only=True来避免pickle问题
|
||||
checkpoint = torch.load(model_path, map_location=device, weights_only=True)
|
||||
except Exception:
|
||||
# 如果还是失败,创建一个更完整的mock config
|
||||
import sys
|
||||
import types
|
||||
|
||||
# 创建mock config模块
|
||||
mock_config = types.ModuleType('config')
|
||||
|
||||
# 添加可能需要的Config类
|
||||
class Config:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
mock_config.Config = Config
|
||||
sys.modules['config'] = mock_config
|
||||
|
||||
try:
|
||||
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
|
||||
finally:
|
||||
# 清理临时模块
|
||||
if 'config' in sys.modules:
|
||||
del sys.modules['config']
|
||||
else:
|
||||
raise e
|
||||
|
||||
# 处理不同的模型保存格式
|
||||
if isinstance(checkpoint, dict):
|
||||
if 'model_state_dict' in checkpoint:
|
||||
# 完整检查点格式
|
||||
state_dict = checkpoint['model_state_dict']
|
||||
print(f"检查点信息:")
|
||||
print(f" - 训练轮次: {checkpoint.get('epoch', 'N/A')}")
|
||||
print(f" - 最佳验证损失: {checkpoint.get('best_val_loss', 'N/A')}")
|
||||
else:
|
||||
# 精简模型格式(只包含权重)
|
||||
print("加载精简模型(仅权重)")
|
||||
state_dict = checkpoint
|
||||
else:
|
||||
# 直接是状态字典
|
||||
state_dict = checkpoint
|
||||
|
||||
# 加载权重
|
||||
lightcrnn_model.load_state_dict(state_dict)
|
||||
lightcrnn_model.to(device)
|
||||
lightcrnn_model.eval()
|
||||
|
||||
print("LightCRNN模型初始化完成")
|
||||
|
||||
# 统计模型参数
|
||||
total_params = sum(p.numel() for p in lightcrnn_model.parameters())
|
||||
print(f"LightCRNN模型参数数量: {total_params:,}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"LightCRNN模型初始化失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def LPRNmodel_predict(image_array):
|
||||
"""
|
||||
轻量化CRNN车牌号识别接口函数
|
||||
|
||||
参数:
|
||||
image_array: numpy数组格式的车牌图像,已经过矫正处理
|
||||
|
||||
返回:
|
||||
list: 包含最多8个字符的列表,代表车牌号的每个字符
|
||||
例如: ['京', 'A', '1', '2', '3', '4', '5', ''] (蓝牌7位+占位符)
|
||||
['京', 'A', 'D', '1', '2', '3', '4', '5'] (绿牌8位)
|
||||
"""
|
||||
global lightcrnn_model, lightcrnn_decoder, lightcrnn_preprocessor, device
|
||||
|
||||
if lightcrnn_model is None or lightcrnn_decoder is None or lightcrnn_preprocessor is None:
|
||||
print("LightCRNN模型未初始化,请先调用LPRNinitialize_model()")
|
||||
return ['待', '识', '别', '0', '0', '0', '0', '0']
|
||||
|
||||
try:
|
||||
# 预处理图像
|
||||
input_tensor = lightcrnn_preprocessor.preprocess_numpy_array(image_array)
|
||||
if input_tensor is None:
|
||||
raise ValueError("图像预处理失败")
|
||||
|
||||
input_tensor = input_tensor.to(device)
|
||||
|
||||
# 模型推理
|
||||
with torch.no_grad():
|
||||
outputs = lightcrnn_model(input_tensor) # (seq_len, batch_size, num_classes)
|
||||
|
||||
# 移除batch维度
|
||||
outputs = outputs.squeeze(1) # (seq_len, num_classes)
|
||||
|
||||
# CTC解码
|
||||
predicted_text, confidence, char_confidences = lightcrnn_decoder.decode_with_confidence(outputs)
|
||||
|
||||
print(f"LightCRNN识别结果: {predicted_text}, 置信度: {confidence:.3f}")
|
||||
|
||||
# 将字符串转换为字符列表
|
||||
char_list = list(predicted_text)
|
||||
|
||||
# 确保返回至少7个字符,最多8个字符
|
||||
if len(char_list) < 7:
|
||||
# 如果识别结果少于7个字符,用'0'补齐到7位
|
||||
char_list.extend(['0'] * (7 - len(char_list)))
|
||||
elif len(char_list) > 8:
|
||||
# 如果识别结果多于8个字符,截取前8个
|
||||
char_list = char_list[:8]
|
||||
|
||||
# 如果是7位,补齐到8位以保持接口一致性(第8位用空字符或占位符)
|
||||
if len(char_list) == 7:
|
||||
char_list.append('') # 添加空字符作为第8位占位符
|
||||
|
||||
return char_list
|
||||
|
||||
except Exception as e:
|
||||
print(f"LightCRNN识别失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return ['识', '别', '失', '败', '0', '0', '0', '0']
|
||||
|
||||
def create_lightweight_model(model_type='lightweight_crnn', img_height=32, num_classes=66, hidden_size=160):
|
||||
"""创建增强版轻量化模型"""
|
||||
if model_type == 'lightweight_crnn':
|
||||
return LightweightCRNN(img_height, num_classes, hidden_size=hidden_size)
|
||||
else:
|
||||
raise ValueError(f"Unknown lightweight model type: {model_type}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试轻量化模型
|
||||
print("测试LightCRNN模型...")
|
||||
|
||||
# 初始化模型
|
||||
success = LPRNinitialize_model()
|
||||
if success:
|
||||
print("模型初始化成功")
|
||||
|
||||
# 创建测试输入
|
||||
test_input = np.random.randint(0, 255, (32, 128, 3), dtype=np.uint8)
|
||||
|
||||
# 测试预测
|
||||
result = LPRNmodel_predict(test_input)
|
||||
print(f"测试预测结果: {result}")
|
||||
else:
|
||||
print("模型初始化失败")
|
||||
5
parking_config.json
Normal file
5
parking_config.json
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"free_parking_duration": 5,
|
||||
"billing_cycle": 3,
|
||||
"price_per_cycle": 5.0
|
||||
}
|
||||
@@ -2,6 +2,7 @@ import cv2
|
||||
import numpy as np
|
||||
from ultralytics import YOLO
|
||||
import os
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
class LicensePlateYOLO:
|
||||
"""
|
||||
@@ -45,7 +46,7 @@ class LicensePlateYOLO:
|
||||
print(f"YOLO模型加载失败: {e}")
|
||||
return False
|
||||
|
||||
def detect_license_plates(self, image, conf_threshold=0.5):
|
||||
def detect_license_plates(self, image, conf_threshold=0.6):
|
||||
"""
|
||||
检测图像中的车牌
|
||||
|
||||
@@ -113,19 +114,38 @@ class LicensePlateYOLO:
|
||||
print(f"检测过程中出错: {e}")
|
||||
return []
|
||||
|
||||
def draw_detections(self, image, detections):
|
||||
def draw_detections(self, image, detections, plate_numbers=None):
|
||||
"""
|
||||
在图像上绘制检测结果
|
||||
|
||||
参数:
|
||||
image: 输入图像
|
||||
detections: 检测结果列表
|
||||
plate_numbers: 车牌号列表,与detections对应
|
||||
|
||||
返回:
|
||||
numpy.ndarray: 绘制了检测结果的图像
|
||||
"""
|
||||
draw_image = image.copy()
|
||||
|
||||
# 转换为PIL图像以支持中文字符
|
||||
pil_image = Image.fromarray(cv2.cvtColor(draw_image, cv2.COLOR_BGR2RGB))
|
||||
draw = ImageDraw.Draw(pil_image)
|
||||
|
||||
# 尝试加载中文字体
|
||||
try:
|
||||
# Windows系统常见的中文字体
|
||||
font_path = "C:/Windows/Fonts/simhei.ttf" # 黑体
|
||||
if not os.path.exists(font_path):
|
||||
font_path = "C:/Windows/Fonts/msyh.ttc" # 微软雅黑
|
||||
if not os.path.exists(font_path):
|
||||
font_path = "C:/Windows/Fonts/simsun.ttc" # 宋体
|
||||
|
||||
font = ImageFont.truetype(font_path, 20)
|
||||
except:
|
||||
# 如果无法加载字体,使用默认字体
|
||||
font = ImageFont.load_default()
|
||||
|
||||
for i, detection in enumerate(detections):
|
||||
box = detection['box']
|
||||
keypoints = detection['keypoints']
|
||||
@@ -133,6 +153,11 @@ class LicensePlateYOLO:
|
||||
confidence = detection['confidence']
|
||||
incomplete = detection.get('incomplete', False)
|
||||
|
||||
# 获取对应的车牌号
|
||||
plate_number = ""
|
||||
if plate_numbers and i < len(plate_numbers):
|
||||
plate_number = plate_numbers[i]
|
||||
|
||||
# 绘制边界框
|
||||
x1, y1, x2, y2 = map(int, box)
|
||||
|
||||
@@ -140,30 +165,53 @@ class LicensePlateYOLO:
|
||||
if class_name == '绿牌':
|
||||
box_color = (0, 255, 0) # 绿色
|
||||
elif class_name == '蓝牌':
|
||||
box_color = (255, 0, 0) # 蓝色
|
||||
box_color = (0, 0, 255) # 蓝色
|
||||
else:
|
||||
box_color = (128, 128, 128) # 灰色
|
||||
|
||||
cv2.rectangle(draw_image, (x1, y1), (x2, y2), box_color, 2)
|
||||
# 在PIL图像上绘制边界框
|
||||
draw.rectangle([(x1, y1), (x2, y2)], outline=box_color, width=2)
|
||||
|
||||
# 构建标签文本
|
||||
if plate_number:
|
||||
label = f"{class_name} {plate_number} {confidence:.2f}"
|
||||
else:
|
||||
label = f"{class_name} {confidence:.2f}"
|
||||
|
||||
# 绘制标签
|
||||
label = f"{class_name} {confidence:.2f}"
|
||||
if incomplete:
|
||||
label += " (不完整)"
|
||||
|
||||
# 计算文本大小和位置
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
font_scale = 0.6
|
||||
thickness = 2
|
||||
(text_width, text_height), _ = cv2.getTextSize(label, font, font_scale, thickness)
|
||||
# 计算文本大小
|
||||
bbox = draw.textbbox((0, 0), label, font=font)
|
||||
text_width = bbox[2] - bbox[0]
|
||||
text_height = bbox[3] - bbox[1]
|
||||
|
||||
# 绘制文本背景
|
||||
cv2.rectangle(draw_image, (x1, y1 - text_height - 10),
|
||||
(x1 + text_width, y1), box_color, -1)
|
||||
draw.rectangle([(x1, y1 - text_height - 10), (x1 + text_width, y1)],
|
||||
fill=box_color)
|
||||
|
||||
# 绘制文本
|
||||
cv2.putText(draw_image, label, (x1, y1 - 5),
|
||||
font, font_scale, (255, 255, 255), thickness)
|
||||
draw.text((x1, y1 - text_height - 5), label, fill=(255, 255, 255), font=font)
|
||||
|
||||
# 转换回OpenCV格式
|
||||
draw_image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
|
||||
|
||||
# 绘制关键点和连线(使用OpenCV)
|
||||
for i, detection in enumerate(detections):
|
||||
box = detection['box']
|
||||
keypoints = detection['keypoints']
|
||||
incomplete = detection.get('incomplete', False)
|
||||
|
||||
x1, y1, x2, y2 = map(int, box)
|
||||
|
||||
# 根据车牌类型选择颜色
|
||||
class_name = detection['class_name']
|
||||
if class_name == '绿牌':
|
||||
box_color = (0, 255, 0) # 绿色
|
||||
elif class_name == '蓝牌':
|
||||
box_color = (0, 0, 255) # 蓝色
|
||||
else:
|
||||
box_color = (128, 128, 128) # 灰色
|
||||
|
||||
# 绘制关键点和连线
|
||||
if len(keypoints) >= 4 and not incomplete:
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Reference in New Issue
Block a user