Эх сурвалжийг харах

use common ReadImage for PP-ChatOCRv3

zhouchangda 1 жил өмнө
parent
commit
d13053e430

+ 2 - 2
paddlex/inference/components/transforms/image/common.py

@@ -111,7 +111,7 @@ class ReadImage(_BaseRead):
             file_list = self._get_files_list(file_path)
             batch = []
             for file_path in file_list:
-                img = self._read_img(file_path)
+                img = self._read(file_path)
                 batch.extend(img)
                 if len(batch) >= self.batch_size:
                     yield batch
@@ -127,7 +127,7 @@ class ReadImage(_BaseRead):
             )
 
     def _read(self, file_path):
-        if file_path:
+        if str(file_path).lower().endswith(".pdf"):
             return self._read_pdf(file_path)
         else:
             return self._read_img(file_path)

+ 9 - 29
paddlex/inference/pipelines/ppchatocrv3/ppchatocrv3.py

@@ -117,7 +117,6 @@ class PPChatOCRPipeline(TableRecPipeline):
             self.user_prompt_dict = None
         self.recovery = recovery
         self.img_reader = ReadImage()
-        self.pdf_reader = PDFReader()
         self.visual_info = None
         self.vector = None
         self._set_predictor(
@@ -197,37 +196,17 @@ class PPChatOCRPipeline(TableRecPipeline):
         self._set_predictor(
             curve_batch_size, oricls_batch_size, uvdoc_batch_size, device
         )
-        input_imgs = []
-        img_list = []
-        for file in inputs:
-            if isinstance(file, str) and file.endswith(".pdf"):
-                img_list = self.pdf_reader.read(file)
-                for page, img in enumerate(img_list):
-                    input_imgs.append(
-                        {
-                            "input_path": f"{Path(file).parent}/{Path(file).stem}_{page}.jpg",
-                            "img": img,
-                        }
-                    )
-            else:
-                for imgs in self.img_reader(file):
-                    input_imgs.extend(imgs)
         # get oricls and uvdoc results
+        img_info_list = list(self.img_reader(inputs))[0]
         oricls_results = []
         if self.oricls_predictor and kwargs.get("use_oricls_model", True):
-            img_list = [img["img"] for img in input_imgs]
-            oricls_results = get_oriclas_results(
-                input_imgs, self.oricls_predictor, img_list
-            )
+            oricls_results = get_oriclas_results(img_info_list, self.oricls_predictor)
         uvdoc_results = []
         if self.uvdoc_predictor and kwargs.get("use_uvdoc_model", True):
-            img_list = [img["img"] for img in input_imgs]
-            uvdoc_results = get_uvdoc_results(
-                input_imgs, self.uvdoc_predictor, img_list
-            )
-        img_list = [img["img"] for img in input_imgs]
-        for idx, (input_img, layout_pred) in enumerate(
-            zip(input_imgs, self.layout_predictor(img_list))
+            uvdoc_results = get_uvdoc_results(img_info_list, self.uvdoc_predictor)
+        img_list = [img_info["img"] for img_info in img_info_list]
+        for idx, (img_info, layout_pred) in enumerate(
+            zip(img_info_list, self.layout_predictor(img_list))
         ):
             single_img_res = {
                 "input_path": "",
@@ -249,7 +228,7 @@ class PPChatOCRPipeline(TableRecPipeline):
             # update layout result
             single_img_res["input_path"] = layout_pred["input_path"]
             single_img_res["layout_result"] = layout_pred
-            single_img = input_img["img"]
+            single_img = img_info["img"]
             if len(layout_pred["boxes"]) > 0:
                 subs_of_img = list(self._crop_by_boxes(layout_pred))
                 # get cropped images with label "table"
@@ -561,7 +540,8 @@ class PPChatOCRPipeline(TableRecPipeline):
             logging.debug(prompt)
             prompt_res["ocr_prompt"] = prompt
             res = self.get_llm_result(prompt)
-            final_results.update(res)
+            if res:
+                final_results.update(res)
         if not res and not final_results:
             final_results = self.llm_api.ERROR_MASSAGE
         if save_prompt:

+ 6 - 3
paddlex/inference/pipelines/ppchatocrv3/utils.py

@@ -36,8 +36,9 @@ def get_ocr_res(pipeline, input):
         return ocr_res_list
 
 
-def get_oriclas_results(inputs, predictor, img_list):
+def get_oriclas_results(inputs, predictor):
     results = []
+    img_list = [img_info["img"] for img_info in inputs]
     for input, pred in zip(inputs, predictor(img_list)):
         results.append(pred)
         angle = int(pred["label_names"][0])
@@ -45,11 +46,13 @@ def get_oriclas_results(inputs, predictor, img_list):
     return results
 
 
-def get_uvdoc_results(inputs, predictor, img_list):
+def get_uvdoc_results(inputs, predictor):
     results = []
+    img_list = [img_info["img"] for img_info in inputs]
     for input, pred in zip(inputs, predictor(img_list)):
         results.append(pred)
-        input["img"] = np.array(pred["doctr_img"], dtype=np.uint8)
+        img = np.array(pred["doctr_img"], dtype=np.uint8)
+        input["img"] = img
     return results