فهرست منبع

support dont build models & bugfix

gaotingquan 1 سال پیش
والد
کامیت
1c2a6c8022
1فایلهای تغییر یافته به همراه54 افزوده شده و 48 حذف شده
  1. 54 48
      paddlex/inference/pipelines/ppchatocrv3/ppchatocrv3.py

+ 54 - 48
paddlex/inference/pipelines/ppchatocrv3/ppchatocrv3.py

@@ -60,31 +60,33 @@ class PPChatOCRPipeline(_TableRecPipeline):
         recovery=True,
         device=None,
         predictor_kwargs=None,
+        _build_models=True,
     ):
         super().__init__(
             predictor_kwargs=predictor_kwargs,
         )
-        self._build_predictor(
-            layout_model=layout_model,
-            text_det_model=text_det_model,
-            text_rec_model=text_rec_model,
-            table_model=table_model,
-            doc_image_ori_cls_model=doc_image_ori_cls_model,
-            doc_image_unwarp_model=doc_image_unwarp_model,
-            seal_text_det_model=seal_text_det_model,
-            llm_name=llm_name,
-            llm_params=llm_params,
-        )
-        self.set_predictor(
-            layout_batch_size=layout_batch_size,
-            text_det_batch_size=text_det_batch_size,
-            text_rec_batch_size=text_rec_batch_size,
-            table_batch_size=table_batch_size,
-            doc_image_ori_cls_batch_size=doc_image_ori_cls_batch_size,
-            doc_image_unwarp_batch_size=doc_image_unwarp_batch_size,
-            seal_text_det_batch_size=seal_text_det_batch_size,
-            device=device,
-        )
+        if _build_models:
+            self._build_predictor(
+                layout_model=layout_model,
+                text_det_model=text_det_model,
+                text_rec_model=text_rec_model,
+                table_model=table_model,
+                doc_image_ori_cls_model=doc_image_ori_cls_model,
+                doc_image_unwarp_model=doc_image_unwarp_model,
+                seal_text_det_model=seal_text_det_model,
+                llm_name=llm_name,
+                llm_params=llm_params,
+            )
+            self.set_predictor(
+                layout_batch_size=layout_batch_size,
+                text_det_batch_size=text_det_batch_size,
+                text_rec_batch_size=text_rec_batch_size,
+                table_batch_size=table_batch_size,
+                doc_image_ori_cls_batch_size=doc_image_ori_cls_batch_size,
+                doc_image_unwarp_batch_size=doc_image_unwarp_batch_size,
+                seal_text_det_batch_size=seal_text_det_batch_size,
+                device=device,
+            )
 
         # get base prompt from yaml info
         if task_prompt_yaml:
@@ -127,13 +129,13 @@ class PPChatOCRPipeline(_TableRecPipeline):
         else:
             self.curve_pipeline = None
         if doc_image_ori_cls_model:
-            self.oricls_predictor = self._create(doc_image_ori_cls_model)
+            self.doc_image_ori_cls_predictor = self._create(doc_image_ori_cls_model)
         else:
-            self.oricls_predictor = None
+            self.doc_image_ori_cls_predictor = None
         if doc_image_unwarp_model:
-            self.uvdoc_predictor = self._create(doc_image_unwarp_model)
+            self.doc_image_unwarp_predictor = self._create(doc_image_unwarp_model)
         else:
-            self.uvdoc_predictor = None
+            self.doc_image_unwarp_predictor = None
 
         self.img_reader = ReadImage(format="RGB")
         self.llm_api = create_llm_api(
@@ -169,19 +171,23 @@ class PPChatOCRPipeline(_TableRecPipeline):
             self.curve_pipeline.text_det_model.set_predictor(
                 batch_size=seal_text_det_batch_size
             )
-        if self.oricls_predictor and doc_image_ori_cls_batch_size:
-            self.oricls_predictor.set_predictor(batch_size=doc_image_ori_cls_batch_size)
-        if self.uvdoc_predictor and doc_image_unwarp_batch_size:
-            self.uvdoc_predictor.set_predictor(batch_size=doc_image_unwarp_batch_size)
+        if self.doc_image_ori_cls_predictor and doc_image_ori_cls_batch_size:
+            self.doc_image_ori_cls_predictor.set_predictor(
+                batch_size=doc_image_ori_cls_batch_size
+            )
+        if self.doc_image_unwarp_predictor and doc_image_unwarp_batch_size:
+            self.doc_image_unwarp_predictor.set_predictor(
+                batch_size=doc_image_unwarp_batch_size
+            )
 
         if device:
             if self.curve_pipeline:
                 self.curve_pipeline.set_predictor(device=device)
-            if self.oricls_predictor:
-                self.oricls_predictor.set_predictor(device=device)
-            if self.uvdoc_predictor:
-                self.uvdoc_predictor.set_predictor(device=device)
-            self.layout_batch_size.set_predictor(device=device)
+            if self.doc_image_ori_cls_predictor:
+                self.doc_image_ori_cls_predictor.set_predictor(device=device)
+            if self.doc_image_unwarp_predictor:
+                self.doc_image_unwarp_predictor.set_predictor(device=device)
+            self.layout_predictor.set_predictor(device=device)
             self.ocr_pipeline.set_predictor(device=device)
 
     def predict(self, *args, **kwargs):
@@ -200,10 +206,6 @@ class PPChatOCRPipeline(_TableRecPipeline):
         **kwargs,
     ):
         self.set_predictor(**kwargs)
-        if self.uvdoc_predictor and uvdoc_batch_size:
-            self.uvdoc_predictor.set_predictor(
-                batch_size=uvdoc_batch_size, device=device
-            )
 
         visual_info = {"ocr_text": [], "table_html": [], "table_text": []}
         # get all visual result
@@ -237,14 +239,18 @@ class PPChatOCRPipeline(_TableRecPipeline):
         use_seal_text_det_model=True,
         recovery=True,
     ):
-        # get oricls and uvdoc results
+        # get oricls and unwarp results
         img_info_list = list(self.img_reader(inputs))[0]
         oricls_results = []
-        if self.oricls_predictor and use_doc_image_ori_cls_model:
-            oricls_results = get_oriclas_results(img_info_list, self.oricls_predictor)
-        uvdoc_results = []
-        if self.uvdoc_predictor and use_doc_image_unwarp_model:
-            uvdoc_results = get_uvdoc_results(img_info_list, self.uvdoc_predictor)
+        if self.doc_image_ori_cls_predictor and use_doc_image_ori_cls_model:
+            oricls_results = get_oriclas_results(
+                img_info_list, self.doc_image_ori_cls_predictor
+            )
+        unwarp_results = []
+        if self.doc_image_unwarp_predictor and use_doc_image_unwarp_model:
+            unwarp_results = get_unwarp_results(
+                img_info_list, self.doc_image_unwarp_predictor
+            )
         img_list = [img_info["img"] for img_info in img_info_list]
         for idx, (img_info, layout_pred) in enumerate(
             zip(img_info_list, self.layout_predictor(img_list))
@@ -257,14 +263,14 @@ class PPChatOCRPipeline(_TableRecPipeline):
                 "table_result": StructureTableResult([]),
                 "structure_result": [],
                 "oricls_result": TopkResult({}),
-                "uvdoc_result": DocTrResult({}),
+                "unwarp_result": DocTrResult({}),
                 "curve_result": [],
             }
-            # update oricls and uvdoc result
+            # update oricls and unwarp results
             if oricls_results:
                 single_img_res["oricls_result"] = oricls_results[idx]
-            if uvdoc_results:
-                single_img_res["uvdoc_result"] = uvdoc_results[idx]
+            if unwarp_results:
+                single_img_res["unwarp_result"] = unwarp_results[idx]
             # update layout result
             single_img_res["input_path"] = layout_pred["input_path"]
             single_img_res["layout_result"] = layout_pred