浏览代码

simplify & fix json format

gaotingquan 1 年之前
父节点
当前提交
c609c19e26

+ 5 - 7
paddlex/inference/results/base.py

@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from types import MappingProxyType
+import inspect
 
 from ...utils.func_register import FuncRegister
 from ..utils.io import ImageReader, ImageWriter
@@ -22,15 +22,14 @@ from .utils.mixin import JsonMixin, ImgMixin, StrMixin
 class BaseResult(dict, StrMixin, JsonMixin):
     def __init__(self, data):
         super().__init__(data)
-        self._show_func_map = {}
-        self._show_func_register = FuncRegister(self._show_func_map)
+        self._show_funcs = []
         StrMixin.__init__(self)
         JsonMixin.__init__(self)
 
     def save_all(self, save_path):
-        for key in self._show_func_map:
-            func = self._show_func_map[key]
-            if "save" in key:
+        for func in self._show_funcs:
+            signature = inspect.signature(func)
+            if "save_path" in signature.parameters:
                 func(save_path=save_path)
             else:
                 func()
@@ -42,4 +41,3 @@ class CVResult(BaseResult, ImgMixin):
         ImgMixin.__init__(self, "pillow")
         self._img_reader = ImageReader(backend="pillow")
         self._img_writer = ImageWriter(backend="pillow")
-        self._show_func_register("save_to_img")(self.save_to_img)

+ 4 - 2
paddlex/inference/results/chat_ocr.py

@@ -33,8 +33,10 @@ class VisualInfoResult(BaseResult):
 class VisualResult(BaseResult):
     """VisualInfoResult"""
 
-    def _to_str(self):
-        return str({"layout_parsing_result": self["layout_parsing_result"]})
+    def _to_str(self, _, *args, **kwargs):
+        return super()._to_str(
+            {"layout_parsing_result": self["layout_parsing_result"]}, *args, **kwargs
+        )
 
     def save_to_html(self, save_path):
         if not save_path.lower().endswith(("html")):

+ 2 - 3
paddlex/inference/results/formula_rec.py

@@ -22,9 +22,8 @@ from .base import CVResult
 class FormulaRecResult(CVResult):
     _HARD_FLAG = False
 
-    def _to_str(self):
-        rec_formula_str = ", ".join([str(formula) for formula in self["rec_formula"]])
-        return str(self).replace("\\\\", "\\")
+    def _to_str(self, *args, **kwargs):
+        return super()._to_str(*args, **kwargs).replace("\\\\", "\\")
 
     def _to_img(
         self,

+ 4 - 4
paddlex/inference/results/instance_seg.py

@@ -146,7 +146,7 @@ class InstanceSegResult(CVResult):
 
         return image
 
-    def _to_str(self):
-        str_ = copy.deepcopy(self)
-        str_["masks"] = "..."
-        return str(str_)
+    def _to_str(self, _, *args, **kwargs):
+        data = copy.deepcopy(self)
+        data["masks"] = "..."
+        return super()._to_str(data, *args, **kwargs)

+ 4 - 4
paddlex/inference/results/seg.py

@@ -66,7 +66,7 @@ class SegResult(CVResult):
             color_map[: len(custom_color)] = custom_color
         return color_map
 
-    def _to_str(self):
-        str_ = copy.deepcopy(self)
-        str_["pred"] = "..."
-        return str(str_)
+    def _to_str(self, _, *args, **kwargs):
+        data = copy.deepcopy(self)
+        data["pred"] = "..."
+        return super()._to_str(data, *args, **kwargs)

+ 0 - 1
paddlex/inference/results/table_rec.py

@@ -65,7 +65,6 @@ class StructureTableResult(TableRecResult, HtmlMixin, XlsxMixin):
     def __init__(self, data):
         super().__init__(data)
         HtmlMixin.__init__(self)
-        self._show_func_register("save_to_html")(self.save_to_html)
         XlsxMixin.__init__(self)
 
     def _to_html(self):

+ 14 - 11
paddlex/inference/results/utils/mixin.py

@@ -54,20 +54,23 @@ class StrMixin:
     def str(self):
         return self._to_str()
 
-    def _to_str(self):
-        return str(self)
+    def _to_str(self, data, json_format=False, indent=4, ensure_ascii=False):
+        if json_format:
+            return json.dumps(data.json, indent=indent, ensure_ascii=ensure_ascii)
+        else:
+            return str(data)
 
     def print(self, json_format=False, indent=4, ensure_ascii=False):
-        str_ = self._to_str()
-        if json_format:
-            str_ = json.dumps(str_, indent=indent, ensure_ascii=ensure_ascii)
+        str_ = self._to_str(
+            self, json_format=json_format, indent=indent, ensure_ascii=ensure_ascii
+        )
         logging.info(str_)
 
 
 class JsonMixin:
     def __init__(self):
         self._json_writer = JsonWriter()
-        self._show_func_register()(self.save_to_json)
+        self._show_funcs.append(self.save_to_json)
 
     def _to_json(self):
         def _format_data(obj):
@@ -109,7 +112,7 @@ class JsonMixin:
 class Base64Mixin:
     def __init__(self, *args, **kwargs):
         self._base64_writer = TextWriter(*args, **kwargs)
-        self._show_func_register()(self.save_to_base64)
+        self._show_funcs.append(self.save_to_base64)
 
     @abstractmethod
     def _to_base64(self):
@@ -131,7 +134,7 @@ class Base64Mixin:
 class ImgMixin:
     def __init__(self, backend="pillow", *args, **kwargs):
         self._img_writer = ImageWriter(backend=backend, *args, **kwargs)
-        self._show_func_register()(self.save_to_img)
+        self._show_funcs.append(self.save_to_img)
 
     @abstractmethod
     def _to_img(self):
@@ -155,7 +158,7 @@ class ImgMixin:
 class CSVMixin:
     def __init__(self, backend="pandas", *args, **kwargs):
         self._csv_writer = CSVWriter(backend=backend, *args, **kwargs)
-        self._show_func_register()(self.save_to_csv)
+        self._show_funcs.append(self.save_to_csv)
 
     @abstractmethod
     def _to_csv(self):
@@ -172,7 +175,7 @@ class CSVMixin:
 class HtmlMixin:
     def __init__(self, *args, **kwargs):
         self._html_writer = HtmlWriter(*args, **kwargs)
-        self._show_func_register()(self.save_to_html)
+        self._show_funcs.append(self.save_to_html)
 
     @property
     def html(self):
@@ -190,7 +193,7 @@ class HtmlMixin:
 class XlsxMixin:
     def __init__(self, *args, **kwargs):
         self._xlsx_writer = XlsxWriter(*args, **kwargs)
-        self._show_func_register()(self.save_to_xlsx)
+        self._show_funcs.append(self.save_to_xlsx)
 
     def _to_xlsx(self):
         return self["html"]

+ 4 - 4
paddlex/inference/results/warp.py

@@ -25,7 +25,7 @@ class DocTrResult(CVResult):
     def _to_img(self):
         return np.array(self["doctr_img"])
 
-    def _to_str(self, json_format=True, indent=4, ensure_ascii=False):
-        str_ = copy.deepcopy(self)
-        str_.pop("doctr_img")
-        return str_
+    def _to_str(self, _, *args, **kwargs):
+        data = copy.deepcopy(self)
+        data.pop("doctr_img")
+        return super()._to_str(data, *args, **kwargs)