diff --git a/main.py b/main.py index 167de6a..1600611 100644 --- a/main.py +++ b/main.py @@ -9,11 +9,11 @@ from PyQt5.QtCore import QTimer, Qt, pyqtSignal, QThread from PyQt5.QtGui import QImage, QPixmap, QFont, QPainter, QPen, QColor import os from yolopart.detector import LicensePlateYOLO -from OCR_part.ocr_interface import ocr_predict -from OCR_part.ocr_interface import initialize_ocr_model -#不使用CRNN所以注释掉 -#from CRNN_part.crnn_interface import crnn_predict -#from CRNN_part.crnn_interface import initialize_crnn_model +#from OCR_part.ocr_interface import ocr_predict +#from OCR_part.ocr_interface import initialize_ocr_model +# 使用CRNN进行车牌字符识别 +from CRNN_part.crnn_interface import crnn_predict +from CRNN_part.crnn_interface import initialize_crnn_model class CameraThread(QThread): """摄像头线程类""" @@ -164,7 +164,8 @@ class MainWindow(QMainWindow): self.init_camera() # 初始化OCR/CRNN模型(具体用哪个模块识别车牌号就写在这儿) - initialize_ocr_model() + #initialize_ocr_model() + initialize_crnn_model() def init_ui(self): @@ -391,7 +392,8 @@ class MainWindow(QMainWindow): try: # 使用OCR接口进行识别 # 可以根据需要切换为CRNN: crnn_predict(corrected_image) - result = ocr_predict(corrected_image) + #result = ocr_predict(corrected_image) + result = crnn_predict(corrected_image) # 将字符列表转换为字符串 if isinstance(result, list) and len(result) >= 7: