瀏覽代碼

fix bug and align with fastdploy

zhouchangda 1 年之前
父節點
當前提交
b975c82667

+ 22 - 18
paddlex/inference/components/llm/erniebot.py

@@ -164,24 +164,28 @@ class ErnieBot(BaseLLM):
 
     def caculate_similar(self, vector, key_list, llm_params=None, sleep_time=0.5):
         """caculate similar with key and doc"""
-        if self.is_vector_store(vector):
-            # XXX: The initialization parameters are hard-coded.
-            if llm_params:
-                api_type = llm_params.get("api_type")
-                access_token = llm_params.get("access_token")
-                ak = llm_params.get("ak")
-                sk = llm_params.get("sk")
-            else:
-                api_type = self.config["api_type"]
-                access_token = self.config.get("access_token")
-                ak = self.config.get("ak")
-                sk = self.config.get("sk")
-            if api_type == "aistudio":
-                embeddings = ErnieEmbeddings(aistudio_access_token=access_token)
-            elif api_type == "qianfan":
-                embeddings = QianfanEmbeddingsEndpoint(qianfan_ak=ak, qianfan_sk=sk)
-            else:
-                raise ValueError(f"Unsupported api_type: {api_type}")
+        if not self.is_vector_store(vector):
+            logging.warning(
+                "The retrieved vectorstore is not for PaddleX and will return vectorstore directly"
+            )
+            return vector
+        # XXX: The initialization parameters are hard-coded.
+        if llm_params:
+            api_type = llm_params.get("api_type")
+            access_token = llm_params.get("access_token")
+            ak = llm_params.get("ak")
+            sk = llm_params.get("sk")
+        else:
+            api_type = self.config["api_type"]
+            access_token = self.config.get("access_token")
+            ak = self.config.get("ak")
+            sk = self.config.get("sk")
+        if api_type == "aistudio":
+            embeddings = ErnieEmbeddings(aistudio_access_token=access_token)
+        elif api_type == "qianfan":
+            embeddings = QianfanEmbeddingsEndpoint(qianfan_ak=ak, qianfan_sk=sk)
+        else:
+            raise ValueError(f"Unsupported api_type: {api_type}")
 
         vectorstore = vectorstores.FAISS.deserialize_from_bytes(
             self.decode_vector_store(vector), embeddings

+ 2 - 2
paddlex/inference/components/task_related/clas.py

@@ -61,7 +61,7 @@ class Topk(BaseComponent):
                 label_name_list.append(self.class_id_map[i.item()])
         result = {
             "class_ids": clas_id_list,
-            "scores": np.around(score_list, decimals=5).tolist(),
+            "scores": np.around(score_list, decimals=5),
         }
         if label_name_list is not None:
             result["label_names"] = label_name_list
@@ -102,7 +102,7 @@ class MultiLabelThreshOutput(BaseComponent):
                 label_name_list.append(self.class_id_map[i.item()])
         result = {
             "class_ids": clas_id_list,
-            "scores": np.around(score_list, decimals=5).tolist(),
+            "scores": np.around(score_list, decimals=5),
         }
         if label_name_list is not None:
             result["label_names"] = label_name_list

+ 2 - 2
paddlex/inference/components/task_related/table_rec.py

@@ -129,7 +129,7 @@ class TableLabelDecode(BaseComponent):
                     bbox = self._bbox_decode(
                         bbox, padding_size[batch_idx], ori_img_size[batch_idx]
                     )
-                    bbox_list.append(bbox.tolist())
+                    bbox_list.append(bbox.astype(int))
                 structure_list.append(text)
                 score_list.append(structure_probs[batch_idx, idx])
             structure_batch_list.append(structure_list)
@@ -163,7 +163,7 @@ class TableLabelDecode(BaseComponent):
                 bbox = gt_bbox_list[batch_idx][idx]
                 if bbox.sum() != 0:
                     bbox = self._bbox_decode(bbox, shape_list[batch_idx])
-                    bbox_list.append(bbox.tolist())
+                    bbox_list.append(bbox.astype(int))
             structure_batch_list.append(structure_list)
             bbox_batch_list.append(bbox_list)
         return bbox_batch_list, structure_batch_list

+ 7 - 4
paddlex/inference/components/task_related/text_det.py

@@ -282,7 +282,7 @@ class DBPostProcess(BaseComponent):
             box[:, 1] = np.clip(
                 np.round(box[:, 1] / height * dest_height), 0, dest_height
             )
-            boxes.append(box.tolist())
+            boxes.append(box)
             scores.append(score)
         return boxes, scores
 
@@ -337,7 +337,10 @@ class DBPostProcess(BaseComponent):
         distance = poly.area * unclip_ratio / poly.length
         offset = pyclipper.PyclipperOffset()
         offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
-        expanded = np.array(offset.Execute(distance))
+        try:
+            expanded = np.array(offset.Execute(distance))
+        except ValueError:
+            expanded = np.array(offset.Execute(distance)[0])
         return expanded
 
     def get_mini_boxes(self, contour):
@@ -854,7 +857,7 @@ class CropByPolys(BaseComponent):
         if len(img.shape) == 2:
             img = np.stack((img,) * 3, axis=-1)
         img_crop, image = rectifier.run(img, new_points_list, mode="homography")
-        return img_crop[0]
+        return np.array(img_crop[0], dtype=np.uint8)
 
 
 class SortBoxes(BaseComponent):
@@ -889,4 +892,4 @@ class SortBoxes(BaseComponent):
                     _boxes[j + 1] = tmp
                 else:
                     break
-        return {"dt_polys": [box.tolist() for box in _boxes]}
+        return {"dt_polys": _boxes}

+ 5 - 0
paddlex/inference/pipelines/base.py

@@ -28,6 +28,11 @@ class BasePipeline(ABC, metaclass=AutoRegisterABCMetaClass):
         super().__init__()
         self._predictor_kwargs = {} if predictor_kwargs is None else predictor_kwargs
 
+    def set_predictor():
+        raise NotImplementedError(
+            "The method `set_predictor` has not been implemented yet."
+        )
+
     # alias the __call__() to predict()
     def __call__(self, *args, **kwargs):
         yield from self.predict(*args, **kwargs)

+ 67 - 41
paddlex/inference/pipelines/ppchatocrv3/ppchatocrv3.py

@@ -21,12 +21,7 @@ from copy import deepcopy
 from ...components import *
 from ..ocr import OCRPipeline
 from ....utils import logging
-from ...results import (
-    TableResult,
-    LayoutStructureResult,
-    VisualInfoResult,
-    ChatOCRResult,
-)
+from ...results import *
 from ...components.llm import ErnieBot
 from ...utils.io import ImageReader, PDFReader
 from ..table_recognition import TableRecPipeline
@@ -119,9 +114,6 @@ class PPChatOCRPipeline(TableRecPipeline):
         self.img_reader = ReadImage()
         self.visual_info = None
         self.vector = None
-        self._set_predictor(
-            oricls_batch_size, uvdoc_batch_size, curve_batch_size, device=device
-        )
 
     def _build_predictor(self):
         super()._build_predictor()
@@ -156,9 +148,29 @@ class PPChatOCRPipeline(TableRecPipeline):
                 batch_size=self.uvdoc_batch_size, device=self.device
             )
 
