changdazhou 1 anno fa
parent
commit
c3ebe9d1a2

+ 9 - 2
paddlex/inference/pipelines/ppchatocrv3/ppchatocrv3.py

@@ -236,7 +236,14 @@ class PPChatOCRPipeline(_TableRecPipeline):
         recovery=True,
     ):
         # get oricls and unwarp results
-        img_info_list = list(self.img_reader(inputs))[0]
+        if isinstance(inputs, str):
+            img_info_list = list(self.img_reader(inputs))[0]
+        elif isinstance(inputs, list):
+            assert not any(
+                s.endswith(".pdf") for s in inputs
+            ), "List containing pdf is not supported; only a list of images or a single PDF is supported."
+            img_info_list = [x[0] for x in list(self.img_reader(inputs))]
+
         oricls_results = []
         if self.doc_image_ori_cls_predictor and use_doc_image_ori_cls_model:
             oricls_results = get_oriclas_results(
@@ -576,7 +583,7 @@ class PPChatOCRPipeline(_TableRecPipeline):
                 user_task_description,
             )
             logging.debug(prompt)
-            prompt_res["ocr_prompt"] = prompt
+            prompt_res["ocr_prompt"] = [prompt]
             res = self.get_llm_result(llm_api, prompt)
             if res:
                 final_results.update(res)

+ 4 - 1
paddlex/inference/results/chat_ocr.py

@@ -36,7 +36,10 @@ class VisualResult(BaseResult):
     def __init__(self, data, page_id=None, src_input_name=None):
         super().__init__(data)
         self.page_id = page_id
-        self.src_input_name = src_input_name
+        if isinstance(src_input_name, list):
+            self.src_input_name = src_input_name[page_id]
+        else:
+            self.src_input_name = src_input_name
 
     def _to_str(self, _, *args, **kwargs):
         return super()._to_str(

+ 0 - 1
paddlex/modules/multilabel_classification/dataset_checker/dataset_src/utils/visualizer.py

@@ -144,7 +144,6 @@ def draw_multi_label(image, label, label_map_dict):
     draw = ImageDraw.Draw(new_image)
     font_color = tuple(font_colormap(3))
     for i, text in enumerate(text_lines):
-        text_width, _ = font.getsize(text)
         draw.text(
             (0, image_height + i * int(row_height * 1.2)),
             text,