Răsfoiți Sursa

refactor Result

gaotingquan 1 an în urmă
părinte
comite
c64e67363d
37 a modificat fișierele cu 439 adăugiri și 382 ștergeri
  1. 5 5
      paddlex/inference/components/task_related/det.py
  2. 4 4
      paddlex/inference/components/task_related/text_det.py
  3. 53 30
      paddlex/inference/components/transforms/image/common.py
  4. 17 10
      paddlex/inference/components/transforms/ts/common.py
  5. 1 1
      paddlex/inference/models/anomaly_detection.py
  6. 1 1
      paddlex/inference/models/formula_recognition.py
  7. 1 1
      paddlex/inference/models/general_recognition.py
  8. 1 1
      paddlex/inference/models/image_classification.py
  9. 1 1
      paddlex/inference/models/image_unwarping.py
  10. 1 1
      paddlex/inference/models/instance_segmentation.py
  11. 1 1
      paddlex/inference/models/multilabel_classification.py
  12. 1 1
      paddlex/inference/models/object_detection.py
  13. 1 1
      paddlex/inference/models/semantic_segmentation.py
  14. 1 1
      paddlex/inference/models/table_recognition.py
  15. 1 1
      paddlex/inference/models/text_detection.py
  16. 1 1
      paddlex/inference/models/text_recognition.py
  17. 3 1
      paddlex/inference/models/ts_ad.py
  18. 4 1
      paddlex/inference/models/ts_cls.py
  19. 3 1
      paddlex/inference/models/ts_fc.py
  20. 4 4
      paddlex/inference/pipelines/table_recognition/table_recognition.py
  21. 27 54
      paddlex/inference/results/base.py
  22. 9 18
      paddlex/inference/results/clas.py
  23. 5 19
      paddlex/inference/results/det.py
  24. 11 28
      paddlex/inference/results/instance_seg.py
  25. 32 38
      paddlex/inference/results/ocr.py
  26. 8 11
      paddlex/inference/results/seg.py
  27. 14 66
      paddlex/inference/results/table_rec.py
  28. 9 8
      paddlex/inference/results/text_det.py
  29. 6 7
      paddlex/inference/results/text_rec.py
  30. 10 36
      paddlex/inference/results/ts.py
  31. 176 0
      paddlex/inference/results/utils/mixin.py
  32. 0 1
      paddlex/inference/results/warp.py
  33. 2 2
      paddlex/inference/utils/io/__init__.py
  34. 8 9
      paddlex/inference/utils/io/readers.py
  35. 9 9
      paddlex/inference/utils/io/writers.py
  36. 2 3
      paddlex/paddlex_cli.py
  37. 6 5
      paddlex/utils/func_register.py

+ 5 - 5
paddlex/inference/components/task_related/det.py

@@ -203,7 +203,7 @@ class WarpAffine(BaseComponent):
 class DetPostProcess(BaseComponent):
     """Save Result Transform"""
 
-    INPUT_KEYS = ["img_path", "boxes", "img_size"]
+    INPUT_KEYS = ["input_path", "boxes", "img_size"]
     OUTPUT_KEYS = ["boxes"]
     DEAULT_INPUTS = {"boxes": "boxes", "img_size": "ori_img_size"}
     DEAULT_OUTPUTS = {"boxes": "boxes"}
@@ -227,18 +227,18 @@ class CropByBoxes(BaseComponent):
     """Crop Image by Box"""
 
     YIELD_BATCH = False
-    INPUT_KEYS = ["img_path", "boxes"]
+    INPUT_KEYS = ["input_path", "boxes"]
     OUTPUT_KEYS = ["img", "box", "label"]
-    DEAULT_INPUTS = {"img_path": "img_path", "boxes": "boxes"}
+    DEAULT_INPUTS = {"input_path": "input_path", "boxes": "boxes"}
     DEAULT_OUTPUTS = {"img": "img", "box": "box", "label": "label"}
 
     def __init__(self):
         super().__init__()
         self._reader = ImageReader(backend="opencv")
 
-    def apply(self, img_path, boxes):
+    def apply(self, input_path, boxes):
         output_list = []
-        img = self._reader.read(img_path)
+        img = self._reader.read(input_path)
         for bbox in boxes:
             label_id = bbox["cls_id"]
             box = bbox["coordinate"]

+ 4 - 4
paddlex/inference/components/task_related/text_det.py

@@ -419,9 +419,9 @@ class DBPostProcess(BaseComponent):
 class CropByPolys(BaseComponent):
     """Crop Image by Polys"""
 
-    INPUT_KEYS = ["img_path", "dt_polys"]
+    INPUT_KEYS = ["input_path", "dt_polys"]
     OUTPUT_KEYS = ["img"]
-    DEAULT_INPUTS = {"img_path": "img_path", "dt_polys": "dt_polys"}
+    DEAULT_INPUTS = {"input_path": "input_path", "dt_polys": "dt_polys"}
     DEAULT_OUTPUTS = {"img": "img"}
 
     def __init__(self, det_box_type="quad"):
@@ -429,9 +429,9 @@ class CropByPolys(BaseComponent):
         self.det_box_type = det_box_type
         self._reader = ImageReader(backend="opencv")
 
-    def apply(self, img_path, dt_polys):
+    def apply(self, input_path, dt_polys):
         """apply"""
-        img = self._reader.read(img_path)
+        img = self._reader.read(input_path)
 
         if self.det_box_type == "quad":
             dt_boxes = np.array(dt_polys)

+ 53 - 30
paddlex/inference/components/transforms/image/common.py

@@ -13,7 +13,7 @@
 # limitations under the License.
 
 import math
-
+import tempfile
 from pathlib import Path
 from copy import deepcopy
 
@@ -21,7 +21,7 @@ import numpy as np
 import cv2
 
 from .....utils.cache import CACHE_DIR
-from ....utils.io import ImageReader, ImageWriter
+from ....utils.io import ImageReader, ImageWriter, PDFReader
 from ...utils.mixin import BatchSizeMixin
 from ...base import BaseComponent
 from ..read_data import _BaseRead
@@ -60,7 +60,7 @@ class ReadImage(_BaseRead):
     DEAULT_INPUTS = {"img": "img"}
     DEAULT_OUTPUTS = {
         "img": "img",
-        "img_path": "img_path",
+        "input_path": "input_path",
         "img_size": "img_size",
         "ori_img": "ori_img",
         "ori_img_size": "ori_img_size",
@@ -72,7 +72,7 @@ class ReadImage(_BaseRead):
         "GRAY": cv2.IMREAD_GRAYSCALE,
     }
 
-    SUFFIX = ["jpg", "png", "jpeg", "JPEG", "JPG", "bmp"]
+    SUFFIX = ["jpg", "png", "jpeg", "JPEG", "JPG", "bmp", "PDF", "pdf"]
 
     def __init__(self, batch_size=1, format="BGR"):
         """
@@ -85,39 +85,47 @@ class ReadImage(_BaseRead):
         super().__init__(batch_size)
         self.format = format
         flags = self._FLAGS_DICT[self.format]
-        self._reader = ImageReader(backend="opencv", flags=flags)
+        self._img_reader = ImageReader(backend="opencv", flags=flags)
+        self._pdf_reader = PDFReader()
         self._writer = ImageWriter(backend="opencv")
 
     def apply(self, img):
         """apply"""
         if not isinstance(img, str):
-            img_path = (Path(CACHE_DIR) / "predict_input" / "tmp_img.jpg").as_posix()
-            self._writer.write(img_path, img)
-            yield [
-                {
-                    "img_path": img_path,
-                    "img": img,
-                    "img_size": [img.shape[1], img.shape[0]],
-                    "ori_img": deepcopy(img),
-                    "ori_img_size": deepcopy([img.shape[1], img.shape[0]]),
-                }
-            ]
+            with tempfile.NamedTemporaryFile(suffix=".png", delete=True) as temp_file:
+                img_path = Path(temp_file.name)
+                self._writer.write(img_path, img)
+                yield [
+                    {
+                        "input_path": img_path,
+                        "img": img,
+                        "img_size": [img.shape[1], img.shape[0]],
+                        "ori_img": deepcopy(img),
+                        "ori_img_size": deepcopy([img.shape[1], img.shape[0]]),
+                    }
+                ]
         else:
-            img_path = img
-            img_path = self._download_from_url(img_path)
-            file_list = self._get_files_list(img_path)
+            file_path = img
+            file_path = self._download_from_url(file_path)
+            file_list = self._get_files_list(file_path)
             batch = []
-            for img_path in file_list:
-                img = self._read_img(img_path)
-                batch.append(img)
+            for file_path in file_list:
+                img = self._read_img(file_path)
+                batch.extend(img)
                 if len(batch) >= self.batch_size:
                     yield batch
                     batch = []
             if len(batch) > 0:
                 yield batch
 
+    def _read(self, file_path):
+        if file_path:
+            return self._read_pdf(file_path)
+        else:
+            return self._read_img(file_path)
+
     def _read_img(self, img_path):
-        blob = self._reader.read(img_path)
+        blob = self._img_reader.read(img_path)
         if blob is None:
             raise Exception("Image read Error")
 
@@ -126,13 +134,28 @@ class ReadImage(_BaseRead):
                 raise RuntimeError("Array is not 3-dimensional.")
             # BGR to RGB
             blob = blob[..., ::-1]
-        return {
-            "img_path": img_path,
-            "img": blob,
-            "img_size": [blob.shape[1], blob.shape[0]],
-            "ori_img": deepcopy(blob),
-            "ori_img_size": deepcopy([blob.shape[1], blob.shape[0]]),
-        }
+        return [
+            {
+                "input_path": img_path,
+                "img": blob,
+                "img_size": [blob.shape[1], blob.shape[0]],
+                "ori_img": deepcopy(blob),
+                "ori_img_size": deepcopy([blob.shape[1], blob.shape[0]]),
+            }
+        ]
+
+    def _read_pdf(self, pdf_path):
+        img_list = self._pdf_reader.read(pdf_path)
+        return [
+            {
+                "input_path": pdf_path,
+                "img": img,
+                "img_size": [img.shape[1], img.shape[0]],
+                "ori_img": deepcopy(img),
+                "ori_img_size": deepcopy([img.shape[1], img.shape[0]]),
+            }
+            for img in img_list
+        ]
 
 
 class GetImageInfo(BaseComponent):

+ 17 - 10
paddlex/inference/components/transforms/ts/common.py

@@ -14,14 +14,15 @@
 
 from pathlib import Path
 from copy import deepcopy
+import tempfile
 import joblib
 import numpy as np
 import pandas as pd
 
 from .....utils.download import download
 from .....utils.cache import CACHE_DIR
-from ....utils.io.readers import TSReader
-from ....utils.io.writers import TSWriter
+from ....utils.io.readers import CSVReader
+from ....utils.io.writers import CSVWriter
 from ...base import BaseComponent
 from ..read_data import _BaseRead
 from .funcs import load_from_dataframe, time_feature
@@ -45,22 +46,24 @@ __all__ = [
 class ReadTS(_BaseRead):
 
     INPUT_KEYS = ["ts"]
-    OUTPUT_KEYS = ["ts_path", "ts", "ori_ts"]
+    OUTPUT_KEYS = ["input_path", "ts", "ori_ts"]
     DEAULT_INPUTS = {"ts": "ts"}
-    DEAULT_OUTPUTS = {"ts_path": "ts_path", "ts": "ts", "ori_ts": "ori_ts"}
+    DEAULT_OUTPUTS = {"input_path": "input_path", "ts": "ts", "ori_ts": "ori_ts"}
 
     SUFFIX = ["csv"]
 
     def __init__(self, batch_size=1):
         super().__init__(batch_size)
-        self._reader = TSReader(backend="pandas")
-        self._writer = TSWriter(backend="pandas")
+        self._reader = CSVReader(backend="pandas")
+        self._writer = CSVWriter(backend="pandas")
 
     def apply(self, ts):
         if not isinstance(ts, str):
-            ts_path = (Path(CACHE_DIR) / "predict_input" / "tmp_ts.csv").as_posix()
-            self._writer.write(ts_path, ts)
-            return {"ts_path": ts_path, "ts": ts, "ori_ts": deepcopy(ts)}
+            with tempfile.NamedTemporaryFile(suffix=".csv", delete=True) as temp_file:
+                input_path = Path(temp_file.name)
+                ts_path = input_path.as_posix()
+                self._writer.write(ts_path, ts)
+                yield {"input_path": input_path, "ts": ts, "ori_ts": deepcopy(ts)}
 
         ts_path = ts
         ts_path = self._download_from_url(ts_path)
@@ -69,7 +72,11 @@ class ReadTS(_BaseRead):
         for ts_path in file_list:
             ts_data = self._reader.read(ts_path)
             batch.append(
-                {"ts_path": ts_path, "ts": ts_data, "ori_ts": deepcopy(ts_data)}
+                {
+                    "input_path": Path(ts_path).name,
+                    "ts": ts_data,
+                    "ori_ts": deepcopy(ts_data),
+                }
             )
             if len(batch) >= self.batch_size:
                 yield batch

+ 1 - 1
paddlex/inference/models/anomaly_detection.py

@@ -83,5 +83,5 @@ class UadPredictor(BasicPredictor):
         return Normalize(mean=mean, std=std)
 
     def _pack_res(self, single):
-        keys = ["img_path", "pred"]
+        keys = ["input_path", "pred"]
         return SegResult({key: single[key] for key in keys})

+ 1 - 1
paddlex/inference/models/formula_recognition.py

@@ -51,5 +51,5 @@ class LaTeXOCRPredictor(BasicPredictor):
             raise Exception()
 
     def _pack_res(self, single):
-        keys = ["img_path", "rec_text"]
+        keys = ["input_path", "rec_text"]
         return TextRecResult({key: single[key] for key in keys})

+ 1 - 1
paddlex/inference/models/general_recognition.py

@@ -95,5 +95,5 @@ class ShiTuRecPredictor(BasicPredictor):
         return NormalizeFeatures()
 
     def _pack_res(self, data):
-        keys = ["img_path", "rec_feature"]
+        keys = ["input_path", "rec_feature"]
         return BaseResult({key: data[key] for key in keys})

+ 1 - 1
paddlex/inference/models/image_classification.py

@@ -95,7 +95,7 @@ class ClasPredictor(BasicPredictor):
         return MultiLabelThreshOutput(threshold=float(threshold), class_ids=label_list)
 
     def _pack_res(self, single):
-        keys = ["img_path", "class_ids", "scores"]
+        keys = ["input_path", "class_ids", "scores"]
         if "label_names" in single:
             keys.append("label_names")
         return TopkResult({key: single[key] for key in keys})

+ 1 - 1
paddlex/inference/models/image_unwarping.py

@@ -39,5 +39,5 @@ class WarpPredictor(BasicPredictor):
         self._add_component([predictor, DocTrPostProcess()])
 
     def _pack_res(self, single):
-        keys = ["img_path", "doctr_img"]
+        keys = ["input_path", "doctr_img"]
         return DocTrResult({key: single[key] for key in keys})

+ 1 - 1
paddlex/inference/models/instance_segmentation.py

@@ -62,5 +62,5 @@ class InstanceSegPredictor(DetPredictor):
         self._add_component([predictor, postprecss])
 
     def _pack_res(self, single):
-        keys = ["img_path", "boxes", "masks"]
+        keys = ["input_path", "boxes", "masks"]
         return InstanceSegResult({key: single[key] for key in keys})

+ 1 - 1
paddlex/inference/models/multilabel_classification.py

@@ -27,7 +27,7 @@ class MLClasPredictor(ClasPredictor):
     entities = [*MODELS]
 
     def _pack_res(self, single):
-        keys = ["img_path", "class_ids", "scores"]
+        keys = ["input_path", "class_ids", "scores"]
         if "label_names" in single:
             keys.append("label_names")
         return MLClassResult({key: single[key] for key in keys})

+ 1 - 1
paddlex/inference/models/object_detection.py

@@ -117,5 +117,5 @@ class DetPredictor(BasicPredictor):
         return WarpAffine(input_h=input_h, input_w=input_w, keep_res=keep_res)
 
     def _pack_res(self, single):
-        keys = ["img_path", "boxes"]
+        keys = ["input_path", "boxes"]
         return DetResult({key: single[key] for key in keys})

+ 1 - 1
paddlex/inference/models/semantic_segmentation.py

@@ -82,5 +82,5 @@ class SegPredictor(BasicPredictor):
         return Normalize(mean=mean, std=std)
 
     def _pack_res(self, single):
-        keys = ["img_path", "pred"]
+        keys = ["input_path", "pred"]
         return SegResult({key: single[key] for key in keys})

+ 1 - 1
paddlex/inference/models/table_recognition.py

@@ -99,5 +99,5 @@ class TablePredictor(BasicPredictor):
         return None
 
     def _pack_res(self, single):
-        keys = ["img_path", "bbox", "structure"]
+        keys = ["input_path", "bbox", "structure"]
         return TableRecResult({key: single[key] for key in keys})

+ 1 - 1
paddlex/inference/models/text_detection.py

@@ -97,5 +97,5 @@ class TextDetPredictor(BasicPredictor):
         return None
 
     def _pack_res(self, single):
-        keys = ["img_path", "dt_polys", "dt_scores"]
+        keys = ["input_path", "dt_polys", "dt_scores"]
         return TextDetResult({key: single[key] for key in keys})

+ 1 - 1
paddlex/inference/models/text_recognition.py

@@ -74,5 +74,5 @@ class TextRecPredictor(BasicPredictor):
         return None
 
     def _pack_res(self, single):
-        keys = ["img_path", "rec_text", "rec_score"]
+        keys = ["input_path", "rec_text", "rec_score"]
         return TextRecResult({key: single[key] for key in keys})

+ 3 - 1
paddlex/inference/models/ts_ad.py

@@ -63,4 +63,6 @@ class TSAdPredictor(BasicPredictor):
         )
 
     def _pack_res(self, single):
-        return TSAdResult({"ts_path": single["ts_path"], "anomaly": single["anomaly"]})
+        return TSAdResult(
+            {"input_path": single["input_path"], "anomaly": single["anomaly"]}
+        )

+ 4 - 1
paddlex/inference/models/ts_cls.py

@@ -50,5 +50,8 @@ class TSClsPredictor(BasicPredictor):
 
     def _pack_res(self, single):
         return TSClsResult(
-            {"ts_path": single["ts_path"], "classification": single["classification"]}
+            {
+                "input_path": single["input_path"],
+                "classification": single["classification"],
+            }
         )

+ 3 - 1
paddlex/inference/models/ts_fc.py

@@ -68,4 +68,6 @@ class TSFcPredictor(BasicPredictor):
             )
 
     def _pack_res(self, single):
-        return TSFcResult({"ts_path": single["ts_path"], "forecast": single["pred"]})
+        return TSFcResult(
+            {"input_path": single["input_path"], "forecast": single["pred"]}
+        )

+ 4 - 4
paddlex/inference/pipelines/table_recognition/table_recognition.py

@@ -70,13 +70,13 @@ class TableRecPipeline(BasePipeline):
             self.layout_predictor(x), self.ocr_pipeline(x)
         ):
             single_img_res = {
-                "img_path": "",
+                "input_path": "",
                 "layout_result": {},
                 "ocr_result": {},
                 "table_result": [],
             }
             # update layout result
-            single_img_res["img_path"] = layout_pred["img_path"]
+            single_img_res["input_path"] = layout_pred["input_path"]
             single_img_res["layout_result"] = layout_pred
             subs_of_img = list(self._crop_by_boxes(layout_pred))
             # get cropped images with label "table"
@@ -102,7 +102,7 @@ class TableRecPipeline(BasePipeline):
         rec_text_list = []
         score_list = []
         unmatched_ocr_res = {"dt_polys": [], "rec_text": [], "rec_score": []}
-        unmatched_ocr_res["img_path"] = ocr_res["img_path"]
+        unmatched_ocr_res["input_path"] = ocr_res["input_path"]
         for i, text_box in enumerate(ocr_res["dt_polys"]):
             text_box_area = convert_4point2rect(text_box)
             if is_inside(text_box_area, box):
@@ -138,7 +138,7 @@ class TableRecPipeline(BasePipeline):
             table_res_list.append(
                 StructureTableResult(
                     {
-                        "img_path": input_img["img_path"],
+                        "input_path": input_img["input_path"],
                         "layout_bbox": [int(x) for x in input_img["box"]],
                         "bbox": ori_bbox_list,
                         "img_idx": table_index,

+ 27 - 54
paddlex/inference/results/base.py

@@ -12,61 +12,34 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from abc import abstractmethod
-from pathlib import Path
-import numpy as np
-import json
+from types import MappingProxyType
 
-from copy import deepcopy
-from ...utils import logging
-from ..utils.io import JsonWriter, ImageReader, ImageWriter
+from ...utils.func_register import FuncRegister
+from ..utils.io import ImageReader, ImageWriter
+from .utils.mixin import JsonMixin, ImgMixin, StrMixin
 
 
-def format_data(obj):
-    if isinstance(obj, np.float32):
-        return float(obj)
-    if isinstance(obj, np.ndarray):
-        return [format_data(item) for item in obj.tolist()]
-    elif isinstance(obj, dict):
-        return type(obj)({k: format_data(v) for k, v in obj.items()})
-    elif isinstance(obj, (list, tuple)):
-        return [format_data(i) for i in obj]
-    else:
-        return obj
-
-
-class BaseResult(dict):
+class BaseResult(dict, StrMixin, JsonMixin):
     def __init__(self, data):
-        super().__init__(format_data(deepcopy(data)))
-        self._check_res()
-        self._json_writer = JsonWriter()
-        self._img_reader = ImageReader(backend="opencv")
-        self._img_writer = ImageWriter(backend="opencv")
-
-    def save_to_json(self, save_path, indent=4, ensure_ascii=False):
-        if not save_path.endswith(".json"):
-            save_path = Path(save_path) / f"{Path(self['img_path']).stem}.json"
-        self._json_writer.write(save_path, self, indent=4, ensure_ascii=False)
-
-    def save_to_img(self, save_path):
-        if not save_path.lower().endswith((".jpg", ".png")):
-            save_path = Path(save_path) / f"{Path(self['img_path']).stem}.jpg"
-        else:
-            save_path = Path(save_path)
-        res_img = self._get_res_img()
-        if res_img is not None:
-            self._img_writer.write(save_path.as_posix(), res_img)
-            logging.info(f"The result has been saved in {save_path}.")
-
-    def print(self, json_format=True, indent=4, ensure_ascii=False):
-        str_ = self
-        if json_format:
-            str_ = json.dumps(str_, indent=indent, ensure_ascii=ensure_ascii)
-        logging.info(str_)
-
-    def _check_res(self):
-        pass
-
-    @abstractmethod
-    def _get_res_img(self):
-        raise NotImplementedError
+        super().__init__(data)
+        self._show_func_map = {}
+        self._show_func_register = FuncRegister(self._show_func_map)
+        StrMixin.__init__(self)
+        JsonMixin.__init__(self)
+
+    def save_all(self, save_path):
+        for key in self._show_func_map:
+            func = self._show_func_map[key]
+            if "save" in key:
+                func(save_path=save_path)
+            else:
+                func()
+
+
+class CVResult(BaseResult, ImgMixin):
+    def __init__(self, data):
+        super().__init__(data)
+        ImgMixin.__init__(self, "pillow")
+        self._img_reader = ImageReader(backend="pillow")
+        self._img_writer = ImageWriter(backend="pillow")
+        self._show_func_register("save_to_img")(self.save_to_img)

+ 9 - 18
paddlex/inference/results/clas.py

@@ -14,27 +14,23 @@
 
 
 import PIL
-from PIL import ImageDraw, ImageFont, Image
+from PIL import Image, ImageDraw, ImageFont
 import numpy as np
+import cv2
 
 from ...utils.fonts import PINGFANG_FONT_FILE_PATH
 from ..utils.color_map import get_colormap
-from .base import BaseResult
+from .base import CVResult
 
 
-class TopkResult(BaseResult):
-    def __init__(self, data):
-        super().__init__(data)
-        self._img_reader.set_backend("pillow")
-        self._img_writer.set_backend("pillow")
+class TopkResult(CVResult):
 
-    def _get_res_img(self):
+    def _to_img(self):
         """Draw label on image"""
         labels = self.get("label_names", self["class_ids"])
         label_str = f"{labels[0]} {self['scores'][0]:.2f}"
 
-        image = self._img_reader.read(self["img_path"])
-        image = image.convert("RGB")
+        image = self._img_reader.read(self["input_path"])
         image_size = image.size
         draw = ImageDraw.Draw(image)
         min_font_size = int(image_size[0] * 0.02)
@@ -87,15 +83,10 @@ class TopkResult(BaseResult):
 
 
 class MLClassResult(TopkResult):
-
-    def __init__(self, data):
-        super().__init__(data)
-        self._img_reader.set_backend("pillow")
-        self._img_writer.set_backend("pillow")
-
-    def _get_res_img(self):
+    def _to_img(self):
         """Draw label on image"""
-        image = self._img_reader.read(self["img_path"])
+        image = self._img_reader.read(self["input_path"])
+        image = image.convert("RGB")
         label_names = self["label_names"]
         scores = self["scores"]
         image = image.convert("RGB")

+ 5 - 19
paddlex/inference/results/det.py

@@ -13,17 +13,13 @@
 # limitations under the License.
 
 import os
-
-import numpy as np
-import math
+import cv2
 import PIL
 from PIL import Image, ImageDraw, ImageFont
 
-from ...utils import logging
 from ...utils.fonts import PINGFANG_FONT_FILE_PATH
-from ..utils.io import ImageWriter, ImageReader
 from ..utils.color_map import get_colormap, font_colormap
-from .base import BaseResult
+from .base import CVResult
 
 
 def draw_box(img, boxes):
@@ -77,22 +73,12 @@ def draw_box(img, boxes):
     return img
 
 
-class DetResult(BaseResult):
+class DetResult(CVResult):
     """Save Result Transform"""
 
-    def __init__(self, data):
-        super().__init__(data)
-        # We use pillow backend to save both numpy arrays and PIL Image objects
-        self._img_reader.set_backend("pillow")
-        self._img_writer.set_backend("pillow")
-
-    def _get_res_img(self):
+    def _to_img(self):
         """apply"""
         boxes = self["boxes"]
-        img_path = self["img_path"]
-        file_name = os.path.basename(img_path)
-
-        image = self._img_reader.read(img_path)
+        image = self._img_reader.read(self["input_path"])
         image = draw_box(image, boxes)
-
         return image

+ 11 - 28
paddlex/inference/results/instance_seg.py

@@ -13,20 +13,13 @@
 # limitations under the License.
 
 import os
-
+import cv2
 import numpy as np
-import math
 import copy
-import json
-import cv2
-import PIL
-from PIL import Image, ImageDraw, ImageFont
-
-from ...utils import logging
-from ...utils.fonts import PINGFANG_FONT_FILE_PATH
-from ..utils.io import ImageWriter, ImageReader
-from ..utils.color_map import get_colormap, font_colormap
-from .base import BaseResult
+from PIL import Image
+
+from ..utils.color_map import get_colormap
+from .base import CVResult
 from .det import draw_box
 
 
@@ -136,20 +129,12 @@ def draw_mask(im, boxes, np_masks, img_size):
     return Image.fromarray(im.astype("uint8"))
 
 
-class InstanceSegResult(BaseResult):
+class InstanceSegResult(CVResult):
     """Save Result Transform"""
 
-    def __init__(self, data):
-        super().__init__(data)
-        # We use pillow backend to save both numpy arrays and PIL Image objects
-        self._img_reader.set_backend("pillow")
-        self._img_writer.set_backend("pillow")
-
-    def _get_res_img(self):
+    def _to_img(self):
         """apply"""
-        img_path = self["img_path"]
-        file_name = os.path.basename(img_path)
-        image = self._img_reader.read(img_path)
+        image = self._img_reader.read(self["input_path"])
         ori_img_size = list(image.size)[::-1]
         boxes = self["boxes"]
         masks = self["masks"]
@@ -161,9 +146,7 @@ class InstanceSegResult(BaseResult):
 
         return image
 
-    def print(self, json_format=True, indent=4, ensure_ascii=False):
+    def _to_str(self):
         str_ = copy.deepcopy(self)
-        del str_["masks"]
-        if json_format:
-            str_ = json.dumps(str_, indent=indent, ensure_ascii=ensure_ascii)
-        logging.info(str_)
+        str_["masks"] = "..."
+        return str(str_)

+ 32 - 38
paddlex/inference/results/ocr.py

@@ -19,17 +19,11 @@ import cv2
 import PIL
 from PIL import Image, ImageDraw, ImageFont
 
-from ...utils import logging
 from ...utils.fonts import PINGFANG_FONT_FILE_PATH
-from ..utils.io import ImageReader
-from .base import BaseResult
+from .base import CVResult
 
 
-class OCRResult(BaseResult):
-
-    def _check_res(self):
-        if len(self["dt_polys"]) == 0:
-            logging.warning("No text detected!")
+class OCRResult(CVResult):
 
     def get_minarea_rect(self, points):
         bounding_box = cv2.minAreaRect(points)
@@ -55,17 +49,17 @@ class OCRResult(BaseResult):
 
         return box
 
-    def _get_res_img(
+    def _to_img(
         self,
-        drop_score=0.5,
-        font_path=PINGFANG_FONT_FILE_PATH,
     ):
         """draw ocr result"""
+        # TODO(gaotingquan): mv to postprocess
+        drop_score = 0.5
+
         boxes = self["dt_polys"]
         txts = self["rec_text"]
         scores = self["rec_score"]
-        img = self._img_reader.read(self["img_path"])
-        image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
+        image = self._img_reader.read(self["input_path"])
         h, w = image.height, image.width
         img_left = image.copy()
         img_right = np.ones((h, w, 3), dtype=np.uint8) * 255
@@ -74,37 +68,37 @@ class OCRResult(BaseResult):
         if txts is None or len(txts) != len(boxes):
             txts = [None] * len(boxes)
         for idx, (box, txt) in enumerate(zip(boxes, txts)):
-            try:
-                if scores is not None and scores[idx] < drop_score:
-                    continue
-                color = (
-                    random.randint(0, 255),
-                    random.randint(0, 255),
-                    random.randint(0, 255),
-                )
-                box = np.array(box)
-                if len(box) > 4:
-                    pts = [(x, y) for x, y in box.tolist()]
-                    draw_left.polygon(pts, outline=color, width=8)
-                    box = self.get_minarea_rect(box)
-                    height = int(0.5 * (max(box[:, 1]) - min(box[:, 1])))
-                    box[:2, 1] = np.mean(box[:, 1])
-                    box[2:, 1] = np.mean(box[:, 1]) + min(20, height)
-                draw_left.polygon(box, fill=color)
-                img_right_text = draw_box_txt_fine((w, h), box, txt, font_path)
-                pts = np.array(box, np.int32).reshape((-1, 1, 2))
-                cv2.polylines(img_right_text, [pts], True, color, 1)
-                img_right = cv2.bitwise_and(img_right, img_right_text)
-            except:
+            if scores is not None and scores[idx] < drop_score:
                 continue
+            color = (
+                random.randint(0, 255),
+                random.randint(0, 255),
+                random.randint(0, 255),
+            )
+            box = np.array(box)
+            if len(box) > 4:
+                pts = [(x, y) for x, y in box.tolist()]
+                draw_left.polygon(pts, outline=color, width=8)
+                box = self.get_minarea_rect(box)
+                height = int(0.5 * (max(box[:, 1]) - min(box[:, 1])))
+                box[:2, 1] = np.mean(box[:, 1])
+                box[2:, 1] = np.mean(box[:, 1]) + min(20, height)
+            draw_left.polygon(box, fill=color)
+            img_right_text = draw_box_txt_fine(
+                (w, h), box, txt, PINGFANG_FONT_FILE_PATH
+            )
+            pts = np.array(box, np.int32).reshape((-1, 1, 2))
+            cv2.polylines(img_right_text, [pts], True, color, 1)
+            img_right = cv2.bitwise_and(img_right, img_right_text)
+
         img_left = Image.blend(image, img_left, 0.5)
         img_show = Image.new("RGB", (w * 2, h), (255, 255, 255))
         img_show.paste(img_left, (0, 0, w, h))
         img_show.paste(Image.fromarray(img_right), (w, 0, w * 2, h))
-        return cv2.cvtColor(np.array(img_show), cv2.COLOR_RGB2BGR)
+        return img_show
 
 
-def draw_box_txt_fine(img_size, box, txt, font_path=PINGFANG_FONT_FILE_PATH):
+def draw_box_txt_fine(img_size, box, txt, font_path):
     """draw box text"""
     box_height = int(
         math.sqrt((box[0][0] - box[3][0]) ** 2 + (box[0][1] - box[3][1]) ** 2)
@@ -145,7 +139,7 @@ def draw_box_txt_fine(img_size, box, txt, font_path=PINGFANG_FONT_FILE_PATH):
     return img_right_text
 
 
-def create_font(txt, sz, font_path=PINGFANG_FONT_FILE_PATH):
+def create_font(txt, sz, font_path):
     """create font"""
     font_size = int(sz[1] * 0.8)
     font = ImageFont.truetype(font_path, font_size, encoding="utf-8")

+ 8 - 11
paddlex/inference/results/seg.py

@@ -19,21 +19,20 @@ import copy
 import json
 
 from ...utils import logging
-from .base import BaseResult
+from .utils.mixin import ImgMixin
+from .base import CVResult
 
 
-class SegResult(BaseResult):
+class SegResult(CVResult):
     """Save Result Transform"""
 
     def __init__(self, data):
         super().__init__(data)
-        self.data = data
-        # We use pillow backend to save both numpy arrays and PIL Image objects
         self._img_writer.set_backend("pillow", format_="PNG")
 
-    def _get_res_img(self):
+    def _to_img(self):
         """apply"""
-        seg_map = self.data["pred"]
+        seg_map = self["pred"]
         pc_map = self.get_pseudo_color_map(seg_map[0])
         return pc_map
 
@@ -67,9 +66,7 @@ class SegResult(BaseResult):
             color_map[: len(custom_color)] = custom_color
         return color_map
 
-    def print(self, json_format=True, indent=4, ensure_ascii=False):
+    def _to_str(self):
         str_ = copy.deepcopy(self)
-        del str_["pred"]
-        if json_format:
-            str_ = json.dumps(str_, indent=indent, ensure_ascii=ensure_ascii)
-        logging.info(str_)
+        str_["pred"] = "..."
+        return str(str_)

+ 14 - 66
paddlex/inference/results/table_rec.py

@@ -16,20 +16,23 @@ import cv2
 import numpy as np
 from pathlib import Path
 
-from .base import BaseResult
-from ...utils import logging
-from ..utils.io import HtmlWriter, XlsxWriter
+from .utils.mixin import HtmlMixin, XlsxMixin
+from .base import BaseResult, CVResult
 
 
-class TableRecResult(BaseResult):
+class TableRecResult(CVResult, HtmlMixin):
     """SaveTableResults"""
 
     def __init__(self, data):
         super().__init__(data)
-        self._img_writer.set_backend("pillow")
+        HtmlMixin.__init__(self)
+        self._show_func_register("save_to_html")(self.save_to_html)
 
-    def _get_res_img(self):
-        image = self._img_reader.read(self["img_path"])
+    def _to_html(self):
+        return self["html"]
+
+    def _to_img(self):
+        image = self._img_reader.read(self["input_path"])
         bbox_res = self["bbox"]
         if len(bbox_res) > 0 and len(bbox_res[0]) == 4:
             vis_img = self.draw_rectangle(image, bbox_res)
@@ -54,76 +57,21 @@ class TableRecResult(BaseResult):
         return image
 
 
-class StructureTableResult(TableRecResult):
+class StructureTableResult(TableRecResult, XlsxMixin):
     """StructureTableResult"""
 
     def __init__(self, data):
-        """__init__"""
         super().__init__(data)
-        self._img_writer.set_backend("pillow")
-        self._html_writer = HtmlWriter()
-        self._xlsx_writer = XlsxWriter()
-
-    def save_to_html(self, save_path):
-        """save_to_html"""
-        img_idx = self["img_idx"]
-        if not save_path.endswith(".html"):
-            if img_idx > 0:
-                save_path = (
-                    Path(save_path) / f"{Path(self['img_path']).stem}_{img_idx}.html"
-                )
-            else:
-                save_path = Path(save_path) / f"{Path(self['img_path']).stem}.html"
-        elif img_idx > 0:
-            save_path = Path(save_path).stem / f"_{img_idx}.html"
-        self._html_writer.write(save_path.as_posix(), self["html"])
-        logging.info(f"The result has been saved in {save_path}.")
-
-    def save_to_excel(self, save_path):
-        """save_to_excel"""
-        img_idx = self["img_idx"]
-        if not save_path.endswith(".xlsx"):
-            if img_idx > 0:
-                save_path = (
-                    Path(save_path) / f"{Path(self['img_path']).stem}_{img_idx}.xlsx"
-                )
-            else:
-                save_path = Path(save_path) / f"{Path(self['img_path']).stem}.xlsx"
-        elif img_idx > 0:
-            save_path = Path(save_path).stem / f"_{img_idx}.xlsx"
-        self._xlsx_writer.write(save_path.as_posix(), self["html"])
-        logging.info(f"The result has been saved in {save_path}.")
-
-    def save_to_img(self, save_path):
-        img_idx = self["img_idx"]
-        if not save_path.endswith((".jpg", ".png")):
-            if img_idx > 0:
-                save_path = (
-                    Path(save_path) / f"{Path(self['img_path']).stem}_{img_idx}.jpg"
-                )
-            else:
-                save_path = Path(save_path) / f"{Path(self['img_path']).stem}.jpg"
-        elif img_idx > 0:
-            save_path = Path(save_path).stem / f"_{img_idx}.jpg"
-        else:
-            save_path = Path(save_path)
-        res_img = self._get_res_img()
-        if res_img is not None:
-            self._img_writer.write(save_path.as_posix(), res_img)
-            logging.info(f"The result has been saved in {save_path}.")
+        XlsxMixin.__init__(self)
 
 
 class TableResult(BaseResult):
     """TableResult"""
 
-    def __init__(self, data):
-        """__init__"""
-        super().__init__(data)
-
     def save_to_img(self, save_path):
         if not save_path.lower().endswith((".jpg", ".png")):
-            img_path = self["img_path"]
-            save_path = Path(save_path) / f"{Path(img_path).stem}"
+            input_path = self["input_path"]
+            save_path = Path(save_path) / f"{Path(input_path).stem}"
         else:
             save_path = Path(save_path).stem
         layout_save_path = f"{save_path}_layout.jpg"

+ 9 - 8
paddlex/inference/results/text_det.py

@@ -15,18 +15,19 @@
 import numpy as np
 import cv2
 
-from ..utils.io import ImageReader
-from .base import BaseResult
+from .base import CVResult
 
 
-class TextDetResult(BaseResult):
+class TextDetResult(CVResult):
+    def __init__(self, data):
+        super().__init__(data)
+        self._img_reader.set_backend("opencv")
 
-    def _get_res_img(self):
+    def _to_img(self):
         """draw rectangle"""
         boxes = self["dt_polys"]
-        img = self._img_reader.read(self["img_path"])
-        res_img = img.copy()
+        image = self._img_reader.read(self["input_path"])
         for box in boxes:
             box = np.reshape(np.array(box).astype(int), [-1, 1, 2]).astype(np.int64)
-            cv2.polylines(res_img, [box], True, (0, 0, 255), 2)
-        return res_img
+            cv2.polylines(image, [box], True, (0, 0, 255), 2)
+        return image

+ 6 - 7
paddlex/inference/results/text_rec.py

@@ -12,12 +12,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from .base import BaseResult
+from ...utils import logging
+from .base import CVResult
 
 
-class TextRecResult(BaseResult):
-    def __init__(self, data):
-        super().__init__(data)
-
-    def _get_res_img(self, save_path):
-        raise Exception("Don't support to save Text Rec result to img!")
+class TextRecResult(CVResult):
+    def _to_img(self):
+        logging.warning("TextRecResult don't support save to img!")
+        return None

+ 10 - 36
paddlex/inference/results/ts.py

@@ -12,52 +12,26 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from pathlib import Path
-import numpy as np
-import pandas as pd
-
-from ...utils import logging
-from ..utils.io import TSWriter
+from .utils.mixin import JsonMixin, CSVMixin
 from .base import BaseResult
 
 
-class TSFcResult(BaseResult):
-
+class _BaseTSResult(BaseResult, CSVMixin):
     def __init__(self, data):
         super().__init__(data)
-        self._writer = TSWriter(backend="pandas")
+        CSVMixin.__init__(self)
 
-    def save_to_csv(self, save_path):
-        """write ts"""
-        if not save_path.endswith(".csv"):
-            save_path = Path(save_path) / f"{Path(self['ts_path']).stem}.csv"
-        self._writer.write(save_path, self["forecast"])
-        logging.info(f"The result has been saved in {save_path}.")
 
+class TSFcResult(_BaseTSResult):
+    def _to_csv(self, save_path):
+        return self["forecast"]
 
-class TSClsResult(BaseResult):
-
-    def __init__(self, data):
-        super().__init__(data)
-        self._writer = TSWriter(backend="pandas")
 
+class TSClsResult(_BaseTSResult):
     def save_to_csv(self, save_path):
-        """write ts"""
-        if not save_path.endswith(".csv"):
-            save_path = Path(save_path) / f"{Path(self['ts_path']).stem}.csv"
-        self._writer.write(save_path, self["classification"])
-        logging.info(f"The result has been saved in {save_path}.")
-
-
-class TSAdResult(BaseResult):
+        return self["classification"]
 
-    def __init__(self, data):
-        super().__init__(data)
-        self._writer = TSWriter(backend="pandas")
 
+class TSAdResult(_BaseTSResult):
     def save_to_csv(self, save_path):
-        """write ts"""
-        if not save_path.endswith(".csv"):
-            save_path = Path(save_path) / f"{Path(self['ts_path']).stem}.csv"
-        self._writer.write(save_path, self["anomaly"])
-        logging.info(f"The result has been saved in {save_path}.")
+        return self["anomaly"]

+ 176 - 0
paddlex/inference/results/utils/mixin.py

@@ -0,0 +1,176 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from abc import abstractmethod
+import json
+from pathlib import Path
+import numpy as np
+from PIL import Image
+
+from ....utils import logging
+from ...utils.io import (
+    JsonWriter,
+    ImageReader,
+    ImageWriter,
+    CSVWriter,
+    HtmlWriter,
+    XlsxWriter,
+)
+
+
+def _save_list_data(save_func, save_path, data, *args, **kwargs):
+    save_path = Path(save_path)
+    if data is None:
+        return
+    if isinstance(data, list):
+        for idx, single in enumerate(data):
+            save_func(
+                (
+                    save_path.parent / f"{save_path.stem}_{idx}.{save_path.suffix}"
+                ).as_posix(),
+                single,
+                *args,
+                **kwargs,
+            )
+    save_func(save_path.as_posix(), data, *args, **kwargs)
+    logging.info(f"The result has been saved in {save_path}.")
+
+
+class StrMixin:
+    def __init__(self):
+        self._show_func_register()(self.print)
+
+    @property
+    def str(self):
+        return self._to_str()
+
+    def _to_str(self):
+        return str(self)
+
+    def print(self, json_format=False, indent=4, ensure_ascii=False):
+        str_ = self._to_str()
+        if json_format:
+            str_ = json.dumps(str_, indent=indent, ensure_ascii=ensure_ascii)
+        logging.info(str_)
+
+
+class JsonMixin:
+    def __init__(self):
+        self._json_writer = JsonWriter()
+        self._show_func_register()(self.save_to_json)
+
+    def _to_json(self):
+        def _format_data(obj):
+            if isinstance(obj, np.float32):
+                return float(obj)
+            if isinstance(obj, np.ndarray):
+                return [_format_data(item) for item in obj.tolist()]
+            elif isinstance(obj, dict):
+                return type(obj)({k: _format_data(v) for k, v in obj.items()})
+            elif isinstance(obj, (list, tuple)):
+                return [_format_data(i) for i in obj]
+            else:
+                return obj
+
+        return _format_data(self)
+
+    @property
+    def json(self):
+        return self._to_json()
+
+    def save_to_json(self, save_path, indent=4, ensure_ascii=False, *args, **kwargs):
+        if not save_path.endswith(".json"):
+            save_path = Path(save_path) / f"{Path(self['input_path']).stem}.json"
+        _save_list_data(
+            self._json_writer.write,
+            save_path,
+            self.json,
+            indent=indent,
+            ensure_ascii=ensure_ascii,
+            *args,
+            **kwargs,
+        )
+
+
+class ImgMixin:
+    def __init__(self, backend="pillow", *args, **kwargs):
+        self._img_writer = ImageWriter(backend=backend, *args, **kwargs)
+        self._show_func_register()(self.save_to_img)
+
+    @abstractmethod
+    def _to_img(self):
+        raise NotImplementedError
+
+    @property
+    def img(self):
+        image = self._to_img()
+        # The img must be a PIL.Image obj
+        if isinstance(image, np.ndarray):
+            return Image.fromarray(image)
+        return image
+
+    def save_to_img(self, save_path, *args, **kwargs):
+        if not save_path.lower().endswith((".jpg", ".png")):
+            fp = Path(self["input_path"])
+            save_path = Path(save_path) / f"{fp.stem}.{fp.suffix}"
+        _save_list_data(self._img_writer.write, save_path, self.img, *args, **kwargs)
+
+
+class CSVMixin:
+    def __init__(self, backend="pandas", *args, **kwargs):
+        self._csv_writer = CSVWriter(backend=backend, *args, **kwargs)
+        self._show_func_register()(self.save_to_csv)
+
+    @abstractmethod
+    def _to_csv(self):
+        raise NotImplementedError
+
+    def save_to_csv(self, save_path, *args, **kwargs):
+        if not save_path.endswith(".csv"):
+            save_path = Path(save_path) / f"{Path(self['input_path']).stem}.csv"
+        _save_list_data(
+            self._csv_writer.write, save_path, self._to_csv(), *args, **kwargs
+        )
+
+
+class HtmlMixin:
+    def __init__(self, *args, **kwargs):
+        self._html_writer = HtmlWriter(*args, **kwargs)
+        self._show_func_register()(self.save_to_html)
+
+    @property
+    def html(self):
+        return self._to_html()
+
+    def _to_html(self):
+        return self["html"]
+
+    def save_to_html(self, save_path, *args, **kwargs):
+        if not save_path.endswith(".html"):
+            save_path = Path(save_path) / f"{Path(self['input_path']).stem}.html"
+        _save_list_data(self._html_writer.write, save_path, self.html, *args, **kwargs)
+
+
+class XlsxMixin:
+    def __init__(self, *args, **kwargs):
+        self._xlsx_writer = XlsxWriter(*args, **kwargs)
+        self._show_func_register()(self.save_to_xlsx)
+
+    def _to_xlsx(self):
+        return self["html"]
+
+    def save_to_xlsx(self, save_path, *args, **kwargs):
+        if not save_path.endswith(".xlsx"):
+            save_path = Path(save_path) / f"{Path(self['input_path']).stem}.xlsx"
+        _save_list_data(self._xlsx_writer.write, save_path, self.html, *args, **kwargs)

+ 0 - 1
paddlex/inference/results/warp.py

@@ -23,7 +23,6 @@ from .base import BaseResult
 class DocTrResult(BaseResult):
     def __init__(self, data):
         super().__init__(data)
-        # We use opencv backend to save both numpy arrays
         self._img_writer.set_backend("opencv")
 
     def _get_res_img(self):

+ 2 - 2
paddlex/inference/utils/io/__init__.py

@@ -13,13 +13,13 @@
 # limitations under the License.
 
 
-from .readers import ReaderType, ImageReader, VideoReader, TSReader, PDFReader
+from .readers import ReaderType, ImageReader, VideoReader, CSVReader, PDFReader
 from .writers import (
     WriterType,
     ImageWriter,
     TextWriter,
     JsonWriter,
-    TSWriter,
+    CSVWriter,
     HtmlWriter,
     XlsxWriter,
 )

+ 8 - 9
paddlex/inference/utils/io/readers.py

@@ -21,7 +21,7 @@ from PIL import Image, ImageOps
 import pandas as pd
 import numpy as np
 
-__all__ = ["ReaderType", "ImageReader", "VideoReader", "TSReader", "PDFReader"]
+__all__ = ["ReaderType", "ImageReader", "VideoReader", "CSVReader", "PDFReader"]
 
 
 class ReaderType(enum.Enum):
@@ -203,7 +203,6 @@ class PDFReaderBackend(_BaseReaderBackend):
 
     def __init__(self, rotate=0, zoom_x=2.0, zoom_y=2.0):
         super().__init__()
-        print(rotate)
         self.mat = fitz.Matrix(zoom_x, zoom_y).prerotate(rotate)
 
     def read_file(self, in_path):
@@ -279,8 +278,8 @@ class OpenCVVideoReaderBackend(_VideoReaderBackend):
             self._cap = None
 
 
-class TSReader(_BaseReader):
-    """TSReader"""
+class CSVReader(_BaseReader):
+    """CSVReader"""
 
     def __init__(self, backend="pandas", **bk_args):
         super().__init__(backend=backend, **bk_args)
@@ -293,7 +292,7 @@ class TSReader(_BaseReader):
     def _init_backend(self, bk_type, bk_args):
         """init backend"""
         if bk_type == "pandas":
-            return PandasTSReaderBackend(**bk_args)
+            return PandasCSVReaderBackend(**bk_args)
         else:
             raise ValueError("Unsupported backend type")
 
@@ -302,14 +301,14 @@ class TSReader(_BaseReader):
         return ReaderType.TS
 
 
-class _TSReaderBackend(_BaseReaderBackend):
-    """_TSReaderBackend"""
+class _CSVReaderBackend(_BaseReaderBackend):
+    """_CSVReaderBackend"""
 
     pass
 
 
-class PandasTSReaderBackend(_TSReaderBackend):
-    """PandasTSReaderBackend"""
+class PandasCSVReaderBackend(_CSVReaderBackend):
+    """PandasCSVReaderBackend"""
 
     def __init__(self):
         super().__init__()

+ 9 - 9
paddlex/inference/utils/io/writers.py

@@ -30,7 +30,7 @@ __all__ = [
     "ImageWriter",
     "TextWriter",
     "JsonWriter",
-    "TSWriter",
+    "CSVWriter",
     "HtmlWriter",
     "XlsxWriter",
 ]
@@ -45,7 +45,7 @@ class WriterType(enum.Enum):
     JSON = 4
     HTML = 5
     XLSX = 6
-    TS = 7
+    CSV = 7
 
 
 class _BaseWriter(object):
@@ -300,8 +300,8 @@ class UJsonWriterBackend(_BaseJsonWriterBackend):
         raise NotImplementedError
 
 
-class TSWriter(_BaseWriter):
-    """TSWriter"""
+class CSVWriter(_BaseWriter):
+    """CSVWriter"""
 
     def __init__(self, backend="pandas", **bk_args):
         super().__init__(backend=backend, **bk_args)
@@ -313,22 +313,22 @@ class TSWriter(_BaseWriter):
     def _init_backend(self, bk_type, bk_args):
         """init backend"""
         if bk_type == "pandas":
-            return PandasTSWriterBackend(**bk_args)
+            return PandasCSVWriterBackend(**bk_args)
         else:
             raise ValueError("Unsupported backend type")
 
     def get_type(self):
         """get type"""
-        return WriterType.TS
+        return WriterType.CSV
 
 
-class _TSWriterBackend(_BaseWriterBackend):
-    """_TSWriterBackend"""
+class _CSVWriterBackend(_BaseWriterBackend):
+    """_CSVWriterBackend"""
 
     pass
 
 
-class PandasTSWriterBackend(_TSWriterBackend):
+class PandasCSVWriterBackend(_CSVWriterBackend):
     """PILImageWriterBackend"""
 
     def __init__(self):

+ 2 - 3
paddlex/paddlex_cli.py

@@ -97,9 +97,8 @@ def pipeline_predict(pipeline, input, device=None, save_dir=None):
     result = pipeline(input, device=device)
     for res in result:
         res.print(json_format=False)
-        # TODO(gaotingquan): support to save all
-        # if save_dir:
-        #     i["result"].save()
+        if save_dir:
+            res.save_all(save_path=save_dir)
 
 
 # for CLI

+ 6 - 5
paddlex/utils/func_register.py

@@ -22,18 +22,19 @@ class FuncRegister(object):
         assert isinstance(register_map, dict)
         self._register_map = register_map
 
-    def __call__(self, key):
+    def __call__(self, key=None):
         """register the decoratored func as key in dict"""
 
         def decorator(func):
-            self._register_map[key] = func
+            actual_key = key if key is not None else func.__name__
+            self._register_map[actual_key] = func
             logging.debug(
-                f"The func ({func.__name__}) has been registered as key ({key})."
+                f"The func ({func.__name__}) has been registered as key ({actual_key})."
             )
 
             @wraps(func)
-            def wrapper(self, *args, **kwargs):
-                return func(self, *args, **kwargs)
+            def wrapper(*args, **kwargs):
+                return func(*args, **kwargs)
 
             return wrapper