-    def _set_predictor(
-        self, curve_batch_size, oricls_batch_size, uvdoc_batch_size, device
+    def set_predictor(
+        self,
+        layout_batch_size=None,
+        text_det_batch_size=None,
+        text_rec_batch_size=None,
+        table_batch_size=None,
+        curve_batch_size=None,
+        oricls_batch_size=None,
+        uvdoc_batch_size=None,
+        device=None,
     ):
+        if text_det_batch_size and text_det_batch_size > 1:
+            logging.warning(
+                f"text det model only support batch_size=1 now,the setting of text_det_batch_size={text_det_batch_size} will not using! "
+            )
+        if layout_batch_size:
+            self.layout_predictor.set_predictor(batch_size=layout_batch_size)
+        if text_rec_batch_size:
+            self.ocr_pipeline.text_rec_model.set_predictor(
+                batch_size=text_rec_batch_size
+            )
+        if table_batch_size:
+            self.table_predictor.set_predictor(batch_size=table_batch_size)
         if self.curve_pipeline and curve_batch_size:
             self.curve_pipeline.text_det_model.set_predictor(
                 batch_size=curve_batch_size, device=device
@@ -189,12 +201,23 @@ class PPChatOCRPipeline(TableRecPipeline):
         return visual_result, visual_info
 
     def get_visual_result(self, inputs, **kwargs):
+        layout_batch_size = kwargs.get("layout_batch_size")
+        text_det_batch_size = kwargs.get("text_det_batch_size")
+        text_rec_batch_size = kwargs.get("text_rec_batch_size")
+        table_batch_size = kwargs.get("table_batch_size")
         curve_batch_size = kwargs.get("curve_batch_size")
         oricls_batch_size = kwargs.get("oricls_batch_size")
         uvdoc_batch_size = kwargs.get("uvdoc_batch_size")
         device = kwargs.get("device")
-        self._set_predictor(
-            curve_batch_size, oricls_batch_size, uvdoc_batch_size, device
+        self.set_predictor(
+            layout_batch_size,
+            text_det_batch_size,
+            text_rec_batch_size,
+            table_batch_size,
+            curve_batch_size,
+            oricls_batch_size,
+            uvdoc_batch_size,
+            device,
         )
         # get oricls and uvdoc results
         img_info_list = list(self.img_reader(inputs))[0]
@@ -229,13 +252,13 @@ class PPChatOCRPipeline(TableRecPipeline):
             single_img_res["input_path"] = layout_pred["input_path"]
             single_img_res["layout_result"] = layout_pred
             single_img = img_info["img"]
+            table_subs = []
+            curve_subs = []
+            structure_res = []
+            ocr_res_with_layout = []
             if len(layout_pred["boxes"]) > 0:
                 subs_of_img = list(self._crop_by_boxes(layout_pred))
-                # get cropped images with label "table"
-                table_subs = []
-                curve_subs = []
-                structure_res = []
-                ocr_res_with_layout = []
+                # get cropped images
                 for sub in subs_of_img:
                     box = sub["box"]
                     xmin, ymin, xmax, ymax = [int(i) for i in box]
@@ -284,7 +307,8 @@ class PPChatOCRPipeline(TableRecPipeline):
 
             all_curve_res = get_ocr_res(curve_pipeline, curve_subs)
             single_img_res["curve_result"] = all_curve_res
-
+            if isinstance(all_curve_res, dict):
+                all_curve_res = [all_curve_res]
             for sub, curve_res in zip(curve_subs, all_curve_res):
                 structure_res.append(
                     {
@@ -325,7 +349,7 @@ class PPChatOCRPipeline(TableRecPipeline):
             single_img_res["table_ocr_result"] = all_table_ocr_res
             single_img_res["structure_result"] = structure_res
 
-            yield ChatOCRResult(single_img_res)
+            yield VisualResult(single_img_res)
 
     def decode_visual_result(self, visual_result):
         ocr_text = []
@@ -375,7 +399,7 @@ class PPChatOCRPipeline(TableRecPipeline):
             logging.warning("Do not use ErnieBot, will not get vector text.")
             get_vector_flag = False
         if not any([visual_info, self.visual_info]):
-            return {"vector": None}
+            return VectorResult({"vector": None})
 
         if visual_info:
             # use for serving or local
@@ -406,7 +430,7 @@ class PPChatOCRPipeline(TableRecPipeline):
         else:
             text_result = str(ocr_text)
 
-        return {"vector": text_result}
+        return VectorResult({"vector": text_result})
 
     def get_retrieval_text(
         self,
@@ -419,7 +443,7 @@ class PPChatOCRPipeline(TableRecPipeline):
     ):
 
         if not any([visual_info, vector, self.visual_info, self.vector]):
-            return {"retrieval": None}
+            return RetrievalResult({"retrieval": None})
 
         key_list = format_key(key_list)
 
@@ -450,7 +474,7 @@ class PPChatOCRPipeline(TableRecPipeline):
                 vector=_vector, key_list=key_list, sleep_time=llm_request_interval
             )
 
-        return {"retrieval": retrieval}
+        return RetrievalResult({"retrieval": retrieval})
 
     def chat(
         self,
@@ -473,7 +497,9 @@ class PPChatOCRPipeline(TableRecPipeline):
         if not any(
             [vector, visual_info, retrieval_result, self.visual_info, self.vector]
         ):
-            return {"chat_res": "请先完成图像解析再开始再对话", "prompt": ""}
+            return ChatResult(
+                {"chat_res": "请先完成图像解析再开始再对话", "prompt": ""}
+            )
         key_list = format_key(key_list)
         # first get from table, then get from text in table, last get from all ocr
         if visual_info:
@@ -486,20 +512,6 @@ class PPChatOCRPipeline(TableRecPipeline):
         ocr_text = _visual_info["ocr_text"]
         html_list = _visual_info["table_html"]
         table_text_list = _visual_info["table_text"]
-        if retrieval_result:
-            ocr_text = retrieval_result
-        elif use_vector and any([visual_info, vector]):
-            # for serving or local
-            ocr_text = self.get_retrieval_text(
-                key_list=key_list,
-                visual_info=visual_info,
-                vector=vector,
-                llm_name=llm_name,
-                llm_params=llm_params,
-            )
-        else:
-            # for local
-            ocr_text = self.get_retrieval_text(key_list=key_list)
 
         prompt_res = {"ocr_prompt": "str", "table_prompt": [], "html_prompt": []}
 
@@ -530,6 +542,20 @@ class PPChatOCRPipeline(TableRecPipeline):
                         final_results[key] = value
         if len(key_list) > 0:
             logging.info("get result from ocr")
+            if retrieval_result:
+                ocr_text = retrieval_result
+            elif use_vector and any([visual_info, vector]):
+                # for serving or local
+                ocr_text = self.get_retrieval_text(
+                    key_list=key_list,
+                    visual_info=visual_info,
+                    vector=vector,
+                    llm_name=llm_name,
+                    llm_params=llm_params,
+                )
+            else:
+                # for local
+                ocr_text = self.get_retrieval_text(key_list=key_list)
             prompt = self.get_prompt_for_ocr(
                 ocr_text,
                 key_list,
@@ -545,9 +571,9 @@ class PPChatOCRPipeline(TableRecPipeline):
         if not res and not final_results:
             final_results = self.llm_api.ERROR_MASSAGE
         if save_prompt:
-            return {"chat_res": final_results, "prompt": prompt_res}
+            return ChatResult({"chat_res": final_results, "prompt": prompt_res})
         else:
-            return {"chat_res": final_results, "prompt": ""}
+            return ChatResult({"chat_res": final_results, "prompt": ""})
 
     def get_llm_result(self, prompt):
         """get llm result and decode to dict"""

+ 1 - 2
paddlex/inference/pipelines/ppchatocrv3/utils.py

@@ -51,8 +51,7 @@ def get_uvdoc_results(inputs, predictor):
     img_list = [img_info["img"] for img_info in inputs]
     for input, pred in zip(inputs, predictor(img_list)):
         results.append(pred)
-        img = np.array(pred["doctr_img"], dtype=np.uint8)
-        input["img"] = img
+        input["img"] = pred["doctr_img"]
     return results
 
 

+ 3 - 2
paddlex/inference/pipelines/table_recognition/table_recognition.py

@@ -49,7 +49,6 @@ class TableRecPipeline(BasePipeline):
         self.predictor_kwargs = predictor_kwargs
         super().__init__(predictor_kwargs=predictor_kwargs)
         self._build_predictor()
-        # self.set_predictor(layout_batch_size, text_det_batch_size,text_rec_batch_size, table_batch_size)
 
     def _build_predictor(
         self,
@@ -85,7 +84,9 @@ class TableRecPipeline(BasePipeline):
         if layout_batch_size:
             self.layout_predictor.set_predictor(batch_size=layout_batch_size)
         if text_rec_batch_size:
-            self.ocr_pipeline.rec_model.set_predictor(batch_size=text_rec_batch_size)
+            self.ocr_pipeline.text_rec_model.set_predictor(
+                batch_size=text_rec_batch_size
+            )
         if table_batch_size:
             self.table_predictor.set_predictor(batch_size=table_batch_size)
 

+ 64 - 1
paddlex/inference/results/chat_ocr.py

@@ -12,7 +12,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from pathlib import Path
 from .base import BaseResult
+from .utils.mixin import Base64Mixin
 
 
 class LayoutStructureResult(BaseResult):
@@ -27,7 +29,68 @@ class VisualInfoResult(BaseResult):
     pass
 
 
-class ChatOCRResult(BaseResult):
+class VisualResult(BaseResult):
+    """VisualInfoResult"""
+
+    def save_to_html(self, save_path):
+        if not save_path.lower().endswith(("html")):
+            input_path = self["input_path"]
+            save_path = Path(save_path) / f"{Path(input_path).stem}"
+        else:
+            save_path = Path(save_path).stem
+        for table_result in self["table_result"]:
+            table_result.save_to_html(save_path)
+
+    def save_to_xlsx(self, save_path):
+        if not save_path.lower().endswith(("xlsx")):
+            input_path = self["input_path"]
+            save_path = Path(save_path) / f"{Path(input_path).stem}"
+        else:
+            save_path = Path(save_path).stem
+        for table_result in self["table_result"]:
+            table_result.save_to_xlsx(save_path)
+
+    def save_to_img(self, save_path):
+        if not save_path.lower().endswith((".jpg", ".png")):
+            input_path = self["input_path"]
+            save_path = Path(save_path) / f"{Path(input_path).stem}"
+        else:
+            save_path = Path(save_path).stem
+
+        oricls_save_path = f"{save_path}_oricls.jpg"
+        oricls_result = self["oricls_result"]
+        oricls_result.save_to_img(oricls_save_path)
+        uvdoc_save_path = f"{save_path}_uvdoc.jpg"
+        uvdoc_result = self["uvdoc_result"]
+        uvdoc_result.save_to_img(uvdoc_save_path)
+        curve_save_path = f"{save_path}_curve.jpg"
+        for curve_result in self["curve_result"]:
+            curve_result.save_to_img(curve_save_path)
+        layout_save_path = f"{save_path}_layout.jpg"
+        layout_result = self["layout_result"]
+        layout_result.save_to_img(layout_save_path)
+        ocr_save_path = f"{save_path}_ocr.jpg"
+        table_save_path = f"{save_path}_table.jpg"
+        ocr_result = self["ocr_result"]
+        ocr_result.save_to_img(ocr_save_path)
+        for table_result in self["table_result"]:
+            table_result.save_to_img(table_save_path)
+
+
+class VectorResult(BaseResult, Base64Mixin):
+    """VisualInfoResult"""
+
+    def _to_base64(self):
+        return self["vector"]
+
+
+class RetrievalResult(BaseResult):
+    """VisualInfoResult"""
+
+    pass
+
+
+class ChatResult(BaseResult):
     """VisualInfoResult"""
 
     pass

+ 23 - 0
paddlex/inference/results/utils/mixin.py

@@ -27,6 +27,7 @@ from ...utils.io import (
     CSVWriter,
     HtmlWriter,
     XlsxWriter,
+    TextWriter,
 )
 
 
@@ -105,6 +106,28 @@ class JsonMixin:
         )
 
 
+class Base64Mixin:
+    def __init__(self, *args, **kwargs):
+        self._base64_writer = TextWriter(*args, **kwargs)
+        self._show_func_register()(self.save_to_base64)
+
+    @abstractmethod
+    def _to_base64(self):
+        raise NotImplementedError
+
+    @property
+    def base64(self):
+        return self._to_base64()
+
+    def save_to_base64(self, save_path, *args, **kwargs):
+        if not str(save_path).lower().endswith((".b64")):
+            fp = Path(self["input_path"])
+            save_path = Path(save_path) / f"{fp.stem}{fp.suffix}"
+        _save_list_data(
+            self._base64_writer.write, save_path, self.base64, *args, **kwargs
+        )
+
+
 class ImgMixin:
     def __init__(self, backend="pillow", *args, **kwargs):
         self._img_writer = ImageWriter(backend=backend, *args, **kwargs)

+ 4 - 4
paddlex/pipelines/PP-ChatOCRv3-doc.yaml

@@ -3,12 +3,12 @@ Global:
   input: https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/contract.pdf
   #input: https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/report.png
 Pipeline:
-  layout_model: PicoDet_layout_1x
+  layout_model: RT-DETR-H_layout_3cls
   table_model: SLANet_plus
-  text_det_model: PP-OCRv4_mobile_det
-  text_rec_model: PP-OCRv4_mobile_rec
+  text_det_model: PP-OCRv4_server_det
+  text_rec_model: PP-OCRv4_server_rec
   uvdoc_model: UVDoc
-  curve_model: PP-OCRv4_mobile_seal_det
+  curve_model: PP-OCRv4_server_seal_det
   oricls_model: PP-LCNet_x1_0_doc_ori
   llm_name: "ernie-3.5"
   llm_params: