gaotingquan 1 yıl önce
ebeveyn
işleme
e6564bd8fc

+ 12 - 12
paddlex/inference/pipelines/ppchatocrv3/ppchatocrv3.py

@@ -65,9 +65,9 @@ class PPChatOCRPipeline(TableRecPipeline):
         self.text_det_model = text_det_model
         self.text_rec_model = text_rec_model
         self.table_model = table_model
-        self.oricls_model = oricls_model
-        self.uvdoc_model = uvdoc_model
-        self.curve_model = curve_model
+        self.doc_image_ori_cls_model = doc_image_ori_cls_model
+        self.doc_image_unwarp_model = doc_image_unwarp_model
+        self.seal_text_det_model = seal_text_det_model
         self.llm_name = llm_name
         self.llm_params = llm_params
         self.task_prompt_yaml = task_prompt_yaml
@@ -118,9 +118,9 @@ class PPChatOCRPipeline(TableRecPipeline):
 
     def _build_predictor(self):
         super()._build_predictor()
-        if self.curve_model:
+        if self.seal_text_det_model:
             self.curve_pipeline = OCRPipeline(
-                text_det_model=self.curve_model,
+                text_det_model=self.seal_text_det_model,
                 text_rec_model=self.text_rec_model,
                 text_det_batch_size=self.text_det_batch_size,
                 text_rec_batch_size=self.text_rec_batch_size,
@@ -128,12 +128,12 @@ class PPChatOCRPipeline(TableRecPipeline):
             )
         else:
             self.curve_pipeline = None
-        if self.oricls_model:
-            self.oricls_predictor = self._create_model(self.oricls_model)
+        if self.doc_image_ori_cls_model:
+            self.oricls_predictor = self._create_model(self.doc_image_ori_cls_model)
         else:
             self.oricls_predictor = None
-        if self.uvdoc_model:
-            self.uvdoc_predictor = self._create_model(self.uvdoc_model)
+        if self.doc_image_unwarp_model:
+            self.uvdoc_predictor = self._create_model(self.doc_image_unwarp_model)
         else:
             self.uvdoc_predictor = None
         if self.curve_pipeline and self.curve_batch_size:
@@ -224,10 +224,10 @@ class PPChatOCRPipeline(TableRecPipeline):
         # get oricls and uvdoc results
         img_info_list = list(self.img_reader(inputs))[0]
         oricls_results = []
-        if self.oricls_predictor and kwargs.get("use_oricls_model", True):
+        if self.oricls_predictor and kwargs.get("use_doc_image_ori_cls_model", True):
             oricls_results = get_oriclas_results(img_info_list, self.oricls_predictor)
         uvdoc_results = []
-        if self.uvdoc_predictor and kwargs.get("use_uvdoc_model", True):
+        if self.uvdoc_predictor and kwargs.get("use_doc_image_unwarp_model", True):
             uvdoc_results = get_uvdoc_results(img_info_list, self.uvdoc_predictor)
         img_list = [img_info["img"] for img_info in img_info_list]
         for idx, (img_info, layout_pred) in enumerate(
@@ -303,7 +303,7 @@ class PPChatOCRPipeline(TableRecPipeline):
                         single_img[ymin:ymax, xmin:xmax, :] = 255
 
             curve_pipeline = self.ocr_pipeline
-            if self.curve_pipeline and kwargs.get("use_curve_model", True):
+            if self.curve_pipeline and kwargs.get("use_seal_text_det_model", True):
                 curve_pipeline = self.curve_pipeline
 
             all_curve_res = get_ocr_res(curve_pipeline, curve_subs)