浏览代码

update paddlex inference

1. fix json numpy.array
2. fix when no text detection in ocr
3. fix key to result
gaotingquan 1 年之前
父节点
当前提交
cdb11c53ad
共有 3 个文件被更改,包括 35 次插入13 次删除
  1. 11 10
      paddlex/inference/pipelines/ocr.py
  2. 18 2
      paddlex/inference/results/base.py
  3. 6 1
      paddlex/inference/results/ocr.py

+ 11 - 10
paddlex/inference/pipelines/ocr.py

@@ -36,14 +36,15 @@ class OCRPipeline(BasePipeline):
                 single_img_res = det_res["result"]
                 single_img_res["rec_text"] = []
                 single_img_res["rec_score"] = []
-                all_subs_of_img = list(self._crop_by_polys(single_img_res))
-                for batch_rec_res in self._rec_predict(all_subs_of_img):
-                    for rec_res in batch_rec_res:
-                        single_img_res["rec_text"].append(rec_res["result"]["rec_text"])
-                        single_img_res["rec_score"].append(
-                            rec_res["result"]["rec_score"]
-                        )
-                # TODO(gaotingquan): using "ocr_res" or new a component or dict only?
-                batch_ocr_res.append({"ocr_res": OCRResult(single_img_res)})
-                # batch_ocr_res.append(OCRResult(single_img_res))
+                if len(single_img_res["dt_polys"]) > 0:
+                    all_subs_of_img = list(self._crop_by_polys(single_img_res))
+                    for batch_rec_res in self._rec_predict(all_subs_of_img):
+                        for rec_res in batch_rec_res:
+                            single_img_res["rec_text"].append(
+                                rec_res["result"]["rec_text"]
+                            )
+                            single_img_res["rec_score"].append(
+                                rec_res["result"]["rec_score"]
+                            )
+                batch_ocr_res.append({"result": OCRResult(single_img_res)})
         yield batch_ocr_res

+ 18 - 2
paddlex/inference/results/base.py

@@ -14,15 +14,24 @@
 
 from abc import abstractmethod
 from pathlib import Path
+import numpy as np
 import json
 
 from ...utils import logging
 from ..utils.io import JsonWriter, ImageReader, ImageWriter
 
 
+class NumpyEncoder(json.JSONEncoder):
+    def default(self, obj):
+        if isinstance(obj, np.ndarray):
+            return obj.tolist()
+        return super(NumpyEncoder, self).default(obj)
+
+
 class BaseResult(dict):
     def __init__(self, data):
         super().__init__(data)
+        self._check_res()
         self._json_writer = JsonWriter()
         self._img_reader = ImageReader(backend="opencv")
         self._img_writer = ImageWriter(backend="opencv")
@@ -30,7 +39,9 @@ class BaseResult(dict):
     def save_to_json(self, save_path, indent=4, ensure_ascii=False):
         if not save_path.endswith(".json"):
             save_path = Path(save_path) / f"{Path(self['img_path']).stem}.json"
-        self._json_writer.write(save_path, self, indent=4, ensure_ascii=False)
+        self._json_writer.write(
+            save_path, self, indent=4, ensure_ascii=False, cls=NumpyEncoder
+        )
 
     def save_to_img(self, save_path):
         if not save_path.lower().endswith((".jpg", ".png")):
@@ -43,9 +54,14 @@ class BaseResult(dict):
     def print(self, json_format=True, indent=4, ensure_ascii=False):
         str_ = self
         if json_format:
-            str_ = json.dumps(str_, indent=indent, ensure_ascii=ensure_ascii)
+            str_ = json.dumps(
+                str_, indent=indent, ensure_ascii=ensure_ascii, cls=NumpyEncoder
+            )
         logging.info(str_)
 
+    def _check_res(self):
+        pass
+
     @abstractmethod
     def _get_res_img(self):
         raise NotImplementedError

+ 6 - 1
paddlex/inference/results/ocr.py

@@ -19,13 +19,18 @@ import cv2
 import PIL
 from PIL import Image, ImageDraw, ImageFont
 
-from .base import BaseResult
+from ...utils import logging
 from ...utils.fonts import PINGFANG_FONT_FILE_PATH
 from ..utils.io import ImageReader
+from .base import BaseResult
 
 
 class OCRResult(BaseResult):
 
+    def _check_res(self):
+        if len(self["dt_polys"]) == 0:
+            logging.warning("No text detected!")
+
     def _get_res_img(
         self,
         drop_score=0.5,