diff --git a/MANIFEST.in b/MANIFEST.in index c0f9d1b..44a0c59 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,3 +1,4 @@ recursive-include ddddocr common.onnx recursive-include ddddocr common_old.onnx -recursive-include ddddocr common_det.onnx \ No newline at end of file +recursive-include ddddocr common_det.onnx +recursive-include ddddocr common_rot.onnx \ No newline at end of file diff --git a/README.md b/README.md index 60c5ad5..1ac313b 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,7 @@ DdddOcr、最简依赖的理念,尽量减少用户的配置和使用成本,
- + ## 目录 - [赞助合作商](#赞助合作商) @@ -112,6 +112,7 @@ ddddocr │ │── __init__.py 主代码库文件 │ │── common.onnx 新ocr模型 │ │── common_det.onnx 目标检测模型 +│ │── common_rot.onnx 旋转图片模型 │ │── common_old.onnx 老ocr模型 │ │── logo.png │ │── README.md @@ -258,7 +259,7 @@ cv2.imwrite("result.jpg", im) res = det.slide_match(target_bytes, background_bytes) print(res) - ``` +``` 由于滑块图可能存在透明边框的问题,导致计算结果不一定准确,需要自行估算滑块图透明边框的宽度用于修正得出的bbox *提示:如果滑块无过多背景部分,则可以添加simple_target参数, 通常为jpg或者bmp格式的图片* @@ -275,7 +276,7 @@ cv2.imwrite("result.jpg", im) res = slide.slide_match(target_bytes, background_bytes, simple_target=True) print(res) - ``` +``` **a.算法2** @@ -303,7 +304,7 @@ cv2.imwrite("result.jpg", im) res = slide.slide_comparison(target_bytes, background_bytes) print(res) - ``` +``` ##### Ⅳ. OCR概率输出 @@ -365,6 +366,40 @@ print(res) ``` +##### Ⅵ. 旋转图片 + +返回图片需要旋转多少度才是正的图片, 可以保存调试图片 + +```python +import ddddocr +import time + +rot = ddddocr.DdddOcr(rot=True) +with open("test.jpg", 'rb') as f: + image = f.read() + +runs = 100 +start = time.time() +for _ in range(runs): + degree = rot.rotate(image, save_rot=False) +end = time.time() +total_time = end - start +qps = runs / total_time +print(f"ONNX Runtime (CPU) - QPS: {qps:.2f}, Avg Latency: {total_time / runs * 1000:.2f} ms") + +# ONNX Runtime (CPU) - QPS: 34.34, Avg Latency: 29.12 ms # 包含图片预处理和后处理(保存) +# ONNX Runtime (CPU) - QPS: 38.39, Avg Latency: 26.05 ms # 包含图片预处理、不含后处理(保存) +``` + +**参考例图** + +包括且不限于以下图片 + + + + + + ### 版本控制 该项目使用Git进行版本管理。您可以在repository参看当前可用版本。 @@ -388,7 +423,7 @@ print(res) ### 作者 sml2h3@gamil.com - +
*好友数过多不一定通过,有问题可以在issue进行交流*
@@ -421,4 +456,3 @@ sml2h3@gamil.com
[](https://star-history.com/#sml2h3/ddddocr&Date)
-
diff --git a/ddddocr/__init__.py b/ddddocr/__init__.py
index 69b8409..beb724b 100644
--- a/ddddocr/__init__.py
+++ b/ddddocr/__init__.py
@@ -11,10 +11,11 @@
from PIL import Image, ImageChops
import numpy as np
import cv2
-
+import math
onnxruntime.set_default_logger_severity(3)
+
def base64_to_image(img_base64):
img_data = base64.b64decode(img_base64)
return Image.open(io.BytesIO(img_data))
@@ -39,7 +40,7 @@ def png_rgba_black_preprocess(img: Image):
class DdddOcr(object):
- def __init__(self, ocr: bool = True, det: bool = False, old: bool = False, beta: bool = False,
+ def __init__(self, ocr: bool = True, det: bool = False, rot: bool = False, old: bool = False, beta: bool = False,
use_gpu: bool = False,
device_id: int = 0, show_ad=True, import_onnx_path: str = "", charsets_path: str = ""):
if show_ad:
@@ -54,11 +55,12 @@ def __init__(self, ocr: bool = True, det: bool = False, old: bool = False, beta:
self.__word = False
self.__resize = []
self.__charset_range = []
- self.__valid_charset_range_index = [] # 指定字符对应的有效索引
+ self.__valid_charset_range_index = [] # 指定字符对应的有效索引
self.__channel = 1
if import_onnx_path != "":
det = False
ocr = False
+ rot = False
self.__graph_path = import_onnx_path
with open(charsets_path, 'r', encoding="utf-8") as f:
info = json.loads(f.read())
@@ -67,9 +69,14 @@ def __init__(self, ocr: bool = True, det: bool = False, old: bool = False, beta:
self.__resize = info['image']
self.__channel = info['channel']
self.use_import_onnx = True
-
+ if rot:
+ ocr = False
+ det = False
+ self.__graph_path = os.path.join(os.path.dirname(__file__), 'common_rot.onnx')
+ self.__charset = []
if det:
ocr = False
+ rot = False
self.__graph_path = os.path.join(os.path.dirname(__file__), 'common_det.onnx')
self.__charset = []
if ocr:
@@ -2401,6 +2408,7 @@ def __init__(self, ocr: bool = True, det: bool = False, old: bool = False, beta:
"窭", "铌",
"友", "唉", "怫", "荘"]
self.det = det
+ self.rot = rot
if use_gpu:
self.__providers = [
('CUDAExecutionProvider', {
@@ -2415,7 +2423,7 @@ def __init__(self, ocr: bool = True, det: bool = False, old: bool = False, beta:
self.__providers = [
'CPUExecutionProvider',
]
- if ocr or det or self.use_import_onnx:
+ if rot or ocr or det or self.use_import_onnx:
self.__ort_session = onnxruntime.InferenceSession(self.__graph_path, providers=self.__providers)
def preproc(self, img, input_size, swap=(2, 0, 1)):
@@ -2605,7 +2613,6 @@ def set_ranges(self, charset_range):
pass
self.__valid_charset_range_index = valid_charset_range_index
-
def classification(self, img, png_fix: bool = False, probability=False):
if self.det:
raise TypeError("当前识别类型为目标检测")
@@ -2681,7 +2688,7 @@ def classification(self, img, png_fix: bool = False, probability=False):
valid_charset_range_index = self.__valid_charset_range_index
probability_result = []
for item in ort_outs_probability:
- probability_result.append([item[i] for i in valid_charset_range_index ])
+ probability_result.append([item[i] for i in valid_charset_range_index])
result['probability'] = probability_result
return result
else:
@@ -2723,12 +2730,67 @@ def classification(self, img, png_fix: bool = False, probability=False):
def detection(self, img_bytes: bytes = None, img_base64: str = None):
if not self.det:
- raise TypeError("当前识别类型为文字识别")
+ raise TypeError("当前识别类型为文字识别或图片旋转")
if not img_bytes:
img_bytes = base64.b64decode(img_base64)
result = self.get_bbox(img_bytes)
return result
+ def rotate(self, img_bytes: bytes = None, img_base64: str = None, save_rot: bool = False):
+ """
+ 返回图片应该旋转的角度
+ :param img_bytes: open(test.jpg).read()
+ :param img_base64:
+ :param save_rot: 为True则保存旋转以后的图片到项目路径,默认False
+ :return:
+ """
+ # 由于弟弟库维护人员过多,rotate没有调用和修改其他方法(如preproc)
+ # 如果我要调用preproc,那么就得修改preproc
+ # 可能导致其他兼容问题,所以全部写在rotate
+
+ if not self.rot:
+ raise TypeError("当前识别类型为文字识别或目标检测")
+ if not img_bytes:
+ img_bytes = base64.b64decode(img_base64)
+ image = cv2.imdecode(np.frombuffer(img_bytes, np.uint16), cv2.IMREAD_COLOR)
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+ src_size, src_w = image.shape[:2]
+ assert src_size == src_w
+ # 中心裁剪
+ output_size = int(src_size / math.sqrt(2.0))
+ img = image[(image.shape[0] - output_size) // 2: (image.shape[0] + output_size) // 2,
+ (image.shape[1] - output_size) // 2: (image.shape[1] + output_size) // 2]
+ # 归一化
+ img = np.array(img).astype(np.float32) / 255.0
+
+ # 缩放
+ input_size = (224,224)
+ img = cv2.resize(img, input_size, interpolation=cv2.INTER_LANCZOS4)
+
+ # 标准化
+ mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
+ std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
+ img = (img - mean) / std
+
+ # 转换为C, H, W
+ im = img.transpose((2, 0, 1))
+
+ ort_inputs = {self.__ort_session.get_inputs()[0].name: im[None, :, :, :]}
+ cls_num = self.__ort_session.get_outputs()[0].shape[1]
+ output = self.__ort_session.run(None, ort_inputs)
+ angle = output[0].argmax(1).item() / cls_num
+ degree = angle * 360
+ if save_rot:
+ center = (src_size // 2, src_w // 2)
+ M = cv2.getRotationMatrix2D(center, -degree, 1.0)
+ rotated = cv2.warpAffine(
+ image, M, (src_w, src_size), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT, borderValue=(255, 255, 255)
+ )
+ # 保存旋转后的图像
+ rotated = cv2.cvtColor(rotated, cv2.COLOR_RGB2BGR)
+ cv2.imwrite("debug.jpg", rotated)
+ return degree
+
def get_target(self, img_bytes: bytes = None):
image = Image.open(io.BytesIO(img_bytes))
w, h = image.size
diff --git a/ddddocr/common_rot.onnx b/ddddocr/common_rot.onnx
new file mode 100644
index 0000000..ffa351f
Binary files /dev/null and b/ddddocr/common_rot.onnx differ
diff --git a/images/debug.jpg b/images/debug.jpg
new file mode 100644
index 0000000..99306a7
Binary files /dev/null and b/images/debug.jpg differ
diff --git a/images/test.jpg b/images/test.jpg
new file mode 100644
index 0000000..aae7724
Binary files /dev/null and b/images/test.jpg differ