zhangyubo0722 1 سال پیش
والد
کامیت
80b872a921

+ 6 - 1
paddlex/inference/components/task_related/__init__.py

@@ -20,7 +20,12 @@ from .text_det import (
     SortBoxes,
     CropByPolys,
 )
-from .text_rec import OCRReisizeNormImg, CTCLabelDecode
+from .text_rec import (
+    OCRReisizeNormImg,
+    LaTeXOCRReisizeNormImg,
+    CTCLabelDecode,
+    LaTeXOCRDecode,
+)
 from .table_rec import TableLabelDecode
 from .det import DetPostProcess, CropByBoxes, DetPad, WarpAffine
 from .instance_seg import InstanceSegPostProcess

+ 158 - 238
paddlex/inference/components/task_related/text_rec.py

@@ -31,9 +31,9 @@ from ..base import BaseComponent
 
 __all__ = [
     "OCRReisizeNormImg",
-    # "LaTeXOCRReisizeNormImg",
+    "LaTeXOCRReisizeNormImg",
     "CTCLabelDecode",
-    # "LaTeXOCRDecode",
+    "LaTeXOCRDecode",
 ]
 
 
@@ -81,111 +81,106 @@ class OCRReisizeNormImg(BaseComponent):
         return {"img": img}
 
 
-# class LaTeXOCRReisizeNormImg(BaseComponent):
-#     """for ocr image resize and normalization"""
-
-#     def __init__(self, rec_image_shape=[3, 48, 320]):
-#         super().__init__()
-#         self.rec_image_shape = rec_image_shape
-
-#     def pad_(self, img, divable=32):
-#         threshold = 128
-#         data = np.array(img.convert("LA"))
-#         if data[..., -1].var() == 0:
-#             data = (data[..., 0]).astype(np.uint8)
-#         else:
-#             data = (255 - data[..., -1]).astype(np.uint8)
-#         data = (data - data.min()) / (data.max() - data.min()) * 255
-#         if data.mean() > threshold:
-#             # To invert the text to white
-#             gray = 255 * (data < threshold).astype(np.uint8)
-#         else:
-#             gray = 255 * (data > threshold).astype(np.uint8)
-#             data = 255 - data
-
-#         coords = cv2.findNonZero(gray)  # Find all non-zero points (text)
-#         a, b, w, h = cv2.boundingRect(coords)  # Find minimum spanning bounding box
-#         rect = data[b : b + h, a : a + w]
-#         im = Image.fromarray(rect).convert("L")
-#         dims = []
-#         for x in [w, h]:
-#             div, mod = divmod(x, divable)
-#             dims.append(divable * (div + (1 if mod > 0 else 0)))
-#         padded = Image.new("L", dims, 255)
-#         padded.paste(im, (0, 0, im.size[0], im.size[1]))
-#         return padded
-
-#     def minmax_size_(
-#         self,
-#         img,
-#         max_dimensions,
-#         min_dimensions,
-#     ):
-#         if max_dimensions is not None:
-#             ratios = [a / b for a, b in zip(img.size, max_dimensions)]
-#             if any([r > 1 for r in ratios]):
-#                 size = np.array(img.size) // max(ratios)
-#                 img = img.resize(tuple(size.astype(int)), Image.BILINEAR)
-#         if min_dimensions is not None:
-#             # hypothesis: there is a dim in img smaller than min_dimensions, and return a proper dim >= min_dimensions
-#             padded_size = [
-#                 max(img_dim, min_dim)
-#                 for img_dim, min_dim in zip(img.size, min_dimensions)
-#             ]
-#             if padded_size != list(img.size):  # assert hypothesis
-#                 padded_im = Image.new("L", padded_size, 255)
-#                 padded_im.paste(img, img.getbbox())
-#                 img = padded_im
-#         return img
-
-#     def norm_img_latexocr(self, img):
-#         # CAN only predict gray scale image
-#         shape = (1, 1, 3)
-#         mean = [0.7931, 0.7931, 0.7931]
-#         std = [0.1738, 0.1738, 0.1738]
-#         scale = np.float32(1.0 / 255.0)
-#         min_dimensions = [32, 32]
-#         max_dimensions = [672, 192]
-#         mean = np.array(mean).reshape(shape).astype("float32")
-#         std = np.array(std).reshape(shape).astype("float32")
-
-#         im_h, im_w = img.shape[:2]
-#         if (
-#             min_dimensions[0] <= im_w <= max_dimensions[0]
-#             and min_dimensions[1] <= im_h <= max_dimensions[1]
-#         ):
-#             pass
-#         else:
-#             img = Image.fromarray(np.uint8(img))
-#             img = self.minmax_size_(self.pad_(img), max_dimensions, min_dimensions)
-#             img = np.array(img)
-#             im_h, im_w = img.shape[:2]
-#             img = np.dstack([img, img, img])
-#         img = (img.astype("float32") * scale - mean) / std
-#         img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
-#         divide_h = math.ceil(im_h / 16) * 16
-#         divide_w = math.ceil(im_w / 16) * 16
-#         img = np.pad(
-#             img, ((0, divide_h - im_h), (0, divide_w - im_w)), constant_values=(1, 1)
-#         )
-#         img = img[:, :, np.newaxis].transpose(2, 0, 1)
-#         img = img.astype("float32")
-#         return img
-
-#     def apply(self, data):
-#         """apply"""
-#         data[K.IMAGE] = self.norm_img_latexocr(data[K.IMAGE])
-#         return data
-
-#     @classmethod
-#     def get_input_keys(cls):
-#         """get input keys"""
-#         return [K.IMAGE, K.ORI_IM_SIZE]
-
-#     @classmethod
-#     def get_output_keys(cls):
-#         """get output keys"""
-#         return [K.IMAGE]
+class LaTeXOCRReisizeNormImg(BaseComponent):
+    """for ocr image resize and normalization"""
+
+    INPUT_KEYS = "img"
+    OUTPUT_KEYS = "img"
+    DEAULT_INPUTS = {"img": "img"}
+    DEAULT_OUTPUTS = {"img": "img"}
+
+    def __init__(self, rec_image_shape=[3, 48, 320]):
+        super().__init__()
+        self.rec_image_shape = rec_image_shape
+
+    def pad_(self, img, divable=32):
+        threshold = 128
+        data = np.array(img.convert("LA"))
+        if data[..., -1].var() == 0:
+            data = (data[..., 0]).astype(np.uint8)
+        else:
+            data = (255 - data[..., -1]).astype(np.uint8)
+        data = (data - data.min()) / (data.max() - data.min()) * 255
+        if data.mean() > threshold:
+            # To invert the text to white
+            gray = 255 * (data < threshold).astype(np.uint8)
+        else:
+            gray = 255 * (data > threshold).astype(np.uint8)
+            data = 255 - data
+
+        coords = cv2.findNonZero(gray)  # Find all non-zero points (text)
+        a, b, w, h = cv2.boundingRect(coords)  # Find minimum spanning bounding box
+        rect = data[b : b + h, a : a + w]
+        im = Image.fromarray(rect).convert("L")
+        dims = []
+        for x in [w, h]:
+            div, mod = divmod(x, divable)
+            dims.append(divable * (div + (1 if mod > 0 else 0)))
+        padded = Image.new("L", dims, 255)
+        padded.paste(im, (0, 0, im.size[0], im.size[1]))
+        return padded
+
+    def minmax_size_(
+        self,
+        img,
+        max_dimensions,
+        min_dimensions,
+    ):
+        if max_dimensions is not None:
+            ratios = [a / b for a, b in zip(img.size, max_dimensions)]
+            if any([r > 1 for r in ratios]):
+                size = np.array(img.size) // max(ratios)
+                img = img.resize(tuple(size.astype(int)), Image.BILINEAR)
+        if min_dimensions is not None:
+            # hypothesis: there is a dim in img smaller than min_dimensions, and return a proper dim >= min_dimensions
+            padded_size = [
+                max(img_dim, min_dim)
+                for img_dim, min_dim in zip(img.size, min_dimensions)
+            ]
+            if padded_size != list(img.size):  # assert hypothesis
+                padded_im = Image.new("L", padded_size, 255)
+                padded_im.paste(img, img.getbbox())
+                img = padded_im
+        return img
+
+    def norm_img_latexocr(self, img):
+        # CAN only predict gray scale image
+        shape = (1, 1, 3)
+        mean = [0.7931, 0.7931, 0.7931]
+        std = [0.1738, 0.1738, 0.1738]
+        scale = np.float32(1.0 / 255.0)
+        min_dimensions = [32, 32]
+        max_dimensions = [672, 192]
+        mean = np.array(mean).reshape(shape).astype("float32")
+        std = np.array(std).reshape(shape).astype("float32")
+
+        im_h, im_w = img.shape[:2]
+        if (
+            min_dimensions[0] <= im_w <= max_dimensions[0]
+            and min_dimensions[1] <= im_h <= max_dimensions[1]
+        ):
+            pass
+        else:
+            img = Image.fromarray(np.uint8(img))
+            img = self.minmax_size_(self.pad_(img), max_dimensions, min_dimensions)
+            img = np.array(img)
+            im_h, im_w = img.shape[:2]
+            img = np.dstack([img, img, img])
+        img = (img.astype("float32") * scale - mean) / std
+        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+        divide_h = math.ceil(im_h / 16) * 16
+        divide_w = math.ceil(im_w / 16) * 16
+        img = np.pad(
+            img, ((0, divide_h - im_h), (0, divide_w - im_w)), constant_values=(1, 1)
+        )
+        img = img[:, :, np.newaxis].transpose(2, 0, 1)
+        img = img.astype("float32")
+        return img
+
+    def apply(self, img):
+        """apply"""
+        img = self.norm_img_latexocr(img)
+        return {"img": img}
 
 
 class BaseRecLabelDecode(BaseComponent):
@@ -301,134 +296,59 @@ class CTCLabelDecode(BaseRecLabelDecode):
         return character_list
 
 
-# class LaTeXOCRDecode(object):
-#     """Convert between latex-symbol and symbol-index"""
-
-#     def __init__(self, post_process_cfg=None, **kwargs):
-#         assert post_process_cfg["name"] == "LaTeXOCRDecode"
-
-#         super(LaTeXOCRDecode, self).__init__()
-#         character_list = post_process_cfg["character_dict"]
-#         temp_path = tempfile.gettempdir()
-#         rec_char_dict_path = os.path.join(temp_path, "latexocr_tokenizer.json")
-#         try:
-#             with open(rec_char_dict_path, "w") as f:
-#                 json.dump(character_list, f)
-#         except Exception as e:
-#             print(f"创建 latexocr_tokenizer.json 文件失败, 原因{str(e)}")
-#         self.tokenizer = TokenizerFast.from_file(rec_char_dict_path)
-
-#     def post_process(self, s):
-#         text_reg = r"(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})"
-#         letter = "[a-zA-Z]"
-#         noletter = "[\W_^\d]"
-#         names = [x[0].replace(" ", "") for x in re.findall(text_reg, s)]
-#         s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
-#         news = s
-#         while True:
-#             s = news
-#             news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, noletter), r"\1\2", s)
-#             news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, letter), r"\1\2", news)
-#             news = re.sub(r"(%s)\s+?(%s)" % (letter, noletter), r"\1\2", news)
-#             if news == s:
-#                 break
-#         return s
-
-#     def decode(self, tokens):
-#         if len(tokens.shape) == 1:
-#             tokens = tokens[None, :]
-
-#         dec = [self.tokenizer.decode(tok) for tok in tokens]
-#         dec_str_list = [
-#             "".join(detok.split(" "))
-#             .replace("Ġ", " ")
-#             .replace("[EOS]", "")
-#             .replace("[BOS]", "")
-#             .replace("[PAD]", "")
-#             .strip()
-#             for detok in dec
-#         ]
-#         return [str(self.post_process(dec_str)) for dec_str in dec_str_list]
-
-#     def __call__(self, data):
-#         preds = data[K.REC_PROBS]
-#         text = self.decode(preds)
-#         data[K.REC_TEXT] = text[0]
-#         return data
-
-
-# class SaveTextRecResults(BaseComponent):
-#     """SaveTextRecResults"""
-
-#     _TEXT_REC_RES_SUFFIX = "_text_rec"
-#     _FILE_EXT = ".txt"
-
-#     def __init__(self, save_dir):
-#         super().__init__()
-#         self.save_dir = save_dir
-#         # We use python backend to save text object
-#         self._writer = TextWriter(backend="python")
-
-#     def apply(self, data):
-#         """apply"""
-#         ori_path = data[K.IM_PATH]
-#         file_name = os.path.basename(ori_path)
-#         file_name = self._replace_ext(file_name, self._FILE_EXT)
-#         text_rec_res_save_path = os.path.join(self.save_dir, file_name)
-#         rec_res = ""
-#         for text, score in zip(data[K.REC_TEXT], data[K.REC_SCORE]):
-#             line = text + "\t" + str(score) + "\n"
-#             rec_res += line
-#         text_rec_res_save_path = self._add_suffix(
-#             text_rec_res_save_path, self._TEXT_REC_RES_SUFFIX
-#         )
-#         self._write_txt(text_rec_res_save_path, rec_res)
-#         return data
-
-#     @classmethod
-#     def get_input_keys(cls):
-#         """get_input_keys"""
-#         return [K.IM_PATH, K.REC_TEXT, K.REC_SCORE]
-
-#     @classmethod
-#     def get_output_keys(cls):
-#         """get_output_keys"""
-#         return []
-
-#     def _write_txt(self, path, txt_str):
-#         """_write_txt"""
-#         if os.path.exists(path):
-#             logging.warning(f"{path} already exists. Overwriting it.")
-#         self._writer.write(path, txt_str)
-
-#     @staticmethod
-#     def _add_suffix(path, suffix):
-#         """_add_suffix"""
-#         stem, ext = os.path.splitext(path)
-#         return stem + suffix + ext
-
-#     @staticmethod
-#     def _replace_ext(path, new_ext):
-#         """_replace_ext"""
-#         stem, _ = os.path.splitext(path)
-#         return stem + new_ext
-
-
-# class PrintResult(BaseComponent):
-#     """Print Result Transform"""
-
-#     def apply(self, data):
-#         """apply"""
-#         logging.info("The prediction result is:")
-#         logging.info(data[K.REC_TEXT])
-#         return data
-
-#     @classmethod
-#     def get_input_keys(cls):
-#         """get input keys"""
-#         return [K.REC_TEXT]
-
-#     @classmethod
-#     def get_output_keys(cls):
-#         """get output keys"""
-#         return []
+class LaTeXOCRDecode(BaseComponent):
+    """Convert between latex-symbol and symbol-index"""
+
+    INPUT_KEYS = ["pred"]
+    OUTPUT_KEYS = ["rec_text"]
+    DEAULT_INPUTS = {"pred": "pred"}
+    DEAULT_OUTPUTS = {"rec_text": "rec_text"}
+
+    def __init__(self, character_list=None):
+        super().__init__()
+        character_list = character_list
+        temp_path = tempfile.gettempdir()
+        rec_char_dict_path = os.path.join(temp_path, "latexocr_tokenizer.json")
+        try:
+            with open(rec_char_dict_path, "w") as f:
+                json.dump(character_list, f)
+        except Exception as e:
+            print(f"创建 latexocr_tokenizer.json 文件失败, 原因{str(e)}")
+        self.tokenizer = TokenizerFast.from_file(rec_char_dict_path)
+
+    def post_process(self, s):
+        text_reg = r"(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})"
+        letter = "[a-zA-Z]"
+        noletter = "[\W_^\d]"
+        names = [x[0].replace(" ", "") for x in re.findall(text_reg, s)]
+        s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
+        news = s
+        while True:
+            s = news
+            news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, noletter), r"\1\2", s)
+            news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, letter), r"\1\2", news)
+            news = re.sub(r"(%s)\s+?(%s)" % (letter, noletter), r"\1\2", news)
+            if news == s:
+                break
+        return s
+
+    def decode(self, tokens):
+        if len(tokens.shape) == 1:
+            tokens = tokens[None, :]
+
+        dec = [self.tokenizer.decode(tok) for tok in tokens]
+        dec_str_list = [
+            "".join(detok.split(" "))
+            .replace("Ġ", " ")
+            .replace("[EOS]", "")
+            .replace("[BOS]", "")
+            .replace("[PAD]", "")
+            .strip()
+            for detok in dec
+        ]
+        return [str(self.post_process(dec_str)) for dec_str in dec_str_list]
+
+    def apply(self, pred):
+        preds = np.array(pred[0])
+        text = self.decode(preds)
+        return {"rec_text": text[0]}

