diff --git a/.idea/License_plate_recognition.iml b/.idea/License_plate_recognition.iml index 2328911..fb56de3 100644 --- a/.idea/License_plate_recognition.iml +++ b/.idea/License_plate_recognition.iml @@ -2,7 +2,7 @@ - + diff --git a/.idea/misc.xml b/.idea/misc.xml index 9475f0b..fb9fc56 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -3,5 +3,5 @@ - + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml index 8306744..288b36b 100644 --- a/.idea/vcs.xml +++ b/.idea/vcs.xml @@ -1,7 +1,7 @@ - + \ No newline at end of file diff --git a/CRNN_part/crnn_interface.py b/CRNN_part/crnn_interface.py index 594595f..bfa9218 100644 --- a/CRNN_part/crnn_interface.py +++ b/CRNN_part/crnn_interface.py @@ -207,7 +207,7 @@ class LicensePlatePreprocessor: print(f"图像预处理失败: {e}") return None -def initialize_crnn_model(): +def LPRNinitialize_model(): """ 初始化CRNN模型 @@ -274,7 +274,7 @@ def initialize_crnn_model(): traceback.print_exc() return False -def crnn_predict(image_array): +def LPRNmodel_predict(image_array): """ CRNN车牌号识别接口函数 diff --git a/OCR_part/ocr_interface.py b/OCR_part/ocr_interface.py index 294e568..2d36093 100644 --- a/OCR_part/ocr_interface.py +++ b/OCR_part/ocr_interface.py @@ -1,36 +1,49 @@ import numpy as np +from paddleocr import TextRecognition +import cv2 + +class OCRProcessor: + def __init__(self): + self.model = TextRecognition(model_name="PP-OCRv5_server_rec") + print("OCR模型初始化完成(占位)") + + def predict(self, image_array): + # 保持原有模型调用方式 + output = self.model.predict(input=image_array) + # 结构化输出结果 + results = output[0]["rec_text"] + placeholder_result = results.split(',') + return placeholder_result + +# 保留原有函数接口 +_processor = OCRProcessor() + +def LPRNinitialize_model(): + return _processor + +def LPRNmodel_predict(image_array): + # 获取原始预测结果 + raw_result = _processor.predict(image_array) + + # 将结果合并为字符串(如果是列表的话) + if isinstance(raw_result, list): + result_str = ''.join(raw_result) + else: + result_str = str(raw_result) + + # 过滤掉'·'字符 + filtered_str = result_str.replace('·', '') + + # 转换为字符列表 + char_list = list(filtered_str) + + # 确保返回长度为7的列表 + if len(char_list) >= 7: + # 如果长度大于等于7,取前7个字符 + return char_list[:7] + else: + # 如果长度小于7,用空字符串补齐到7位 + return char_list + [''] * (7 - len(char_list)) -def initialize_ocr_model(): - """ - 初始化OCR模型 - - 返回: - bool: 初始化是否成功 - """ - # OCR模型初始化代码 - # 例如: 加载预训练模型、设置参数等 - - print("OCR模型初始化完成(占位)") - return True -def ocr_predict(image_array): - """ - OCR车牌号识别接口函数 - - 参数: - image_array: numpy数组格式的车牌图像,已经过矫正处理 - - 返回: - list: 包含7个字符的列表,代表车牌号的每个字符 - 例如: ['京', 'A', '1', '2', '3', '4', '5'] - """ - # 这是OCR部分的占位函数 - # 实际实现时,这里应该包含: - # 1. 图像预处理 - # 2. OCR模型推理 - # 3. 后处理和字符识别 - - # 临时返回占位结果 - placeholder_result = ['待', '识', '别', '0', '0', '0', '0'] - return placeholder_result diff --git a/main.py b/main.py index 1600611..2eee603 100644 --- a/main.py +++ b/main.py @@ -9,11 +9,11 @@ from PyQt5.QtCore import QTimer, Qt, pyqtSignal, QThread from PyQt5.QtGui import QImage, QPixmap, QFont, QPainter, QPen, QColor import os from yolopart.detector import LicensePlateYOLO -#from OCR_part.ocr_interface import ocr_predict -#from OCR_part.ocr_interface import initialize_ocr_model -# 使用CRNN进行车牌字符识别 -from CRNN_part.crnn_interface import crnn_predict -from CRNN_part.crnn_interface import initialize_crnn_model +from OCR_part.ocr_interface import LPRNmodel_predict +from OCR_part.ocr_interface import LPRNinitialize_model +# 使用CRNN进行车牌字符识别(可选)同时也要修改第395,396行 +# from CRNN_part.crnn_interface import LPRNmodel_predict +# from CRNN_part.crnn_interface import LPRNinitialize_model class CameraThread(QThread): """摄像头线程类""" @@ -163,9 +163,8 @@ class MainWindow(QMainWindow): self.init_detector() self.init_camera() - # 初始化OCR/CRNN模型(具体用哪个模块识别车牌号就写在这儿) - #initialize_ocr_model() - initialize_crnn_model() + # 初始化OCR/CRNN模型(函数名改成一样的了,所以不要修改这里了,想用哪个模块直接导入) + LPRNinitialize_model() def init_ui(self): @@ -390,10 +389,9 @@ class MainWindow(QMainWindow): return "识别失败" try: - # 使用OCR接口进行识别 - # 可以根据需要切换为CRNN: crnn_predict(corrected_image) - #result = ocr_predict(corrected_image) - result = crnn_predict(corrected_image) + # 预测函数(来自模块) + # 函数名改成一样的了,所以不要修改这里了,想用哪个模块直接导入 + result = LPRNmodel_predict(corrected_image) # 将字符列表转换为字符串 if isinstance(result, list) and len(result) >= 7: diff --git a/requirements.txt b/requirements.txt index 57cb5ac..f5e0e68 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,6 +11,11 @@ PyQt5>=5.15.0 # 图像处理 Pillow>=8.0.0 +#paddleocr +python -m pip install paddlepaddle-gpu==3.0.0 -i https://www.paddlepaddle.org.cn/packages/stable/cu118/ +python -m pip install "paddleocr[all]" + + # 可选:如果需要GPU加速 # torch>=1.9.0 # torchvision>=0.10.0