91 lines
3.0 KiB
Python
91 lines
3.0 KiB
Python
import numpy as np
|
||
from paddleocr import TextRecognition
|
||
import cv2
|
||
|
||
class OCRProcessor:
|
||
def __init__(self):
|
||
self.model = TextRecognition(model_name="PP-OCRv5_server_rec")
|
||
# 定义允许的字符集合(不包含空白字符)
|
||
self.allowed_chars = [
|
||
# 中文省份简称
|
||
'京', '沪', '津', '渝', '冀', '晋', '蒙', '辽', '吉', '黑',
|
||
'苏', '浙', '皖', '闽', '赣', '鲁', '豫', '鄂', '湘', '粤',
|
||
'桂', '琼', '川', '贵', '云', '藏', '陕', '甘', '青', '宁', '新',
|
||
# 字母 A-Z
|
||
'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M',
|
||
'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',
|
||
# 数字 0-9
|
||
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9'
|
||
]
|
||
print("OCR模型初始化完成(占位)")
|
||
|
||
def predict(self, image_array):
|
||
# 保持原有模型调用方式
|
||
output = self.model.predict(input=image_array)
|
||
# 结构化输出结果
|
||
results = output[0]["rec_text"]
|
||
placeholder_result = results.split(',')
|
||
return placeholder_result
|
||
|
||
def filter_allowed_chars(self, text):
|
||
"""只保留允许的字符"""
|
||
filtered_text = ""
|
||
for char in text:
|
||
if char in self.allowed_chars:
|
||
filtered_text += char
|
||
return filtered_text
|
||
|
||
# 保留原有函数接口
|
||
_processor = OCRProcessor()
|
||
|
||
def LPRNinitialize_model():
|
||
return _processor
|
||
|
||
def LPRNmodel_predict(image_array):
|
||
"""
|
||
OCR车牌号识别接口函数
|
||
|
||
参数:
|
||
image_array: numpy数组格式的车牌图像,已经过矫正处理
|
||
|
||
返回:
|
||
list: 包含最多8个字符的列表,代表车牌号的每个字符
|
||
例如: ['京', 'A', '1', '2', '3', '4', '5', ''] (蓝牌7位+占位符)
|
||
['京', 'A', 'D', '1', '2', '3', '4', '5'] (绿牌8位)
|
||
"""
|
||
# 获取原始预测结果
|
||
raw_result = _processor.predict(image_array)
|
||
|
||
# 将结果合并为字符串(如果是列表的话)
|
||
if isinstance(raw_result, list):
|
||
result_str = ''.join(raw_result)
|
||
else:
|
||
result_str = str(raw_result)
|
||
|
||
# 过滤掉'·'和'-'字符
|
||
filtered_str = result_str.replace('·', '')
|
||
filtered_str = filtered_str.replace('-', '')
|
||
|
||
# 只保留允许的字符
|
||
filtered_str = _processor.filter_allowed_chars(filtered_str)
|
||
|
||
# 转换为字符列表
|
||
char_list = list(filtered_str)
|
||
|
||
# 确保返回至少7个字符,最多8个字符
|
||
if len(char_list) < 7:
|
||
# 如果识别结果少于7个字符,用'0'补齐到7位
|
||
char_list.extend(['0'] * (7 - len(char_list)))
|
||
elif len(char_list) > 8:
|
||
# 如果识别结果多于8个字符,截取前8个
|
||
char_list = char_list[:8]
|
||
|
||
# 如果是7位,补齐到8位以保持接口一致性(第8位用空字符或占位符)
|
||
if len(char_list) == 7:
|
||
char_list.append('') # 添加空字符作为第8位占位符
|
||
|
||
return char_list
|
||
|
||
|
||
|