+ 1 - 0
paddlex/inference/models/__init__.py

@@ -31,6 +31,7 @@ from .ts_cls import TSClsPredictor
 from .image_unwarping import WarpPredictor
 from .multilabel_classification import MLClasPredictor
 from .anomaly_detection import UadPredictor
+from .formula_recognition import LaTeXOCRPredictor
 
 
 def _create_hp_predictor(

+ 55 - 0
paddlex/inference/models/formula_recognition.py

@@ -0,0 +1,55 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+
+from ...modules.formula_recognition.model_list import MODELS
+from ..components import *
+from ..results import TextRecResult
+from .base import BasicPredictor
+
+
+class LaTeXOCRPredictor(BasicPredictor):
+
+    entities = MODELS
+
+    def _build_components(self):
+        self._add_component(
+            [
+                ReadImage(format="RGB"),
+                LaTeXOCRReisizeNormImg(),
+            ]
+        )
+
+        predictor = ImagePredictor(
+            model_dir=self.model_dir,
+            model_prefix=self.MODEL_FILE_PREFIX,
+            option=self.pp_option,
+        )
+        self._add_component(predictor)
+
+        op = self.build_postprocess(**self.config["PostProcess"])
+        self._add_component(op)
+
+    def build_postprocess(self, **kwargs):
+        if kwargs.get("name") == "LaTeXOCRDecode":
+            return LaTeXOCRDecode(
+                character_list=kwargs.get("character_dict"),
+            )
+        else:
+            raise Exception()
+
+    def _pack_res(self, single):
+        keys = ["img_path", "rec_text"]
+        return TextRecResult({key: single[key] for key in keys})

+ 17 - 0
paddlex/modules/formula_recognition/model_list.py

@@ -0,0 +1,17 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+MODELS = [
+    "LaTeX_OCR_rec",
+]

+ 0 - 1
paddlex/modules/text_recognition/model_list.py

@@ -17,5 +17,4 @@ MODELS = [
     "PP-OCRv4_server_rec",
     "ch_SVTRv2_rec",
     "ch_RepSVTR_rec",
-    "LaTeX_OCR_rec",
 ]

+ 2 - 1
paddlex/pipelines/OCR.yaml

@@ -8,5 +8,6 @@ Global:
 Pipeline:
   det_model: PP-OCRv4_mobile_det
   rec_model: PP-OCRv4_mobile_rec
-
+  batch_size: 1
+  device: "gpu"
 ######################################## Support ########################################

+ 11 - 0
paddlex/pipelines/anomaly_detection.yaml

@@ -0,0 +1,11 @@
+Global:
+  pipeline_name: anomaly_detection
+  input: https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/uad_grid.png
+  
+######################################## Setting ########################################
+# Please select the model from bellow `Support`
+
+Pipeline:
+  model: STFPM
+  batch_size: 1
+  device: "gpu"

+ 2 - 2
paddlex/pipelines/image_classification.yaml

@@ -7,8 +7,8 @@ Global:
 
 Pipeline:
   model: PP-LCNet_x0_5
-  device: cpu
-
+  batch_size: 1
+  device: "gpu:0"
 ######################################## Support ########################################
 NOTE:
   device: 

+ 2 - 1
paddlex/pipelines/instance_segmentation.yaml

@@ -7,7 +7,8 @@ Global:
 
 Pipeline:
   model: Mask-RT-DETR-S
-
+  batch_size: 1
+  device: "gpu"
 ######################################## Support ########################################
 NOTE:
   device: 

+ 2 - 1
paddlex/pipelines/object_detection.yaml

@@ -7,7 +7,8 @@ Global:
 
 Pipeline:
   model: PicoDet-S
-
+  batch_size: 1
+  device: "gpu"
 ######################################## Support ########################################
 NOTE:
   device: 

+ 2 - 1
paddlex/pipelines/semantic_segmentation.yaml

@@ -7,7 +7,8 @@ Global:
 
 Pipeline:
   model: PP-LiteSeg-T
-
+  batch_size: 1
+  device: "gpu"
 ######################################## Support ########################################
 NOTE:
   device: 

+ 11 - 0
paddlex/pipelines/small_object_detection.yaml

@@ -0,0 +1,11 @@
+Global:
+  pipeline_name: small_object_detection
+  input: https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/small_object_detection.jpg
+  
+######################################## Setting ########################################
+# Please select the model from bellow `Support`
+
+Pipeline:
+  model: PP-YOLOE_plus_SOD-L
+  batch_size: 1
+  device: "gpu"

+ 9 - 1
paddlex/repo_apis/PaddleClas_api/cls/runner.py

@@ -195,8 +195,16 @@ def _extract_eval_metrics(stdout: str) -> dict:
         r"\[Eval\]\[Epoch 0\]\[Avg\].*recall1: (_dp), recall5: (_dp), mAP: (_dp)".replace(
             "_dp", _DP
         ),
+        r"\[Eval\]\[Epoch 0\]\[Avg\].*MultiLabelMAP\(integral\): (_dp)".replace(
+            "_dp", _DP
+        ),
+    ]
+    keys = [
+        ["val.top1"],
+        ["val.top1", "val.top5"],
+        ["recall1", "recall5", "mAP"],
+        ["MultiLabelMAP"],
     ]
-    keys = [["val.top1"], ["val.top1", "val.top5"], ["recall1", "recall5", "mAP"]]
 
     metric_dict = dict()
     for pattern, key in zip(patterns, keys):