diff --git a/.idea/License_plate_recognition.iml b/.idea/License_plate_recognition.iml index 07abf20..2328911 100644 --- a/.idea/License_plate_recognition.iml +++ b/.idea/License_plate_recognition.iml @@ -2,7 +2,7 @@ - + diff --git a/.idea/misc.xml b/.idea/misc.xml index 060d2c5..9475f0b 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -1,4 +1,7 @@ - + + + \ No newline at end of file diff --git a/CRNN_part/crnn_interface.py b/CRNN_part/crnn_interface.py new file mode 100644 index 0000000..31e8c0f --- /dev/null +++ b/CRNN_part/crnn_interface.py @@ -0,0 +1,37 @@ +import numpy as np + +def initialize_crnn_model(): + """ + 初始化CRNN模型 + + 返回: + bool: 初始化是否成功 + """ + # CRNN模型初始化代码 + # 例如: 加载预训练模型、设置参数等 + + print("CRNN模型初始化完成(占位)") + return True + + +def crnn_predict(image_array): + """ + CRNN车牌号识别接口函数 + + 参数: + image_array: numpy数组格式的车牌图像,已经过矫正处理 + + 返回: + list: 包含7个字符的列表,代表车牌号的每个字符 + 例如: ['京', 'A', '1', '2', '3', '4', '5'] + """ + # 这是CRNN部分的占位函数 + # 实际实现时,这里应该包含: + # 1. 图像预处理 + # 2. CRNN模型推理 + # 3. CTC解码 + # 4. 后处理和字符识别 + + # 临时返回占位结果 + placeholder_result = ['待', '识', '别', '0', '0', '0', '0'] + return placeholder_result diff --git a/OCR_part/ocr_interface.py b/OCR_part/ocr_interface.py new file mode 100644 index 0000000..294e568 --- /dev/null +++ b/OCR_part/ocr_interface.py @@ -0,0 +1,36 @@ +import numpy as np + +def initialize_ocr_model(): + """ + 初始化OCR模型 + + 返回: + bool: 初始化是否成功 + """ + # OCR模型初始化代码 + # 例如: 加载预训练模型、设置参数等 + + print("OCR模型初始化完成(占位)") + return True + +def ocr_predict(image_array): + """ + OCR车牌号识别接口函数 + + 参数: + image_array: numpy数组格式的车牌图像,已经过矫正处理 + + 返回: + list: 包含7个字符的列表,代表车牌号的每个字符 + 例如: ['京', 'A', '1', '2', '3', '4', '5'] + """ + # 这是OCR部分的占位函数 + # 实际实现时,这里应该包含: + # 1. 图像预处理 + # 2. OCR模型推理 + # 3. 后处理和字符识别 + + # 临时返回占位结果 + placeholder_result = ['待', '识', '别', '0', '0', '0', '0'] + return placeholder_result + diff --git a/README.md b/README.md new file mode 100644 index 0000000..7876d34 --- /dev/null +++ b/README.md @@ -0,0 +1,155 @@ +# 车牌识别系统 + +基于YOLO11 Pose模型的实时车牌检测与识别系统,支持蓝牌和绿牌的检测、四角点定位、透视矫正和车牌号识别。 + +## 项目结构 + +``` +License_plate_recognition/ +├── main.py # 主程序入口,PyQt界面 +├── requirements.txt # 依赖包列表 +├── README.md # 项目说明文档 +├── yolopart/ # YOLO检测模块 +│ ├── detector.py # YOLO检测器类 +│ └── yolo11s-pose42.pt # YOLO pose模型文件 +├── OCR_part/ # OCR识别模块 +│ └── ocr_interface.py # OCR接口(占位) +└── CRNN_part/ # CRNN识别模块 + └── crnn_interface.py # CRNN接口(占位) +``` + +## 功能特性 + +### 1. 实时车牌检测 +- 基于YOLO11 Pose模型进行车牌检测 +- 支持蓝牌(类别0)和绿牌(类别1)识别 +- 实时摄像头画面处理 + +### 2. 四角点定位 +- 检测车牌的四个角点:right_bottom, left_bottom, left_top, right_top +- 只有检测到完整四个角点的车牌才进行后续处理 +- 用黄色线条连接四个角点显示检测结果 + +### 3. 透视矫正 +- 使用四个角点进行透视变换 +- 将倾斜的车牌矫正为标准矩形 +- 输出标准尺寸的车牌图像供识别使用 + +### 4. PyQt界面 +- 左侧:实时摄像头画面显示 +- 右侧:检测结果展示区域 + - 顶部显示识别到的车牌数量 + - 每行显示:车牌类型、矫正后图像、车牌号 +- 美观的现代化界面设计 + +### 5. 模块化设计 +- yolopart:负责车牌定位和矫正 +- OCR_part/CRNN_part:负责车牌号识别(接口已预留) +- 各模块独立,便于维护和扩展 + +## 安装和使用 + +### 1. 环境要求 +- Python 3.7+ +- Windows/Linux/macOS +- 摄像头设备 + +### 2. 安装依赖 +```bash +pip install -r requirements.txt +``` + +### 3. 模型文件 +确保 `yolopart/yolo11s-pose42.pt` 模型文件存在。这是一个YOLO11 Pose模型,专门训练用于车牌的四角点检测。 + +### 4. 运行程序 +```bash +python main.py +``` + +### 5. 使用说明 +1. 点击"启动摄像头"按钮开始检测 +2. 将车牌对准摄像头 +3. 系统会自动检测车牌并显示: + - 检测框和角点连线 + - 右侧显示车牌类型、矫正图像和车牌号 +4. 点击"停止摄像头"结束检测 + +## 模型输出格式 + +YOLO Pose模型输出包含: +- **检测框**:车牌的边界框坐标 +- **类别**:0=蓝牌,1=绿牌 +- **置信度**:检测置信度分数 +- **关键点**:四个角点坐标 + - right_bottom:右下角 + - left_bottom:左下角 + - left_top:左上角 + - right_top:右上角 + +## 接口说明 + +### OCR/CRNN接口 +车牌号识别部分使用统一接口: + +```python +# OCR接口 +from OCR_part.ocr_interface import ocr_predict +result = ocr_predict(corrected_image) # 返回7个字符的列表 + +# CRNN接口 +from CRNN_part.crnn_interface import crnn_predict +result = crnn_predict(corrected_image) # 返回7个字符的列表 +``` + +### 输入参数 +- `corrected_image`:numpy数组格式的矫正后车牌图像 + +### 返回值 +- 长度为7的字符列表,包含车牌号的每个字符 +- 例如:`['京', 'A', '1', '2', '3', '4', '5']` + +## 开发说明 + +### 添加新的识别算法 +1. 在对应目录(OCR_part或CRNN_part)实现识别函数 +2. 确保函数签名与接口一致 +3. 在main.py中切换调用的函数即可 + +### 自定义模型 +1. 替换 `yolopart/yolo11s-pose42.pt` 文件 +2. 确保新模型输出格式与现有接口兼容 +3. 根据需要调整类别名称和数量 + +## 注意事项 + +1. **模型文件**:确保YOLO模型文件路径正确 +2. **摄像头权限**:程序需要摄像头访问权限 +3. **光照条件**:良好的光照有助于提高检测精度 +4. **车牌角度**:尽量保持车牌完整出现在画面中 +5. **性能优化**:可根据硬件配置调整检测参数 + +## 故障排除 + +### 常见问题 +1. **摄像头无法启动**:检查摄像头是否被其他程序占用 +2. **模型加载失败**:确认模型文件路径和格式正确 +3. **检测效果差**:调整光照条件或摄像头角度 +4. **界面显示异常**:检查PyQt5安装是否完整 + +### 调试模式 +在代码中设置调试标志可以输出更多信息: +```python +# 在detector.py中设置verbose=True +results = self.model(image, conf=conf_threshold, verbose=True) +``` + +## 扩展功能 + +系统设计支持以下扩展: +- 多摄像头支持 +- 批量图像处理 +- 检测结果保存 +- 网络API接口 +- 数据库集成 +- 性能统计和分析 \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..b5b95c8 --- /dev/null +++ b/main.py @@ -0,0 +1,411 @@ +import sys +import cv2 +import numpy as np +from PyQt5.QtWidgets import ( + QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout, + QLabel, QPushButton, QScrollArea, QFrame, QSizePolicy +) +from PyQt5.QtCore import QTimer, Qt, pyqtSignal, QThread +from PyQt5.QtGui import QImage, QPixmap, QFont, QPainter, QPen, QColor +import os +from yolopart.detector import LicensePlateYOLO +from OCR_part.ocr_interface import ocr_predict +#from CRNN_part.crnn_interface import crnn_predict(不使用CRNN) + +class CameraThread(QThread): + """摄像头线程类""" + frame_ready = pyqtSignal(np.ndarray) + + def __init__(self): + super().__init__() + self.camera = None + self.running = False + + def start_camera(self): + """启动摄像头""" + self.camera = cv2.VideoCapture(0) + if self.camera.isOpened(): + self.running = True + self.start() + return True + return False + + def stop_camera(self): + """停止摄像头""" + self.running = False + if self.camera: + self.camera.release() + self.quit() + self.wait() + + def run(self): + """线程运行函数""" + while self.running: + if self.camera and self.camera.isOpened(): + ret, frame = self.camera.read() + if ret: + self.frame_ready.emit(frame) + self.msleep(30) # 约30fps + +class LicensePlateWidget(QWidget): + """单个车牌结果显示组件""" + + def __init__(self, plate_id, class_name, corrected_image, plate_number): + super().__init__() + self.plate_id = plate_id + self.init_ui(class_name, corrected_image, plate_number) + + def init_ui(self, class_name, corrected_image, plate_number): + layout = QHBoxLayout() + layout.setContentsMargins(10, 5, 10, 5) + + # 车牌类型标签 + type_label = QLabel(class_name) + type_label.setFixedWidth(60) + type_label.setAlignment(Qt.AlignCenter) + type_label.setStyleSheet( + "QLabel { " + "background-color: #4CAF50 if class_name == '绿牌' else #2196F3; " + "color: white; " + "border-radius: 5px; " + "padding: 5px; " + "font-weight: bold; " + "}" + ) + if class_name == '绿牌': + type_label.setStyleSheet( + "QLabel { " + "background-color: #4CAF50; " + "color: white; " + "border-radius: 5px; " + "padding: 5px; " + "font-weight: bold; " + "}" + ) + else: + type_label.setStyleSheet( + "QLabel { " + "background-color: #2196F3; " + "color: white; " + "border-radius: 5px; " + "padding: 5px; " + "font-weight: bold; " + "}" + ) + + # 矫正后的车牌图像 + image_label = QLabel() + image_label.setFixedSize(120, 40) + image_label.setStyleSheet("border: 1px solid #ddd; background-color: white;") + + if corrected_image is not None: + # 转换numpy数组为QPixmap + h, w = corrected_image.shape[:2] + if len(corrected_image.shape) == 3: + bytes_per_line = 3 * w + q_image = QImage(corrected_image.data, w, h, bytes_per_line, QImage.Format_RGB888).rgbSwapped() + else: + bytes_per_line = w + q_image = QImage(corrected_image.data, w, h, bytes_per_line, QImage.Format_Grayscale8) + + pixmap = QPixmap.fromImage(q_image) + scaled_pixmap = pixmap.scaled(120, 40, Qt.KeepAspectRatio, Qt.SmoothTransformation) + image_label.setPixmap(scaled_pixmap) + else: + image_label.setText("车牌未完全\n进入摄像头") + image_label.setAlignment(Qt.AlignCenter) + image_label.setStyleSheet("border: 1px solid #ddd; background-color: #f5f5f5; color: #666;") + + # 车牌号标签 + number_label = QLabel(plate_number) + number_label.setFixedWidth(150) + number_label.setAlignment(Qt.AlignCenter) + number_label.setStyleSheet( + "QLabel { " + "border: 1px solid #ddd; " + "background-color: white; " + "padding: 8px; " + "font-family: 'Courier New'; " + "font-size: 14px; " + "font-weight: bold; " + "}" + ) + + layout.addWidget(type_label) + layout.addWidget(image_label) + layout.addWidget(number_label) + layout.addStretch() + + self.setLayout(layout) + self.setStyleSheet( + "QWidget { " + "background-color: white; " + "border: 1px solid #e0e0e0; " + "border-radius: 8px; " + "margin: 2px; " + "}" + ) + +class MainWindow(QMainWindow): + """主窗口类""" + + def __init__(self): + super().__init__() + self.detector = None + self.camera_thread = None + self.current_frame = None + self.detections = [] + + self.init_ui() + self.init_detector() + self.init_camera() + + def init_ui(self): + """初始化用户界面""" + self.setWindowTitle("车牌识别系统") + self.setGeometry(100, 100, 1200, 800) + + # 创建中央widget + central_widget = QWidget() + self.setCentralWidget(central_widget) + + # 创建主布局 + main_layout = QHBoxLayout(central_widget) + + # 左侧摄像头显示区域 + left_frame = QFrame() + left_frame.setFrameStyle(QFrame.StyledPanel) + left_frame.setStyleSheet("QFrame { background-color: #f0f0f0; border: 2px solid #ddd; }") + left_layout = QVBoxLayout(left_frame) + + # 摄像头显示标签 + self.camera_label = QLabel() + self.camera_label.setMinimumSize(640, 480) + self.camera_label.setStyleSheet("QLabel { background-color: black; border: 1px solid #ccc; }") + self.camera_label.setAlignment(Qt.AlignCenter) + self.camera_label.setText("摄像头未启动") + self.camera_label.setScaledContents(True) + + # 控制按钮 + button_layout = QHBoxLayout() + self.start_button = QPushButton("启动摄像头") + self.stop_button = QPushButton("停止摄像头") + self.start_button.clicked.connect(self.start_camera) + self.stop_button.clicked.connect(self.stop_camera) + self.stop_button.setEnabled(False) + + button_layout.addWidget(self.start_button) + button_layout.addWidget(self.stop_button) + button_layout.addStretch() + + left_layout.addWidget(self.camera_label) + left_layout.addLayout(button_layout) + + # 右侧结果显示区域 + right_frame = QFrame() + right_frame.setFrameStyle(QFrame.StyledPanel) + right_frame.setFixedWidth(400) + right_frame.setStyleSheet("QFrame { background-color: #fafafa; border: 2px solid #ddd; }") + right_layout = QVBoxLayout(right_frame) + + # 标题 + title_label = QLabel("检测结果") + title_label.setAlignment(Qt.AlignCenter) + title_label.setFont(QFont("Arial", 16, QFont.Bold)) + title_label.setStyleSheet("QLabel { color: #333; padding: 10px; }") + + # 车牌数量显示 + self.count_label = QLabel("识别到的车牌数量: 0") + self.count_label.setAlignment(Qt.AlignCenter) + self.count_label.setFont(QFont("Arial", 12)) + self.count_label.setStyleSheet( + "QLabel { " + "background-color: #e3f2fd; " + "border: 1px solid #2196f3; " + "border-radius: 5px; " + "padding: 8px; " + "color: #1976d2; " + "font-weight: bold; " + "}" + ) + + # 滚动区域用于显示车牌结果 + scroll_area = QScrollArea() + scroll_area.setWidgetResizable(True) + scroll_area.setStyleSheet("QScrollArea { border: none; background-color: transparent; }") + + self.results_widget = QWidget() + self.results_layout = QVBoxLayout(self.results_widget) + self.results_layout.setAlignment(Qt.AlignTop) + + scroll_area.setWidget(self.results_widget) + + right_layout.addWidget(title_label) + right_layout.addWidget(self.count_label) + right_layout.addWidget(scroll_area) + + # 添加到主布局 + main_layout.addWidget(left_frame, 2) + main_layout.addWidget(right_frame, 1) + + # 设置样式 + self.setStyleSheet(""" + QMainWindow { + background-color: #f5f5f5; + } + QPushButton { + background-color: #2196F3; + color: white; + border: none; + padding: 8px 16px; + border-radius: 4px; + font-weight: bold; + } + QPushButton:hover { + background-color: #1976D2; + } + QPushButton:pressed { + background-color: #0D47A1; + } + QPushButton:disabled { + background-color: #cccccc; + color: #666666; + } + """) + + def init_detector(self): + """初始化检测器""" + model_path = os.path.join(os.path.dirname(__file__), "yolopart", "yolo11s-pose42.pt") + self.detector = LicensePlateYOLO(model_path) + + def init_camera(self): + """初始化摄像头线程""" + self.camera_thread = CameraThread() + self.camera_thread.frame_ready.connect(self.process_frame) + + def start_camera(self): + """启动摄像头""" + if self.camera_thread.start_camera(): + self.start_button.setEnabled(False) + self.stop_button.setEnabled(True) + self.camera_label.setText("摄像头启动中...") + else: + self.camera_label.setText("摄像头启动失败") + + def stop_camera(self): + """停止摄像头""" + self.camera_thread.stop_camera() + self.start_button.setEnabled(True) + self.stop_button.setEnabled(False) + self.camera_label.setText("摄像头已停止") + self.camera_label.clear() + + def process_frame(self, frame): + """处理摄像头帧""" + self.current_frame = frame.copy() + + # 进行车牌检测 + self.detections = self.detector.detect_license_plates(frame) + + # 在图像上绘制检测结果 + display_frame = self.draw_detections(frame.copy()) + + # 转换为Qt格式并显示 + self.display_frame(display_frame) + + # 更新右侧结果显示 + self.update_results_display() + + def draw_detections(self, frame): + """在图像上绘制检测结果""" + return self.detector.draw_detections(frame, self.detections) + + def display_frame(self, frame): + """显示帧到界面""" + rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + h, w, ch = rgb_frame.shape + bytes_per_line = ch * w + qt_image = QImage(rgb_frame.data, w, h, bytes_per_line, QImage.Format_RGB888) + + pixmap = QPixmap.fromImage(qt_image) + scaled_pixmap = pixmap.scaled(self.camera_label.size(), Qt.KeepAspectRatio, Qt.SmoothTransformation) + self.camera_label.setPixmap(scaled_pixmap) + + def update_results_display(self): + """更新右侧结果显示""" + # 更新车牌数量 + count = len(self.detections) + self.count_label.setText(f"识别到的车牌数量: {count}") + + # 清除之前的结果 + for i in reversed(range(self.results_layout.count())): + child = self.results_layout.itemAt(i).widget() + if child: + child.setParent(None) + + # 添加新的结果 + for i, detection in enumerate(self.detections): + # 矫正车牌图像 + corrected_image = self.correct_license_plate(detection) + + # 获取车牌号(占位) + plate_number = self.recognize_plate_number(corrected_image) + + # 创建车牌显示组件 + plate_widget = LicensePlateWidget( + i + 1, + detection['class_name'], + corrected_image, + plate_number + ) + + self.results_layout.addWidget(plate_widget) + + def correct_license_plate(self, detection): + """矫正车牌图像""" + if self.current_frame is None: + return None + + # 检查是否为不完整检测 + if detection.get('incomplete', False): + return None + + # 使用检测器的矫正方法 + return self.detector.correct_license_plate( + self.current_frame, + detection['keypoints'] + ) + + def recognize_plate_number(self, corrected_image): + """识别车牌号""" + if corrected_image is None: + return "识别失败" + + try: + # 使用OCR接口进行识别 + # 可以根据需要切换为CRNN: crnn_predict(corrected_image) + result = ocr_predict(corrected_image) + + # 将字符列表转换为字符串 + if isinstance(result, list) and len(result) >= 7: + return ''.join(result[:7]) + else: + return "识别失败" + except Exception as e: + print(f"车牌号识别失败: {e}") + return "识别失败" + + def closeEvent(self, event): + """窗口关闭事件""" + if self.camera_thread: + self.camera_thread.stop_camera() + event.accept() + +def main(): + app = QApplication(sys.argv) + window = MainWindow() + window.show() + sys.exit(app.exec_()) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..57cb5ac --- /dev/null +++ b/requirements.txt @@ -0,0 +1,20 @@ +# 车牌识别系统依赖包 + +# 深度学习和计算机视觉 +ultralytics>=8.0.0 +opencv-python>=4.5.0 +numpy>=1.21.0 + +# PyQt5界面 +PyQt5>=5.15.0 + +# 图像处理 +Pillow>=8.0.0 + +# 可选:如果需要GPU加速 +# torch>=1.9.0 +# torchvision>=0.10.0 + +# 可选:如果需要其他功能 +# matplotlib>=3.3.0 # 用于调试和可视化 +# scipy>=1.7.0 # 科学计算 \ No newline at end of file diff --git a/yolopart/.idea/.gitignore b/yolopart/.idea/.gitignore deleted file mode 100644 index 35410ca..0000000 --- a/yolopart/.idea/.gitignore +++ /dev/null @@ -1,8 +0,0 @@ -# 默认忽略的文件 -/shelf/ -/workspace.xml -# 基于编辑器的 HTTP 客户端请求 -/httpRequests/ -# Datasource local storage ignored files -/dataSources/ -/dataSources.local.xml diff --git a/yolopart/.idea/inspectionProfiles/profiles_settings.xml b/yolopart/.idea/inspectionProfiles/profiles_settings.xml deleted file mode 100644 index 105ce2d..0000000 --- a/yolopart/.idea/inspectionProfiles/profiles_settings.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - \ No newline at end of file diff --git a/yolopart/.idea/misc.xml b/yolopart/.idea/misc.xml deleted file mode 100644 index 9475f0b..0000000 --- a/yolopart/.idea/misc.xml +++ /dev/null @@ -1,7 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/yolopart/.idea/modules.xml b/yolopart/.idea/modules.xml deleted file mode 100644 index 526c1b4..0000000 --- a/yolopart/.idea/modules.xml +++ /dev/null @@ -1,8 +0,0 @@ - - - - - - - - \ No newline at end of file diff --git a/yolopart/.idea/vcs.xml b/yolopart/.idea/vcs.xml deleted file mode 100644 index 6c0b863..0000000 --- a/yolopart/.idea/vcs.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/yolopart/.idea/yolopart.iml b/yolopart/.idea/yolopart.iml deleted file mode 100644 index 2328911..0000000 --- a/yolopart/.idea/yolopart.iml +++ /dev/null @@ -1,12 +0,0 @@ - - - - - - - - - - \ No newline at end of file diff --git a/yolopart/README.md b/yolopart/README.md deleted file mode 100644 index 7aa9c6a..0000000 --- a/yolopart/README.md +++ /dev/null @@ -1,177 +0,0 @@ -# 车牌检测系统 - -基于YOLO11s模型的实时车牌检测应用,支持摄像头和视频文件输入,具备GPU加速和车牌识别接口。 - -## 功能特性 - -- ✅ **实时车牌检测**: 基于YOLO11s ONNX模型 -- ✅ **GPU加速**: 支持CUDA GPU推理加速 -- ✅ **多视频源**: 支持摄像头和视频文件切换 -- ✅ **实时显示**: 显示检测框、置信度和实时FPS -- ✅ **图像切割**: 自动切割检测到的车牌区域 -- ✅ **识别接口**: 预留车牌号识别接口,可接入OCR模型 -- ✅ **友好界面**: 基于PyQt5的现代化用户界面 - -## 系统要求 - -- Python 3.7+ -- Windows/Linux/macOS -- 摄像头(可选) -- NVIDIA GPU(可选,用于加速) - -## 安装依赖 - -```bash -# 安装基础依赖 -pip install -r requirements.txt - -# 如果需要CPU版本的onnxruntime -pip uninstall onnxruntime-gpu -pip install onnxruntime - -# 可选:安装车牌识别依赖 -# PaddleOCR -pip install paddlepaddle paddleocr - -# 或者 Tesseract -pip install pytesseract -``` - -## 使用方法 - -### 1. 准备模型文件 - -确保项目根目录下有以下文件: -- `last.onnx`: YOLO11s车牌检测模型 -- `video.mp4`: 测试视频文件(可选) - -### 2. 运行应用 - -```bash -python main.py -``` - -### 3. 界面操作 - -- **开始检测**: 点击"开始检测"按钮启动实时检测 -- **切换视频源**: 勾选/取消"使用摄像头"切换视频源 -- **启用检测**: 勾选/取消"启用检测"开关检测功能 -- **查看结果**: 右侧面板显示检测信息和车牌识别结果 - -## 项目结构 - -``` -yolopart/ -├── main.py # 主程序入口 -├── requirements.txt # 依赖包列表 -├── README.md # 项目说明 -├── last.onnx # YOLO11s模型文件 -├── video.mp4 # 测试视频文件 -├── ui/ # 用户界面模块 -│ ├── __init__.py -│ ├── main_window.py # 主窗口 -│ └── video_widget.py # 视频显示组件 -├── models/ # 模型模块 -│ ├── __init__.py -│ ├── yolo_detector.py # YOLO检测器 -│ └── plate_recognizer.py # 车牌识别接口 -└── utils/ # 工具模块 - ├── __init__.py - └── video_capture.py # 视频捕获管理 -``` - -## 核心功能说明 - -### YOLO检测器 (`models/yolo_detector.py`) - -- 支持ONNX格式的YOLO11s模型 -- 自动GPU/CPU推理选择 -- 640x640输入尺寸 -- NMS后处理 -- 检测框绘制和车牌切割 - -### 视频捕获 (`utils/video_capture.py`) - -- 摄像头自动检测和配置 -- 视频文件循环播放 -- 实时FPS计算和显示 -- 线程安全的帧获取 - -### 车牌识别接口 (`models/plate_recognizer.py`) - -提供了多种识别器实现: -- `MockPlateRecognizer`: 模拟识别器(用于测试) -- `PaddleOCRRecognizer`: PaddleOCR识别器 -- `TesseractRecognizer`: Tesseract识别器 - -可通过`PlateRecognizerManager`轻松切换不同的识别引擎。 - -## 配置说明 - -### 检测参数调整 - -在`models/yolo_detector.py`中可以调整: -- `conf_threshold`: 置信度阈值(默认0.5) -- `nms_threshold`: NMS阈值(默认0.4) -- `input_size`: 输入尺寸(默认640x640) - -### 视频参数调整 - -在`utils/video_capture.py`中可以调整: -- 摄像头分辨率和帧率 -- FPS计算窗口大小 -- 视频文件路径 - -## 扩展开发 - -### 添加新的车牌识别器 - -1. 继承`PlateRecognizerInterface`基类 -2. 实现`recognize`和`batch_recognize`方法 -3. 在`PlateRecognizerManager`中注册新识别器 - -### 添加新功能 - -- 检测结果保存 -- 车牌数据库管理 -- 网络接口API -- 多摄像头支持 - -## 故障排除 - -### 常见问题 - -1. **模型加载失败** - - 检查`last.onnx`文件是否存在 - - 确认onnxruntime版本兼容性 - -2. **摄像头无法打开** - - 检查摄像头是否被其他程序占用 - - 尝试不同的摄像头索引 - -3. **GPU加速不生效** - - 确认安装了`onnxruntime-gpu` - - 检查CUDA环境配置 - -4. **车牌识别失败** - - 检查OCR依赖是否正确安装 - - 尝试切换不同的识别器 - -### 性能优化 - -- 使用GPU加速推理 -- 调整检测阈值减少误检 -- 优化图像预处理流程 -- 使用多线程处理 - -## 许可证 - -本项目仅供学习和研究使用。 - -## 更新日志 - -### v1.0.0 -- 初始版本发布 -- 支持YOLO11s车牌检测 -- 实现基础UI界面 -- 预留车牌识别接口 \ No newline at end of file diff --git a/yolopart/detector.py b/yolopart/detector.py new file mode 100644 index 0000000..b95fd80 --- /dev/null +++ b/yolopart/detector.py @@ -0,0 +1,275 @@ +import cv2 +import numpy as np +from ultralytics import YOLO +import os + +class LicensePlateYOLO: + """ + 车牌YOLO检测器类 + 负责加载YOLO pose模型并进行车牌检测和角点提取 + """ + + def __init__(self, model_path=None): + """ + 初始化YOLO检测器 + + 参数: + model_path: 模型文件路径,如果为None则使用默认路径 + """ + self.model = None + self.model_path = model_path or self._get_default_model_path() + self.class_names = {0: '蓝牌', 1: '绿牌'} + self.load_model() + + def _get_default_model_path(self): + """获取默认模型路径""" + current_dir = os.path.dirname(__file__) + return os.path.join(current_dir, "yolo11s-pose42.pt") + + def load_model(self): + """ + 加载YOLO pose模型 + + 返回: + bool: 加载是否成功 + """ + try: + if os.path.exists(self.model_path): + self.model = YOLO(self.model_path) + print(f"YOLO模型加载成功: {self.model_path}") + return True + else: + print(f"模型文件不存在: {self.model_path}") + return False + except Exception as e: + print(f"YOLO模型加载失败: {e}") + return False + + def detect_license_plates(self, image, conf_threshold=0.5): + """ + 检测图像中的车牌 + + 参数: + image: 输入图像 (numpy数组) + conf_threshold: 置信度阈值 + + 返回: + list: 检测结果列表,每个元素包含: + - box: 边界框坐标 [x1, y1, x2, y2] + - keypoints: 四个角点坐标 [[x1,y1], [x2,y2], [x3,y3], [x4,y4]] + - confidence: 置信度 + - class_id: 类别ID (0=蓝牌, 1=绿牌) + - class_name: 类别名称 + """ + if self.model is None: + print("模型未加载") + return [] + + try: + # 进行推理 + results = self.model(image, conf=conf_threshold, verbose=False) + detections = [] + + for result in results: + # 检查是否有检测结果 + if result.boxes is None or result.keypoints is None: + continue + + # 提取检测信息 + boxes = result.boxes.xyxy.cpu().numpy() # 边界框 + keypoints = result.keypoints.xy.cpu().numpy() # 关键点 + confidences = result.boxes.conf.cpu().numpy() # 置信度 + classes = result.boxes.cls.cpu().numpy() # 类别 + + # 处理每个检测结果 + for i in range(len(boxes)): + # 检查关键点数量是否为4个 + if len(keypoints[i]) == 4: + class_id = int(classes[i]) + detection = { + 'box': boxes[i], + 'keypoints': keypoints[i], + 'confidence': confidences[i], + 'class_id': class_id, + 'class_name': self.class_names.get(class_id, '未知') + } + detections.append(detection) + else: + # 关键点不足4个,记录但标记为不完整 + class_id = int(classes[i]) + detection = { + 'box': boxes[i], + 'keypoints': keypoints[i] if len(keypoints[i]) > 0 else [], + 'confidence': confidences[i], + 'class_id': class_id, + 'class_name': self.class_names.get(class_id, '未知'), + 'incomplete': True # 标记为不完整 + } + detections.append(detection) + + return detections + + except Exception as e: + print(f"检测过程中出错: {e}") + return [] + + def draw_detections(self, image, detections): + """ + 在图像上绘制检测结果 + + 参数: + image: 输入图像 + detections: 检测结果列表 + + 返回: + numpy.ndarray: 绘制了检测结果的图像 + """ + draw_image = image.copy() + + for i, detection in enumerate(detections): + box = detection['box'] + keypoints = detection['keypoints'] + class_name = detection['class_name'] + confidence = detection['confidence'] + incomplete = detection.get('incomplete', False) + + # 绘制边界框 + x1, y1, x2, y2 = map(int, box) + + # 根据车牌类型选择颜色 + if class_name == '绿牌': + box_color = (0, 255, 0) # 绿色 + elif class_name == '蓝牌': + box_color = (255, 0, 0) # 蓝色 + else: + box_color = (128, 128, 128) # 灰色 + + cv2.rectangle(draw_image, (x1, y1), (x2, y2), box_color, 2) + + # 绘制标签 + label = f"{class_name} {confidence:.2f}" + if incomplete: + label += " (不完整)" + + # 计算文本大小和位置 + font = cv2.FONT_HERSHEY_SIMPLEX + font_scale = 0.6 + thickness = 2 + (text_width, text_height), _ = cv2.getTextSize(label, font, font_scale, thickness) + + # 绘制文本背景 + cv2.rectangle(draw_image, (x1, y1 - text_height - 10), + (x1 + text_width, y1), box_color, -1) + + # 绘制文本 + cv2.putText(draw_image, label, (x1, y1 - 5), + font, font_scale, (255, 255, 255), thickness) + + # 绘制关键点和连线 + if len(keypoints) >= 4 and not incomplete: + # 四个角点完整,用黄色连线 + points = [(int(kp[0]), int(kp[1])) for kp in keypoints[:4]] + + # 绘制关键点 + for point in points: + cv2.circle(draw_image, point, 5, (0, 255, 255), -1) + + # 连接关键点形成四边形(按顺序连接) + # 假设关键点顺序为: right_bottom, left_bottom, left_top, right_top + for j in range(4): + cv2.line(draw_image, points[j], points[(j+1)%4], (0, 255, 255), 2) + + elif len(keypoints) > 0: + # 关键点不完整,用红色标记现有点 + for kp in keypoints: + point = (int(kp[0]), int(kp[1])) + cv2.circle(draw_image, point, 5, (0, 0, 255), -1) + + return draw_image + + def correct_license_plate(self, image, keypoints, target_size=(240, 80)): + """ + 使用四个角点对车牌进行透视变换矫正 + + 参数: + image: 原始图像 + keypoints: 四个角点坐标 + target_size: 目标尺寸 (width, height) + + 返回: + numpy.ndarray: 矫正后的车牌图像,如果失败返回None + """ + if len(keypoints) != 4: + return None + + try: + # 将关键点转换为numpy数组 + src_points = np.array(keypoints, dtype=np.float32) + + # 定义目标矩形的四个角点 + # 假设关键点顺序为: right_bottom, left_bottom, left_top, right_top + # 重新排序为标准顺序: left_top, right_top, right_bottom, left_bottom + width, height = target_size + dst_points = np.array([ + [0, 0], # left_top + [width, 0], # right_top + [width, height], # right_bottom + [0, height] # left_bottom + ], dtype=np.float32) + + # 重新排序源点以匹配目标点 + # 原顺序: right_bottom, left_bottom, left_top, right_top + # 目标顺序: left_top, right_top, right_bottom, left_bottom + reordered_src = np.array([ + src_points[2], # left_top + src_points[3], # right_top + src_points[0], # right_bottom + src_points[1] # left_bottom + ], dtype=np.float32) + + # 计算透视变换矩阵 + matrix = cv2.getPerspectiveTransform(reordered_src, dst_points) + + # 应用透视变换 + corrected = cv2.warpPerspective(image, matrix, target_size) + + return corrected + + except Exception as e: + print(f"车牌矫正失败: {e}") + return None + + def get_model_info(self): + """ + 获取模型信息 + + 返回: + dict: 模型信息字典 + """ + if self.model is None: + return {"status": "未加载", "path": self.model_path} + + return { + "status": "已加载", + "path": self.model_path, + "model_type": "YOLO11 Pose", + "classes": self.class_names + } + +def initialize_yolo_detector(model_path=None): + """ + 初始化YOLO检测器的便捷函数 + + 参数: + model_path: 模型文件路径 + + 返回: + LicensePlateYOLO: 初始化后的检测器实例 + """ + detector = LicensePlateYOLO(model_path) + return detector + +if __name__ == "__main__": + # 测试代码 + detector = initialize_yolo_detector() + print("检测器信息:", detector.get_model_info()) \ No newline at end of file diff --git a/yolopart/main.py b/yolopart/main.py deleted file mode 100644 index 9549df0..0000000 --- a/yolopart/main.py +++ /dev/null @@ -1,34 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -车牌检测系统主程序 -基于YOLO11s模型的实时车牌检测应用 -""" - -import sys -import os -from PyQt5.QtWidgets import QApplication -from PyQt5.QtCore import Qt -from ui.main_window import MainWindow - -def main(): - """主函数""" - # 创建QApplication实例 - app = QApplication(sys.argv) - app.setAttribute(Qt.AA_EnableHighDpiScaling, True) - app.setAttribute(Qt.AA_UseHighDpiPixmaps, True) - - # 设置应用信息 - app.setApplicationName("车牌检测系统") - app.setApplicationVersion("1.0.0") - app.setOrganizationName("License Plate Detection") - - # 创建主窗口 - main_window = MainWindow() - main_window.show() - - # 运行应用 - sys.exit(app.exec_()) - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/yolopart/models/__init__.py b/yolopart/models/__init__.py deleted file mode 100644 index 2138ee8..0000000 --- a/yolopart/models/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# 模型模块初始化文件 \ No newline at end of file diff --git a/yolopart/models/plate_recognizer.py b/yolopart/models/plate_recognizer.py deleted file mode 100644 index a4d711e..0000000 --- a/yolopart/models/plate_recognizer.py +++ /dev/null @@ -1,490 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -车牌识别接口模块 -预留接口,可接入各种OCR模型进行车牌号识别 -""" - -import cv2 -import numpy as np -from typing import List, Optional, Dict, Any -from abc import ABC, abstractmethod - -class PlateRecognizerInterface(ABC): - """车牌识别接口基类""" - - @abstractmethod - def recognize(self, plate_image: np.ndarray) -> Dict[str, Any]: - """ - 识别车牌号 - - Args: - plate_image: 车牌图像 (BGR格式) - - Returns: - 识别结果字典,包含: - { - 'text': str, # 识别的车牌号 - 'confidence': float, # 置信度 (0-1) - 'success': bool # 是否识别成功 - } - """ - pass - - @abstractmethod - def batch_recognize(self, plate_images: List[np.ndarray]) -> List[Dict[str, Any]]: - """ - 批量识别车牌号 - - Args: - plate_images: 车牌图像列表 - - Returns: - 识别结果列表 - """ - pass - -class MockPlateRecognizer(PlateRecognizerInterface): - """模拟车牌识别器(用于测试)""" - - def __init__(self): - self.mock_plates = [ - "京A12345", "沪B67890", "粤C11111", "川D22222", - "鲁E33333", "苏F44444", "浙G55555", "闽H66666" - ] - self.call_count = 0 - - def recognize(self, plate_image: np.ndarray) -> Dict[str, Any]: - """ - 模拟识别单个车牌 - - Args: - plate_image: 车牌图像 - - Returns: - 模拟识别结果 - """ - # 模拟处理时间 - import time - time.sleep(0.01) # 10ms模拟处理时间 - - # 简单的图像质量检查 - if plate_image is None or plate_image.size == 0: - return { - 'text': '', - 'confidence': 0.0, - 'success': False - } - - # 检查图像尺寸 - height, width = plate_image.shape[:2] - if width < 50 or height < 20: - return { - 'text': '', - 'confidence': 0.3, - 'success': False - } - - # 模拟识别结果 - plate_text = self.mock_plates[self.call_count % len(self.mock_plates)] - confidence = 0.85 + (self.call_count % 10) * 0.01 # 0.85-0.94 - - self.call_count += 1 - - return { - 'text': plate_text, - 'confidence': confidence, - 'success': True - } - - def batch_recognize(self, plate_images: List[np.ndarray]) -> List[Dict[str, Any]]: - """ - 批量识别车牌 - - Args: - plate_images: 车牌图像列表 - - Returns: - 识别结果列表 - """ - results = [] - for plate_image in plate_images: - result = self.recognize(plate_image) - results.append(result) - return results - -class PaddleOCRRecognizer(PlateRecognizerInterface): - """PaddleOCR车牌识别器(示例实现)""" - - def __init__(self, use_gpu: bool = True): - """ - 初始化PaddleOCR识别器 - - Args: - use_gpu: 是否使用GPU - """ - self.use_gpu = use_gpu - self.ocr = None - self._init_ocr() - - def _init_ocr(self): - """初始化OCR模型""" - try: - # 这里可以接入PaddleOCR - # from paddleocr import PaddleOCR - # self.ocr = PaddleOCR(use_angle_cls=True, lang='ch', use_gpu=self.use_gpu) - print("PaddleOCR初始化完成(示例代码,需要安装PaddleOCR)") - except ImportError: - print("PaddleOCR未安装,使用模拟识别器") - self.ocr = None - - def recognize(self, plate_image: np.ndarray) -> Dict[str, Any]: - """ - 使用PaddleOCR识别车牌 - - Args: - plate_image: 车牌图像 - - Returns: - 识别结果 - """ - if self.ocr is None: - # 回退到模拟识别 - mock_recognizer = MockPlateRecognizer() - return mock_recognizer.recognize(plate_image) - - try: - # 使用PaddleOCR进行识别 - results = self.ocr.ocr(plate_image, cls=True) - - if results and len(results) > 0 and results[0]: - # 提取文本和置信度 - text_results = [] - for line in results[0]: - text = line[1][0] - confidence = line[1][1] - text_results.append((text, confidence)) - - # 选择置信度最高的结果 - if text_results: - best_result = max(text_results, key=lambda x: x[1]) - return { - 'text': best_result[0], - 'confidence': best_result[1], - 'success': True - } - - except Exception as e: - print(f"PaddleOCR识别失败: {e}") - - return { - 'text': '', - 'confidence': 0.0, - 'success': False - } - - def batch_recognize(self, plate_images: List[np.ndarray]) -> List[Dict[str, Any]]: - """ - 批量识别 - - Args: - plate_images: 车牌图像列表 - - Returns: - 识别结果列表 - """ - results = [] - for plate_image in plate_images: - result = self.recognize(plate_image) - results.append(result) - return results - -class TesseractRecognizer(PlateRecognizerInterface): - """Tesseract车牌识别器(示例实现)""" - - def __init__(self, lang: str = 'chi_sim+eng'): - """ - 初始化Tesseract识别器 - - Args: - lang: 识别语言 - """ - self.lang = lang - self.tesseract_available = self._check_tesseract() - - def _check_tesseract(self) -> bool: - """检查Tesseract是否可用""" - try: - import pytesseract - return True - except ImportError: - print("pytesseract未安装,使用模拟识别器") - return False - - def recognize(self, plate_image: np.ndarray) -> Dict[str, Any]: - """ - 使用Tesseract识别车牌 - - Args: - plate_image: 车牌图像 - - Returns: - 识别结果 - """ - if not self.tesseract_available: - # 回退到模拟识别 - mock_recognizer = MockPlateRecognizer() - return mock_recognizer.recognize(plate_image) - - try: - import pytesseract - - # 图像预处理 - processed_image = self._preprocess_image(plate_image) - - # 使用Tesseract识别 - text = pytesseract.image_to_string( - processed_image, - lang=self.lang, - config='--psm 8 --oem 3 -c tessedit_char_whitelist=0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ京沪粤川鲁苏浙闽' - ) - - # 清理识别结果 - text = text.strip().replace(' ', '').replace('\n', '') - - if text and len(text) >= 5: # 车牌号至少5位 - return { - 'text': text, - 'confidence': 0.8, # Tesseract不直接提供置信度 - 'success': True - } - - except Exception as e: - print(f"Tesseract识别失败: {e}") - - return { - 'text': '', - 'confidence': 0.0, - 'success': False - } - - def _preprocess_image(self, image: np.ndarray) -> np.ndarray: - """图像预处理""" - # 转换为灰度图 - if len(image.shape) == 3: - gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) - else: - gray = image - - # 调整尺寸 - height, width = gray.shape - if width < 200: - scale = 200 / width - new_width = int(width * scale) - new_height = int(height * scale) - gray = cv2.resize(gray, (new_width, new_height)) - - # 二值化 - _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) - - return binary - - def batch_recognize(self, plate_images: List[np.ndarray]) -> List[Dict[str, Any]]: - """ - 批量识别 - - Args: - plate_images: 车牌图像列表 - - Returns: - 识别结果列表 - """ - results = [] - for plate_image in plate_images: - result = self.recognize(plate_image) - results.append(result) - return results - -class PlateRecognizerManager: - """车牌识别管理器""" - - def __init__(self, recognizer_type: str = 'mock'): - """ - 初始化识别管理器 - - Args: - recognizer_type: 识别器类型 ('mock', 'paddleocr', 'tesseract') - """ - self.recognizer_type = recognizer_type - self.recognizer = self._create_recognizer(recognizer_type) - - def _create_recognizer(self, recognizer_type: str) -> PlateRecognizerInterface: - """创建识别器""" - if recognizer_type == 'mock': - return MockPlateRecognizer() - elif recognizer_type == 'paddleocr': - return PaddleOCRRecognizer() - elif recognizer_type == 'tesseract': - return TesseractRecognizer() - else: - print(f"未知的识别器类型: {recognizer_type},使用模拟识别器") - return MockPlateRecognizer() - - def recognize_plates(self, plate_images: List[np.ndarray]) -> List[Dict[str, Any]]: - """ - 识别车牌列表 - - Args: - plate_images: 车牌图像列表 - - Returns: - 识别结果列表 - """ - if not plate_images: - return [] - - return self.recognizer.batch_recognize(plate_images) - - def switch_recognizer(self, recognizer_type: str): - """ - 切换识别器 - - Args: - recognizer_type: 新的识别器类型 - """ - self.recognizer_type = recognizer_type - self.recognizer = self._create_recognizer(recognizer_type) - print(f"已切换到识别器: {recognizer_type}") - - def get_recognizer_info(self) -> Dict[str, Any]: - """ - 获取识别器信息 - - Returns: - 识别器信息 - """ - return { - 'type': self.recognizer_type, - 'class': self.recognizer.__class__.__name__ - } - - def preprocess_blue_plate(self, plate_image: np.ndarray, original_image: np.ndarray, bbox: List[int]) -> np.ndarray: - """ - 蓝色车牌预处理:倾斜矫正 - - Args: - plate_image: 切割后的车牌图像 - original_image: 原始图像 - bbox: 边界框坐标 [x1, y1, x2, y2] - - Returns: - 矫正后的车牌图像 - """ - try: - # 从原图中提取车牌区域 - x1, y1, x2, y2 = bbox - roi = original_image[y1:y2, x1:x2] - - # 获取蓝色车牌的二值图像 - bin_img = self._get_blue_img_bin(roi) - - # 倾斜矫正 - corrected_img = self._deskew_plate(bin_img, roi) - - return corrected_img - except Exception as e: - print(f"蓝色车牌预处理失败: {e}") - return plate_image - - def _get_blue_img_bin(self, img: np.ndarray) -> np.ndarray: - """ - 获取蓝色车牌的二值图像 - """ - # 掩膜:BGR通道,若像素B分量在 100~255 且 G分量在 0~190 且 R分量在 0~140 置255(白色),否则置0(黑色) - mask_bgr = cv2.inRange(img, (100, 0, 0), (255, 190, 140)) - - # 转换成 HSV 颜色空间 - img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) - h, s, v = cv2.split(img_hsv) # 分离通道 色调(H),饱和度(S),明度(V) - mask_s = cv2.inRange(s, 80, 255) # 取饱和度通道进行掩膜得到二值图像 - - # 与操作,两个二值图像都为白色才保留,否则置黑 - rgbs = mask_bgr & mask_s - - # 核的横向分量大,使车牌数字尽量连在一起 - kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (15, 3)) - img_rgbs_dilate = cv2.dilate(rgbs, kernel, 3) # 膨胀,减小车牌空洞 - - return img_rgbs_dilate - - def _order_points(self, pts: np.ndarray) -> np.ndarray: - """ - 将四点按 左上、右上、右下、左下 排序 - """ - rect = np.zeros((4, 2), dtype="float32") - s = pts.sum(axis=1) - rect[0] = pts[np.argmin(s)] # 左上 - rect[2] = pts[np.argmax(s)] # 右下 - - diff = np.diff(pts, axis=1) - rect[1] = pts[np.argmin(diff)] # 右上 - rect[3] = pts[np.argmax(diff)] # 左下 - - return rect - - def _deskew_plate(self, bin_img: np.ndarray, original_roi: np.ndarray) -> np.ndarray: - """ - 车牌倾斜矫正 - - Args: - bin_img: 二值图像 - original_roi: 原始ROI区域 - - Returns: - 矫正后的原始图像(未被掩模,但经过旋转和切割) - """ - try: - # 找最大轮廓 - cnts, _ = cv2.findContours(bin_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) - if not cnts: - return original_roi - - c = max(cnts, key=cv2.contourArea) - - # 最小外接矩形 - rect = cv2.minAreaRect(c) - box = cv2.boxPoints(rect) - box = np.array(box, dtype="float32") - - # 排序四个点 - pts_src = self._order_points(box) - - # 计算目标矩形宽高 - (tl, tr, br, bl) = pts_src - widthA = np.linalg.norm(br - bl) - widthB = np.linalg.norm(tr - tl) - maxWidth = int(max(widthA, widthB)) - - heightA = np.linalg.norm(tr - br) - heightB = np.linalg.norm(tl - bl) - maxHeight = int(max(heightA, heightB)) - - # 确保尺寸合理 - if maxWidth < 10 or maxHeight < 10: - return original_roi - - # 目标点集合 - pts_dst = np.array([ - [0, 0], - [maxWidth - 1, 0], - [maxWidth - 1, maxHeight - 1], - [0, maxHeight - 1]], dtype="float32") - - # 透视变换 - M = cv2.getPerspectiveTransform(pts_src, pts_dst) - warped = cv2.warpPerspective(original_roi, M, (maxWidth, maxHeight)) - - return warped - except Exception as e: - print(f"车牌矫正失败: {e}") - return original_roi \ No newline at end of file diff --git a/yolopart/models/yolo_detector.py b/yolopart/models/yolo_detector.py deleted file mode 100644 index 2b02960..0000000 --- a/yolopart/models/yolo_detector.py +++ /dev/null @@ -1,368 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -YOLO车牌检测器 -基于ONNX Runtime的YOLO11s模型推理 -""" - -import cv2 -import numpy as np -import onnxruntime as ort -import time -from typing import List, Tuple, Optional - -class YOLODetector: - """YOLO车牌检测器""" - - def __init__(self, model_path: str, conf_threshold: float = 0.25, nms_threshold: float = 0.4): - """ - 初始化YOLO检测器 - - Args: - model_path: ONNX模型文件路径 - conf_threshold: 置信度阈值 - nms_threshold: NMS阈值 - """ - self.model_path = model_path - self.conf_threshold = conf_threshold - self.nms_threshold = nms_threshold - self.input_size = (640, 640) # YOLO11s输入尺寸 - self.use_gpu = False - - # 初始化ONNX Runtime会话 - self._init_session() - - # 获取模型输入输出信息 - self.input_name = self.session.get_inputs()[0].name - self.output_names = [output.name for output in self.session.get_outputs()] - - print(f"YOLO检测器初始化完成") - print(f"模型路径: {model_path}") - print(f"输入尺寸: {self.input_size}") - print(f"GPU加速: {self.use_gpu}") - - def _init_session(self): - """初始化ONNX Runtime会话""" - # 获取可用的providers - available_providers = ort.get_available_providers() - print(f"可用的执行提供者: {available_providers}") - - # 优先使用GPU,如果可用的话 - providers = [] - if 'CUDAExecutionProvider' in available_providers: - providers.append('CUDAExecutionProvider') - self.use_gpu = True - print("检测到CUDA支持,将使用GPU加速") - elif 'TensorrtExecutionProvider' in available_providers: - providers.append('TensorrtExecutionProvider') - self.use_gpu = True - print("检测到TensorRT支持,将使用GPU加速") - else: - self.use_gpu = False - print("未检测到GPU支持,将使用CPU") - - # 添加CPU作为备选 - providers.append('CPUExecutionProvider') - - print(f"使用的执行提供者: {providers}") - - # 创建会话 - session_options = ort.SessionOptions() - session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL - - try: - self.session = ort.InferenceSession( - self.model_path, - sess_options=session_options, - providers=providers - ) - - # 检查实际使用的provider - actual_providers = self.session.get_providers() - print(f"实际使用的执行提供者: {actual_providers}") - - if 'CUDAExecutionProvider' in actual_providers or 'TensorrtExecutionProvider' in actual_providers: - self.use_gpu = True - print("✅ GPU加速已启用") - else: - self.use_gpu = False - print("⚠️ 使用CPU执行") - - except Exception as e: - print(f"模型加载失败: {e}") - raise - - def preprocess(self, image: np.ndarray) -> Tuple[np.ndarray, float, float]: - """ - 图像预处理 - - Args: - image: 输入图像 (BGR格式) - - Returns: - preprocessed_image: 预处理后的图像 - scale_x: X轴缩放比例 - scale_y: Y轴缩放比例 - """ - original_height, original_width = image.shape[:2] - target_width, target_height = self.input_size - - # 计算缩放比例 - scale_x = target_width / original_width - scale_y = target_height / original_height - - # 等比例缩放 - scale = min(scale_x, scale_y) - new_width = int(original_width * scale) - new_height = int(original_height * scale) - - # 缩放图像 - resized_image = cv2.resize(image, (new_width, new_height)) - - # 创建目标尺寸的图像并居中放置 - padded_image = np.full((target_height, target_width, 3), 114, dtype=np.uint8) - - # 计算填充位置 - start_x = (target_width - new_width) // 2 - start_y = (target_height - new_height) // 2 - - padded_image[start_y:start_y + new_height, start_x:start_x + new_width] = resized_image - - # 转换为RGB并归一化 - rgb_image = cv2.cvtColor(padded_image, cv2.COLOR_BGR2RGB) - normalized_image = rgb_image.astype(np.float32) / 255.0 - - # 转换为NCHW格式 - input_tensor = np.transpose(normalized_image, (2, 0, 1)) - input_tensor = np.expand_dims(input_tensor, axis=0) - - return input_tensor, scale, scale - - def postprocess(self, outputs: List[np.ndarray], scale_x: float, scale_y: float, - original_shape: Tuple[int, int]) -> List[dict]: - """ - 后处理检测结果 - - Args: - outputs: 模型输出 - scale_x: X轴缩放比例 - scale_y: Y轴缩放比例 - original_shape: 原始图像尺寸 (height, width) - - Returns: - 检测结果列表 - """ - detections = [] - - if len(outputs) == 0: - return detections - - # 获取输出张量 - output = outputs[0] - - # YOLO11输出格式: [batch, 6, 8400] -> [batch, 8400, 6] - if len(output.shape) == 3: - output = output.transpose(0, 2, 1) - - # 处理每个检测结果 - for detection in output[0]: # 取第一个batch - # 前4个值是边界框坐标,后2个是类别概率 - x_center, y_center, width, height = detection[:4] - class_scores = detection[4:] # 类别概率 [蓝牌概率, 绿牌概率] - - # 获取最高概率的类别 - class_id = np.argmax(class_scores) - confidence = class_scores[class_id] # 使用类别概率作为置信度 - - # 过滤低置信度检测 - if confidence < self.conf_threshold: - continue - - # 转换坐标到原始图像尺寸 - original_height, original_width = original_shape - - # 计算实际缩放比例和偏移 - scale = min(self.input_size[0] / original_width, self.input_size[1] / original_height) - pad_x = (self.input_size[0] - original_width * scale) / 2 - pad_y = (self.input_size[1] - original_height * scale) / 2 - - # 转换坐标 - x_center = (x_center - pad_x) / scale - y_center = (y_center - pad_y) / scale - width = width / scale - height = height / scale - - # 计算边界框 - x1 = int(x_center - width / 2) - y1 = int(y_center - height / 2) - x2 = int(x_center + width / 2) - y2 = int(y_center + height / 2) - - # 确保坐标在图像范围内 - x1 = max(0, min(x1, original_width - 1)) - y1 = max(0, min(y1, original_height - 1)) - x2 = max(0, min(x2, original_width - 1)) - y2 = max(0, min(y2, original_height - 1)) - - # 定义类别名称 - class_names = ['blue_plate', 'green_plate'] # 0: 蓝牌, 1: 绿牌 - class_name = class_names[class_id] if class_id < len(class_names) else 'unknown' - - detections.append({ - 'bbox': [x1, y1, x2, y2], - 'confidence': float(confidence), - 'class_id': int(class_id), - 'class_name': class_name - }) - - # 应用NMS - if detections: - detections = self._apply_nms(detections) - - return detections - - def _apply_nms(self, detections: List[dict]) -> List[dict]: - """ - 应用非极大值抑制 - - Args: - detections: 检测结果列表 - - Returns: - NMS后的检测结果 - """ - if len(detections) == 0: - return detections - - # 提取边界框和置信度 - boxes = np.array([det['bbox'] for det in detections]) - scores = np.array([det['confidence'] for det in detections]) - - # 应用NMS - indices = cv2.dnn.NMSBoxes( - boxes.tolist(), - scores.tolist(), - self.conf_threshold, - self.nms_threshold - ) - - # 返回保留的检测结果 - if len(indices) > 0: - indices = indices.flatten() - return [detections[i] for i in indices] - else: - return [] - - def detect(self, image: np.ndarray) -> List[dict]: - """ - 检测车牌 - - Args: - image: 输入图像 (BGR格式) - - Returns: - 检测结果列表 - """ - try: - # 预处理 - input_tensor, scale_x, scale_y = self.preprocess(image) - - # 推理 - outputs = self.session.run(self.output_names, {self.input_name: input_tensor}) - - # 调试输出 - print(f"模型输出数量: {len(outputs)}") - for i, output in enumerate(outputs): - print(f"输出 {i} 形状: {output.shape}") - print(f"输出 {i} 数据范围: [{output.min():.4f}, {output.max():.4f}]") - - # 后处理 - detections = self.postprocess(outputs, scale_x, scale_y, image.shape[:2]) - print(f"检测到的目标数量: {len(detections)}") - for i, det in enumerate(detections): - print(f"检测 {i}: 类别={det['class_name']}, 置信度={det['confidence']:.3f}") - - return detections - - except Exception as e: - print(f"检测过程出错: {e}") - return [] - - def draw_detections(self, image: np.ndarray, detections: List[dict]) -> np.ndarray: - """ - 在图像上绘制检测结果 - - Args: - image: 输入图像 - detections: 检测结果列表 - - Returns: - 绘制了检测框的图像 - """ - result_image = image.copy() - - for detection in detections: - bbox = detection['bbox'] - confidence = detection['confidence'] - class_id = detection['class_id'] - class_name = detection['class_name'] - - x1, y1, x2, y2 = bbox - - # 根据车牌类型选择颜色 - if class_id == 0: # 蓝牌 - color = (255, 0, 0) # 蓝色 (BGR格式) - plate_type = "Blue Plate" - elif class_id == 1: # 绿牌 - color = (0, 255, 0) # 绿色 (BGR格式) - plate_type = "Green Plate" - else: - color = (0, 255, 255) # 黄色 (BGR格式) - plate_type = "Unknown" - - # 绘制边界框 - cv2.rectangle(result_image, (x1, y1), (x2, y2), color, 2) - - # 绘制置信度标签 - label = f"{plate_type}: {confidence:.2f}" - label_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0] - - # 绘制标签背景 - cv2.rectangle(result_image, - (x1, y1 - label_size[1] - 10), - (x1 + label_size[0], y1), - color, -1) - - # 绘制标签文字 - cv2.putText(result_image, label, - (x1, y1 - 5), - cv2.FONT_HERSHEY_SIMPLEX, 0.6, - (255, 255, 255), 2) - - return result_image - - def crop_plates(self, image: np.ndarray, detections: List[dict]) -> List[np.ndarray]: - """ - 切割车牌图像 - - Args: - image: 原始图像 - detections: 检测结果列表 - - Returns: - 切割后的车牌图像列表 - """ - plate_images = [] - - for detection in detections: - bbox = detection['bbox'] - x1, y1, x2, y2 = bbox - - # 确保坐标有效 - if x2 > x1 and y2 > y1: - # 切割车牌区域 - plate_image = image[y1:y2, x1:x2] - if plate_image.size > 0: - plate_images.append(plate_image) - - return plate_images \ No newline at end of file diff --git a/yolopart/requirements.txt b/yolopart/requirements.txt deleted file mode 100644 index 713ce63..0000000 --- a/yolopart/requirements.txt +++ /dev/null @@ -1,17 +0,0 @@ -# 车牌检测系统依赖包 - -# 核心依赖 -PyQt5>=5.15.0 -opencv-python>=4.5.0 -onnxruntime-gpu>=1.12.0 -numpy>=1.21.0 - -# 可选依赖(车牌识别) -# paddlepaddle>=2.4.0 -# paddleocr>=2.6.0 -# pytesseract>=0.3.10 - -# 开发依赖 -# pytest>=7.0.0 -# black>=22.0.0 -# flake8>=4.0.0 \ No newline at end of file diff --git a/yolopart/ui/__init__.py b/yolopart/ui/__init__.py deleted file mode 100644 index 8807aa1..0000000 --- a/yolopart/ui/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# UI模块初始化文件 \ No newline at end of file diff --git a/yolopart/ui/main_window.py b/yolopart/ui/main_window.py deleted file mode 100644 index 789a139..0000000 --- a/yolopart/ui/main_window.py +++ /dev/null @@ -1,348 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -主界面窗口 -包含视频显示区域、控制按钮和车牌号显示区域 -""" - -import sys -import os -from PyQt5.QtWidgets import ( - QMainWindow, QWidget, QVBoxLayout, QHBoxLayout, - QLabel, QPushButton, QFrame, QTextEdit, QGroupBox, - QCheckBox, QSpinBox, QSlider, QGridLayout -) -from PyQt5.QtCore import Qt, QTimer, pyqtSignal -from PyQt5.QtGui import QFont, QPixmap, QPalette, QImage - -from .video_widget import VideoWidget -from utils.video_capture import VideoCapture -from models.yolo_detector import YOLODetector -from models.plate_recognizer import PlateRecognizerManager - -class MainWindow(QMainWindow): - """主窗口类""" - - def __init__(self): - super().__init__() - self.video_capture = None - self.yolo_detector = None - self.plate_recognizer = PlateRecognizerManager('mock') # 车牌识别管理器 - self.timer = QTimer() - self.use_camera = 1 # 1: 摄像头, 0: 视频文件 - self.detected_plates = [] # 存储切割后的车牌图像数组 - self.current_frame = None # 存储当前帧用于车牌矫正 - - self.init_ui() - self.init_detector() - self.init_video_capture() - self.connect_signals() - - def init_ui(self): - """初始化用户界面""" - self.setWindowTitle("车牌检测系统 - YOLO11s") - self.setGeometry(100, 100, 1200, 800) - - # 创建中央widget - central_widget = QWidget() - self.setCentralWidget(central_widget) - - # 主布局 - main_layout = QHBoxLayout(central_widget) - - # 左侧视频显示区域 - self.create_video_area(main_layout) - - # 右侧控制和信息显示区域 - self.create_control_area(main_layout) - - # 设置布局比例 - main_layout.setStretch(0, 3) # 视频区域占3/4 - main_layout.setStretch(1, 1) # 控制区域占1/4 - - def create_video_area(self, parent_layout): - """创建视频显示区域""" - video_frame = QFrame() - video_frame.setFrameStyle(QFrame.StyledPanel) - video_layout = QVBoxLayout(video_frame) - - # 视频显示widget - self.video_widget = VideoWidget() - video_layout.addWidget(self.video_widget) - - parent_layout.addWidget(video_frame) - - def create_control_area(self, parent_layout): - """创建控制和信息显示区域""" - control_frame = QFrame() - control_frame.setFrameStyle(QFrame.StyledPanel) - control_frame.setMaximumWidth(300) - control_layout = QVBoxLayout(control_frame) - - # 控制按钮组 - self.create_control_buttons(control_layout) - - # 检测信息显示 - self.create_detection_info(control_layout) - - # 车牌号显示区域 - self.create_plate_display(control_layout) - - # 系统状态显示 - self.create_status_display(control_layout) - - parent_layout.addWidget(control_frame) - - def create_control_buttons(self, parent_layout): - """创建控制按钮""" - button_group = QGroupBox("控制面板") - button_layout = QVBoxLayout(button_group) - - # 开始/停止按钮 - self.start_btn = QPushButton("开始检测") - self.start_btn.setMinimumHeight(40) - self.start_btn.clicked.connect(self.toggle_detection) - button_layout.addWidget(self.start_btn) - - # 视频源切换 - self.camera_checkbox = QCheckBox("使用摄像头") - self.camera_checkbox.setChecked(True) - self.camera_checkbox.stateChanged.connect(self.toggle_video_source) - button_layout.addWidget(self.camera_checkbox) - - # 检测开关 - self.detection_checkbox = QCheckBox("启用检测") - self.detection_checkbox.setChecked(True) - button_layout.addWidget(self.detection_checkbox) - - parent_layout.addWidget(button_group) - - def create_detection_info(self, parent_layout): - """创建检测信息显示""" - info_group = QGroupBox("检测信息") - info_layout = QVBoxLayout(info_group) - - # FPS显示 - self.fps_label = QLabel("FPS: 0") - self.fps_label.setFont(QFont("Arial", 12, QFont.Bold)) - info_layout.addWidget(self.fps_label) - - # 检测数量 - self.detection_count_label = QLabel("检测到车牌: 0") - info_layout.addWidget(self.detection_count_label) - - # 模型信息 - self.model_info_label = QLabel("模型: YOLO11s (ONNX)") - info_layout.addWidget(self.model_info_label) - - parent_layout.addWidget(info_group) - - def create_plate_display(self, parent_layout): - """创建车牌号显示区域""" - plate_group = QGroupBox("车牌识别结果") - plate_layout = QVBoxLayout(plate_group) - - # 当前识别的车牌号 - self.current_plate_label = QLabel("当前车牌: 未识别") - self.current_plate_label.setFont(QFont("Arial", 14, QFont.Bold)) - self.current_plate_label.setStyleSheet("color: blue; padding: 10px; border: 1px solid gray;") - plate_layout.addWidget(self.current_plate_label) - - # 矫正后的车牌图像显示 - self.plate_image_label = QLabel("矫正后车牌图像") - self.plate_image_label.setAlignment(Qt.AlignCenter) - self.plate_image_label.setMinimumHeight(100) - self.plate_image_label.setMaximumHeight(150) - self.plate_image_label.setStyleSheet("border: 1px solid gray; background-color: #f0f0f0;") - plate_layout.addWidget(self.plate_image_label) - - # 历史车牌记录 - history_label = QLabel("历史记录:") - plate_layout.addWidget(history_label) - - self.plate_history = QTextEdit() - self.plate_history.setMaximumHeight(150) - self.plate_history.setReadOnly(True) - plate_layout.addWidget(self.plate_history) - - # 预留接口说明 - interface_label = QLabel("注: 车牌识别接口已预留,可接入OCR模型") - interface_label.setStyleSheet("color: gray; font-size: 10px;") - plate_layout.addWidget(interface_label) - - parent_layout.addWidget(plate_group) - - def create_status_display(self, parent_layout): - """创建系统状态显示""" - status_group = QGroupBox("系统状态") - status_layout = QVBoxLayout(status_group) - - self.status_label = QLabel("状态: 就绪") - status_layout.addWidget(self.status_label) - - self.gpu_status_label = QLabel("GPU: 检测中...") - status_layout.addWidget(self.gpu_status_label) - - parent_layout.addWidget(status_group) - - # 添加弹性空间 - parent_layout.addStretch() - - def init_detector(self): - """初始化YOLO检测器""" - try: - model_path = os.path.join(os.path.dirname(__file__), "..", "yolo11sth50.onnx") - self.yolo_detector = YOLODetector(model_path) - self.model_info_label.setText(f"模型: YOLO11s (ONNX) - GPU: {self.yolo_detector.use_gpu}") - self.gpu_status_label.setText(f"GPU: {'启用' if self.yolo_detector.use_gpu else '禁用'}") - except Exception as e: - self.status_label.setText(f"模型加载失败: {str(e)}") - - def init_video_capture(self): - """初始化视频捕获""" - try: - self.video_capture = VideoCapture() - self.status_label.setText("视频捕获初始化成功") - except Exception as e: - self.status_label.setText(f"视频捕获初始化失败: {str(e)}") - - def connect_signals(self): - """连接信号和槽""" - self.timer.timeout.connect(self.update_frame) - - def toggle_detection(self): - """切换检测状态""" - if self.timer.isActive(): - self.stop_detection() - else: - self.start_detection() - - def start_detection(self): - """开始检测""" - if self.video_capture and self.video_capture.start_capture(self.use_camera): - # 根据视频源类型设置定时器间隔 - video_fps = self.video_capture.get_video_fps() - timer_interval = int(1000 / video_fps) # 转换为毫秒 - self.timer.start(timer_interval) - - self.start_btn.setText("停止检测") - source_type = "摄像头" if self.use_camera else f"视频文件({video_fps:.1f}FPS)" - self.status_label.setText(f"检测中... - {source_type}") - else: - self.status_label.setText("启动失败") - - def stop_detection(self): - """停止检测""" - self.timer.stop() - if self.video_capture: - self.video_capture.stop_capture() - self.start_btn.setText("开始检测") - self.status_label.setText("已停止") - - def toggle_video_source(self, state): - """切换视频源""" - self.use_camera = 1 if state == Qt.Checked else 0 - if self.timer.isActive(): - self.stop_detection() - self.start_detection() - - def update_frame(self): - """更新帧""" - if not self.video_capture: - return - - frame, fps = self.video_capture.get_frame() - if frame is None: - return - - # 保存当前帧用于车牌矫正 - self.current_frame = frame.copy() - - # 更新FPS显示 - self.fps_label.setText(f"FPS: {fps:.1f}") - - # 进行检测 - if self.detection_checkbox.isChecked() and self.yolo_detector: - detections = self.yolo_detector.detect(frame) - frame = self.yolo_detector.draw_detections(frame, detections) - - # 切割车牌图像 - if detections: - self.detected_plates = self.yolo_detector.crop_plates(frame, detections) - - # 统计不同类型车牌数量 - blue_count = sum(1 for d in detections if d['class_id'] == 0) - green_count = sum(1 for d in detections if d['class_id'] == 1) - total_count = len(detections) - - self.detection_count_label.setText(f"检测到车牌: {total_count} (蓝牌:{blue_count}, 绿牌:{green_count})") - - # 调用车牌识别接口(预留) - self.recognize_plates(self.detected_plates, detections) - else: - self.detection_count_label.setText("检测到车牌: 0") - - # 显示帧 - self.video_widget.update_frame(frame) - - def recognize_plates(self, plate_images, detections): - """车牌识别接口(预留)""" - # 这里是预留的车牌识别接口 - # 可以接入OCR模型进行车牌号识别 - if plate_images and detections and self.current_frame is not None: - # 获取最新检测到的车牌信息 - latest_detection = detections[-1] # 取最后一个检测结果 - plate_type = "Blue Plate" if latest_detection['class_id'] == 0 else "Green Plate" - confidence = latest_detection['confidence'] - - # 处理蓝色车牌的矫正 - corrected_image = None - if latest_detection['class_id'] == 0: # 蓝色车牌 - try: - bbox = latest_detection['bbox'] - corrected_image = self.plate_recognizer.preprocess_blue_plate( - plate_images[-1], self.current_frame, bbox - ) - self._display_plate_image(corrected_image) - except Exception as e: - print(f"蓝色车牌矫正失败: {e}") - self.plate_image_label.setText("蓝色车牌矫正失败") - elif latest_detection['class_id'] == 1: # 绿色车牌 - # 绿色车牌处理预留 - self.plate_image_label.setText("绿色车牌处理\n(待实现)") - - # 模拟识别结果 - plate_text = f"Mock {plate_type}-{len(plate_images)}" - self.current_plate_label.setText(f"Current Plate: {plate_text} (Confidence: {confidence:.2f})") - - # 添加到历史记录 - import datetime - timestamp = datetime.datetime.now().strftime("%H:%M:%S") - self.plate_history.append(f"[{timestamp}] {plate_text} (Confidence: {confidence:.2f})") - - def _display_plate_image(self, image): - """在界面上显示车牌图像""" - try: - # 将OpenCV图像转换为QPixmap - if len(image.shape) == 3: - height, width, channel = image.shape - bytes_per_line = 3 * width - q_image = QImage(image.data, width, height, bytes_per_line, QImage.Format_RGB888).rgbSwapped() - else: - height, width = image.shape - bytes_per_line = width - q_image = QImage(image.data, width, height, bytes_per_line, QImage.Format_Grayscale8) - - # 缩放图像以适应标签大小 - pixmap = QPixmap.fromImage(q_image) - scaled_pixmap = pixmap.scaled(self.plate_image_label.size(), Qt.KeepAspectRatio, Qt.SmoothTransformation) - - self.plate_image_label.setPixmap(scaled_pixmap) - except Exception as e: - print(f"显示车牌图像失败: {e}") - self.plate_image_label.setText(f"图像显示失败: {str(e)}") - - def closeEvent(self, event): - """窗口关闭事件""" - self.stop_detection() - event.accept() \ No newline at end of file diff --git a/yolopart/ui/video_widget.py b/yolopart/ui/video_widget.py deleted file mode 100644 index 35999a4..0000000 --- a/yolopart/ui/video_widget.py +++ /dev/null @@ -1,59 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -视频显示组件 -用于显示视频帧和检测结果 -""" - -import cv2 -import numpy as np -from PyQt5.QtWidgets import QLabel -from PyQt5.QtCore import Qt -from PyQt5.QtGui import QImage, QPixmap, QPainter, QPen, QFont - -class VideoWidget(QLabel): - """视频显示组件""" - - def __init__(self): - super().__init__() - self.setMinimumSize(640, 480) - self.setStyleSheet("border: 1px solid gray; background-color: black;") - self.setAlignment(Qt.AlignCenter) - self.setText("视频显示区域\n点击'开始检测'开始") - self.setScaledContents(True) - - def update_frame(self, frame): - """更新显示帧""" - if frame is None: - return - - # 转换BGR到RGB - rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - h, w, ch = rgb_frame.shape - bytes_per_line = ch * w - - # 创建QImage - qt_image = QImage(rgb_frame.data, w, h, bytes_per_line, QImage.Format_RGB888) - - # 转换为QPixmap并显示 - pixmap = QPixmap.fromImage(qt_image) - - # 缩放以适应widget大小,保持宽高比 - scaled_pixmap = pixmap.scaled( - self.size(), - Qt.KeepAspectRatio, - Qt.SmoothTransformation - ) - - self.setPixmap(scaled_pixmap) - - def paintEvent(self, event): - """绘制事件""" - super().paintEvent(event) - - # 如果没有图像,显示提示文本 - if not self.pixmap(): - painter = QPainter(self) - painter.setPen(QPen(Qt.white)) - painter.setFont(QFont("Arial", 16)) - painter.drawText(self.rect(), Qt.AlignCenter, "视频显示区域\n点击'开始检测'开始") \ No newline at end of file diff --git a/yolopart/utils/__init__.py b/yolopart/utils/__init__.py deleted file mode 100644 index 538fa97..0000000 --- a/yolopart/utils/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# 工具模块初始化文件 \ No newline at end of file diff --git a/yolopart/utils/video_capture.py b/yolopart/utils/video_capture.py deleted file mode 100644 index 83cb017..0000000 --- a/yolopart/utils/video_capture.py +++ /dev/null @@ -1,280 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -视频捕获管理 -支持摄像头和视频文件的切换和管理 -""" - -import cv2 -import os -import time -import threading -from typing import Optional, Tuple - -class VideoCapture: - """视频捕获管理类""" - - def __init__(self): - """ - 初始化视频捕获管理器 - """ - self.cap = None - self.is_camera = True - self.video_path = None - self.fps_counter = FPSCounter() - self.frame_lock = threading.Lock() - self.current_frame = None - self.is_running = False - self.video_fps = 30.0 # 视频原始帧率 - - # 设置视频文件路径 - self.video_file_path = os.path.join(os.path.dirname(__file__), "..", "video.mp4") - - def start_capture(self, use_camera: int = 1) -> bool: - """ - 开始视频捕获 - - Args: - use_camera: 1使用摄像头,0使用视频文件 - - Returns: - 是否成功启动 - """ - self.stop_capture() - - self.is_camera = bool(use_camera) - - try: - if self.is_camera: - # 使用摄像头 - self.cap = cv2.VideoCapture(0) - if not self.cap.isOpened(): - # 尝试其他摄像头索引 - for i in range(1, 5): - self.cap = cv2.VideoCapture(i) - if self.cap.isOpened(): - break - else: - print("无法打开摄像头") - return False - - # 设置摄像头参数 - self.cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640) - self.cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480) - self.cap.set(cv2.CAP_PROP_FPS, 30) - - print("摄像头启动成功") - - else: - # 使用视频文件 - if not os.path.exists(self.video_file_path): - print(f"视频文件不存在: {self.video_file_path}") - return False - - self.cap = cv2.VideoCapture(self.video_file_path) - if not self.cap.isOpened(): - print(f"无法打开视频文件: {self.video_file_path}") - return False - - # 获取视频原始帧率 - self.video_fps = self.cap.get(cv2.CAP_PROP_FPS) - if self.video_fps <= 0: - self.video_fps = 25.0 # 默认帧率 - - print(f"视频文件加载成功: {self.video_file_path}, FPS: {self.video_fps}") - - self.is_running = True - self.fps_counter.reset() - return True - - except Exception as e: - print(f"启动视频捕获失败: {e}") - return False - - def stop_capture(self): - """停止视频捕获""" - self.is_running = False - - if self.cap is not None: - self.cap.release() - self.cap = None - - with self.frame_lock: - self.current_frame = None - - print("视频捕获已停止") - - def get_frame(self) -> Tuple[Optional[cv2.Mat], float]: - """ - 获取当前帧 - - Returns: - (frame, fps): 当前帧和FPS - """ - if not self.is_running or self.cap is None: - return None, 0.0 - - try: - ret, frame = self.cap.read() - - if not ret: - if not self.is_camera: - # 视频文件播放完毕,重新开始(循环播放) - self.cap.set(cv2.CAP_PROP_POS_FRAMES, 0) - ret, frame = self.cap.read() - - if not ret: - return None, 0.0 - - # 更新FPS计数器 - fps = self.fps_counter.update() - - # 在帧上绘制FPS信息 - frame_with_fps = self._draw_fps(frame, fps) - - with self.frame_lock: - self.current_frame = frame_with_fps.copy() - - return frame_with_fps, fps - - except Exception as e: - print(f"获取帧失败: {e}") - return None, 0.0 - - def _draw_fps(self, frame: cv2.Mat, fps: float) -> cv2.Mat: - """ - 在帧上绘制FPS信息 - - Args: - frame: 输入帧 - fps: 当前FPS - - Returns: - 绘制了FPS的帧 - """ - result_frame = frame.copy() - - # FPS文本 - fps_text = f"FPS: {fps:.1f}" - - # 文本参数 - font = cv2.FONT_HERSHEY_SIMPLEX - font_scale = 0.7 - color = (0, 255, 0) # 绿色 - thickness = 2 - - # 获取文本尺寸 - text_size = cv2.getTextSize(fps_text, font, font_scale, thickness)[0] - - # 绘制背景矩形 - cv2.rectangle(result_frame, - (10, 10), - (20 + text_size[0], 20 + text_size[1]), - (0, 0, 0), -1) - - # 绘制FPS文本 - cv2.putText(result_frame, fps_text, - (15, 15 + text_size[1]), - font, font_scale, color, thickness) - - return result_frame - - def get_capture_info(self) -> dict: - """ - 获取捕获信息 - - Returns: - 捕获信息字典 - """ - info = { - 'is_running': self.is_running, - 'is_camera': self.is_camera, - 'video_path': self.video_file_path if not self.is_camera else None, - 'fps': self.fps_counter.get_fps(), - 'video_fps': self.video_fps - } - - if self.cap is not None: - try: - info['width'] = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) - info['height'] = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) - if not self.is_camera: - info['total_frames'] = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) - info['current_frame'] = int(self.cap.get(cv2.CAP_PROP_POS_FRAMES)) - except: - pass - - return info - - def get_video_fps(self) -> float: - """ - 获取视频帧率 - - Returns: - 视频帧率,摄像头返回30.0,视频文件返回原始帧率 - """ - if self.is_camera: - return 30.0 # 摄像头固定30FPS - else: - return self.video_fps # 视频文件原始帧率 - - def __del__(self): - """析构函数""" - self.stop_capture() - -class FPSCounter: - """FPS计数器""" - - def __init__(self, window_size: int = 30): - """ - 初始化FPS计数器 - - Args: - window_size: 滑动窗口大小 - """ - self.window_size = window_size - self.frame_times = [] - self.last_time = time.time() - - def update(self) -> float: - """ - 更新FPS计数 - - Returns: - 当前FPS - """ - current_time = time.time() - - # 添加当前帧时间 - self.frame_times.append(current_time) - - # 保持窗口大小 - if len(self.frame_times) > self.window_size: - self.frame_times.pop(0) - - # 计算FPS - if len(self.frame_times) >= 2: - time_diff = self.frame_times[-1] - self.frame_times[0] - if time_diff > 0: - fps = (len(self.frame_times) - 1) / time_diff - return fps - - return 0.0 - - def get_fps(self) -> float: - """ - 获取当前FPS - - Returns: - 当前FPS - """ - if len(self.frame_times) >= 2: - time_diff = self.frame_times[-1] - self.frame_times[0] - if time_diff > 0: - return (len(self.frame_times) - 1) / time_diff - return 0.0 - - def reset(self): - """重置计数器""" - self.frame_times.clear() - self.last_time = time.time() \ No newline at end of file diff --git a/yolopart/yolo11sth50.onnx b/yolopart/yolo11s-pose42.pt similarity index 55% rename from yolopart/yolo11sth50.onnx rename to yolopart/yolo11s-pose42.pt index c55db16..45f8ddc 100644 Binary files a/yolopart/yolo11sth50.onnx and b/yolopart/yolo11s-pose42.pt differ