diff --git a/MANIFEST.in b/MANIFEST.in index c0f9d1b..44a0c59 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,3 +1,4 @@ recursive-include ddddocr common.onnx recursive-include ddddocr common_old.onnx -recursive-include ddddocr common_det.onnx \ No newline at end of file +recursive-include ddddocr common_det.onnx +recursive-include ddddocr common_rot.onnx \ No newline at end of file diff --git a/README.md b/README.md index 60c5ad5..1ac313b 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,7 @@ DdddOcr、最简依赖的理念,尽量减少用户的配置和使用成本,

- + ## 目录 - [赞助合作商](#赞助合作商) @@ -112,6 +112,7 @@ ddddocr │ │── __init__.py 主代码库文件 │ │── common.onnx 新ocr模型 │ │── common_det.onnx 目标检测模型 +│ │── common_rot.onnx 旋转图片模型 │ │── common_old.onnx 老ocr模型 │ │── logo.png │ │── README.md @@ -258,7 +259,7 @@ cv2.imwrite("result.jpg", im) res = det.slide_match(target_bytes, background_bytes) print(res) - ``` +``` 由于滑块图可能存在透明边框的问题,导致计算结果不一定准确,需要自行估算滑块图透明边框的宽度用于修正得出的bbox *提示:如果滑块无过多背景部分,则可以添加simple_target参数, 通常为jpg或者bmp格式的图片* @@ -275,7 +276,7 @@ cv2.imwrite("result.jpg", im) res = slide.slide_match(target_bytes, background_bytes, simple_target=True) print(res) - ``` +``` **a.算法2** @@ -303,7 +304,7 @@ cv2.imwrite("result.jpg", im) res = slide.slide_comparison(target_bytes, background_bytes) print(res) - ``` +``` ##### Ⅳ. OCR概率输出 @@ -365,6 +366,40 @@ print(res) ``` +##### Ⅵ. 旋转图片 + +返回图片需要旋转多少度才是正的图片, 可以保存调试图片 + +```python +import ddddocr +import time + +rot = ddddocr.DdddOcr(rot=True) +with open("test.jpg", 'rb') as f: + image = f.read() + +runs = 100 +start = time.time() +for _ in range(runs): + degree = rot.rotate(image, save_rot=False) +end = time.time() +total_time = end - start +qps = runs / total_time +print(f"ONNX Runtime (CPU) - QPS: {qps:.2f}, Avg Latency: {total_time / runs * 1000:.2f} ms") + +# ONNX Runtime (CPU) - QPS: 34.34, Avg Latency: 29.12 ms # 包含图片预处理和后处理(保存) +# ONNX Runtime (CPU) - QPS: 38.39, Avg Latency: 26.05 ms # 包含图片预处理、不含后处理(保存) +``` + +**参考例图** + +包括且不限于以下图片 + +![原图](images/test.jpg) + +![旋转后](images/debug.jpg) + + ### 版本控制 该项目使用Git进行版本管理。您可以在repository参看当前可用版本。 @@ -388,7 +423,7 @@ print(res) ### 作者 sml2h3@gamil.com - + wechat *好友数过多不一定通过,有问题可以在issue进行交流* @@ -421,4 +456,3 @@ sml2h3@gamil.com [![Star History Chart](https://api.star-history.com/svg?repos=sml2h3/ddddocr&type=Date)](https://star-history.com/#sml2h3/ddddocr&Date) - diff --git a/ddddocr/__init__.py b/ddddocr/__init__.py index 69b8409..beb724b 100644 --- a/ddddocr/__init__.py +++ b/ddddocr/__init__.py @@ -11,10 +11,11 @@ from PIL import Image, ImageChops import numpy as np import cv2 - +import math onnxruntime.set_default_logger_severity(3) + def base64_to_image(img_base64): img_data = base64.b64decode(img_base64) return Image.open(io.BytesIO(img_data)) @@ -39,7 +40,7 @@ def png_rgba_black_preprocess(img: Image): class DdddOcr(object): - def __init__(self, ocr: bool = True, det: bool = False, old: bool = False, beta: bool = False, + def __init__(self, ocr: bool = True, det: bool = False, rot: bool = False, old: bool = False, beta: bool = False, use_gpu: bool = False, device_id: int = 0, show_ad=True, import_onnx_path: str = "", charsets_path: str = ""): if show_ad: @@ -54,11 +55,12 @@ def __init__(self, ocr: bool = True, det: bool = False, old: bool = False, beta: self.__word = False self.__resize = [] self.__charset_range = [] - self.__valid_charset_range_index = [] # 指定字符对应的有效索引 + self.__valid_charset_range_index = [] # 指定字符对应的有效索引 self.__channel = 1 if import_onnx_path != "": det = False ocr = False + rot = False self.__graph_path = import_onnx_path with open(charsets_path, 'r', encoding="utf-8") as f: info = json.loads(f.read()) @@ -67,9 +69,14 @@ def __init__(self, ocr: bool = True, det: bool = False, old: bool = False, beta: self.__resize = info['image'] self.__channel = info['channel'] self.use_import_onnx = True - + if rot: + ocr = False + det = False + self.__graph_path = os.path.join(os.path.dirname(__file__), 'common_rot.onnx') + self.__charset = [] if det: ocr = False + rot = False self.__graph_path = os.path.join(os.path.dirname(__file__), 'common_det.onnx') self.__charset = [] if ocr: @@ -2401,6 +2408,7 @@ def __init__(self, ocr: bool = True, det: bool = False, old: bool = False, beta: "窭", "铌", "友", "唉", "怫", "荘"] self.det = det + self.rot = rot if use_gpu: self.__providers = [ ('CUDAExecutionProvider', { @@ -2415,7 +2423,7 @@ def __init__(self, ocr: bool = True, det: bool = False, old: bool = False, beta: self.__providers = [ 'CPUExecutionProvider', ] - if ocr or det or self.use_import_onnx: + if rot or ocr or det or self.use_import_onnx: self.__ort_session = onnxruntime.InferenceSession(self.__graph_path, providers=self.__providers) def preproc(self, img, input_size, swap=(2, 0, 1)): @@ -2605,7 +2613,6 @@ def set_ranges(self, charset_range): pass self.__valid_charset_range_index = valid_charset_range_index - def classification(self, img, png_fix: bool = False, probability=False): if self.det: raise TypeError("当前识别类型为目标检测") @@ -2681,7 +2688,7 @@ def classification(self, img, png_fix: bool = False, probability=False): valid_charset_range_index = self.__valid_charset_range_index probability_result = [] for item in ort_outs_probability: - probability_result.append([item[i] for i in valid_charset_range_index ]) + probability_result.append([item[i] for i in valid_charset_range_index]) result['probability'] = probability_result return result else: @@ -2723,12 +2730,67 @@ def classification(self, img, png_fix: bool = False, probability=False): def detection(self, img_bytes: bytes = None, img_base64: str = None): if not self.det: - raise TypeError("当前识别类型为文字识别") + raise TypeError("当前识别类型为文字识别或图片旋转") if not img_bytes: img_bytes = base64.b64decode(img_base64) result = self.get_bbox(img_bytes) return result + def rotate(self, img_bytes: bytes = None, img_base64: str = None, save_rot: bool = False): + """ + 返回图片应该旋转的角度 + :param img_bytes: open(test.jpg).read() + :param img_base64: + :param save_rot: 为True则保存旋转以后的图片到项目路径,默认False + :return: + """ + # 由于弟弟库维护人员过多,rotate没有调用和修改其他方法(如preproc) + # 如果我要调用preproc,那么就得修改preproc + # 可能导致其他兼容问题,所以全部写在rotate + + if not self.rot: + raise TypeError("当前识别类型为文字识别或目标检测") + if not img_bytes: + img_bytes = base64.b64decode(img_base64) + image = cv2.imdecode(np.frombuffer(img_bytes, np.uint16), cv2.IMREAD_COLOR) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + src_size, src_w = image.shape[:2] + assert src_size == src_w + # 中心裁剪 + output_size = int(src_size / math.sqrt(2.0)) + img = image[(image.shape[0] - output_size) // 2: (image.shape[0] + output_size) // 2, + (image.shape[1] - output_size) // 2: (image.shape[1] + output_size) // 2] + # 归一化 + img = np.array(img).astype(np.float32) / 255.0 + + # 缩放 + input_size = (224,224) + img = cv2.resize(img, input_size, interpolation=cv2.INTER_LANCZOS4) + + # 标准化 + mean = np.array([0.485, 0.456, 0.406], dtype=np.float32) + std = np.array([0.229, 0.224, 0.225], dtype=np.float32) + img = (img - mean) / std + + # 转换为C, H, W + im = img.transpose((2, 0, 1)) + + ort_inputs = {self.__ort_session.get_inputs()[0].name: im[None, :, :, :]} + cls_num = self.__ort_session.get_outputs()[0].shape[1] + output = self.__ort_session.run(None, ort_inputs) + angle = output[0].argmax(1).item() / cls_num + degree = angle * 360 + if save_rot: + center = (src_size // 2, src_w // 2) + M = cv2.getRotationMatrix2D(center, -degree, 1.0) + rotated = cv2.warpAffine( + image, M, (src_w, src_size), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT, borderValue=(255, 255, 255) + ) + # 保存旋转后的图像 + rotated = cv2.cvtColor(rotated, cv2.COLOR_RGB2BGR) + cv2.imwrite("debug.jpg", rotated) + return degree + def get_target(self, img_bytes: bytes = None): image = Image.open(io.BytesIO(img_bytes)) w, h = image.size diff --git a/ddddocr/common_rot.onnx b/ddddocr/common_rot.onnx new file mode 100644 index 0000000..ffa351f Binary files /dev/null and b/ddddocr/common_rot.onnx differ diff --git a/images/debug.jpg b/images/debug.jpg new file mode 100644 index 0000000..99306a7 Binary files /dev/null and b/images/debug.jpg differ diff --git a/images/test.jpg b/images/test.jpg new file mode 100644 index 0000000..aae7724 Binary files /dev/null and b/images/test.jpg differ