Procházet zdrojové kódy

bugfix & support CI for chatocr

gaotingquan před 1 rokem
rodič
revize
3c90557980

+ 8 - 3
docs/pipeline_usage/tutorials/information_extration_pipelines/document_scene_information_extraction.md

@@ -203,14 +203,20 @@ for res in visual_result:
     res.save_to_html('./output')
     res.save_to_xlsx('./output')
 
-chat_result = pipeline.chat(["乙方", "手机号"])
+vector = pipeline.build_vector(visual_info=visual_info)
+
+chat_result = pipeline.chat(
+    key_list=["乙方", "手机号"],
+    visual_info=visual_info,
+    vector=vector,
+    )
 chat_result.print()
 ```
 **注**:请先在[百度云千帆平台](https://console.bce.baidu.com/qianfan/ais/console/onlineService)获取自己的ak与sk(详细流程请参考[AK和SK鉴权调用API流程](https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Hlwerugt8)),将ak与sk填入至指定位置后才能正常调用大模型。
 
 运行后,输出结果如下:
 
-```
+```python
 {'chat_res': {'乙方': '股份测试有限公司', '手机号': '19331729920'}, 'prompt': ''}
 ```
 
@@ -729,4 +735,3 @@ predict = create_pipeline( pipeline="PP-ChatOCRv3-doc",
                             device = "npu:0" )
 ```
 若您想在更多种类的硬件上使用通用文档场景信息抽取产线,请参考[PaddleX多硬件使用指南](../../../other_devices_support/multi_devices_use_guide.md)。
-

+ 8 - 3
docs/pipeline_usage/tutorials/information_extration_pipelines/document_scene_information_extraction_en.md

@@ -199,14 +199,20 @@ for res in visual_result:
     res.save_to_html('./output')
     res.save_to_xlsx('./output')
 
-chat_result = pipeline.chat(["乙方", "手机号"])
+vector = pipeline.build_vector(visual_info=visual_info)
+
+chat_result = pipeline.chat(
+    key_list=["乙方", "手机号"],
+    visual_info=visual_info,
+    vector=vector,
+    )
 chat_result.print()
 ```
 **Note**: Please first obtain your ak and sk on the [Baidu Cloud Qianfan Platform](https://console.bce.baidu.com/qianfan/ais/console/onlineService) (for detailed steps, please refer to the [AK and SK Authentication API Call Process](https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Hlwerugt8)), and fill in your ak and sk to the specified locations to enable normal calls to the large model.
 
 After running, the output is as follows:
 
-```
+```python
 {'chat_res': {'乙方': '股份测试有限公司', '手机号': '19331729920'}, 'prompt': ''}
 ```
 
@@ -623,4 +629,3 @@ predict = create_pipeline(
 ```
 
 If you want to use the PP-ChatOCRv3-doc Pipeline on more types of hardware, please refer to the [PaddleX Multi-Device Usage Guide](../../../installation/multi_devices_use_guide_en.md).
-

+ 3 - 0
paddlex/inference/components/llm/__init__.py

@@ -16,6 +16,9 @@ from .erniebot import ErnieBot
 
 
 def create_llm_api(model_name: str, params={}) -> BaseLLM:
+    # for CI
+    if model_name == "paddlex_ci":
+        return
     return BaseLLM.get(model_name)(
         model_name=model_name,
         params=params,

+ 27 - 60
paddlex/inference/pipelines/ppchatocrv3/ppchatocrv3.py

@@ -422,16 +422,9 @@ class PPChatOCRPipeline(_TableRecPipeline):
         if not any([visual_info, self.visual_info]):
             return VectorResult({"vector": None})
 
-        if visual_info:
-            # use for serving or local
-            _visual_info = visual_info
-        else:
-            # use for local
-            _visual_info = self.visual_info
-
-        ocr_text = _visual_info["ocr_text"]
-        html_list = _visual_info["table_html"]
-        table_text_list = _visual_info["table_text"]
+        ocr_text = visual_info["ocr_text"]
+        html_list = visual_info["table_html"]
+        table_text_list = visual_info["table_text"]
 
         # add table text to ocr text
         for html, table_text_rec in zip(html_list, table_text_list):
@@ -457,36 +450,16 @@ class PPChatOCRPipeline(_TableRecPipeline):
     def retrieval(
         self,
         key_list,
-        visual_info=None,
-        vector=None,
+        vector,
         llm_name=None,
         llm_params={},
         llm_request_interval=0.1,
     ):
-
-        if not any([visual_info, vector, self.visual_info, self.vector]):
-            return RetrievalResult({"retrieval": None})
-
+        assert "vector" in vector
         key_list = format_key(key_list)
 
-        is_seving = visual_info and llm_name
-
-        if self.visual_flag and not is_seving:
-            self.vector = self.build_vector()
-
-        if not any([vector, self.vector]):
-            logging.warning(
-                "The vector library is not created, and is being created automatically"
-            )
-            if is_seving:
-                # for serving
-                vector = self.build_vector(
-                    llm_name=llm_name, llm_params=llm_params, visual_info=visual_info
-                )
-            else:
-                self.vector = self.build_vector()
-
-        if vector and llm_name:
+        # for serving
+        if llm_name:
             _vector = vector["vector"]
             llm_api = create_llm_api(llm_name, llm_params)
             retrieval = llm_api.caculate_similar(
@@ -496,7 +469,7 @@ class PPChatOCRPipeline(_TableRecPipeline):
                 sleep_time=llm_request_interval,
             )
         else:
-            _vector = self.vector["vector"]
+            _vector = vector["vector"]
             retrieval = self.llm_api.caculate_similar(
                 vector=_vector, key_list=key_list, sleep_time=llm_request_interval
             )
@@ -512,33 +485,24 @@ class PPChatOCRPipeline(_TableRecPipeline):
         user_task_description="",
         rules="",
         few_shot="",
-        use_retrieval=True,
         save_prompt=False,
-        llm_name="ernie-3.5",
+        llm_name=None,
         llm_params={},
     ):
         """
         chat with key
 
         """
-        if not any(
-            [vector, visual_info, retrieval_result, self.visual_info, self.vector]
-        ):
+        if not any([vector, visual_info, retrieval_result]):
             return ChatResult(
                 {"chat_res": "请先完成图像解析再开始再对话", "prompt": ""}
             )
         key_list = format_key(key_list)
         # first get from table, then get from text in table, last get from all ocr
-        if visual_info:
-            # use for serving or local
-            _visual_info = visual_info
-        else:
-            # use for local
-            _visual_info = self.visual_info
 
-        ocr_text = _visual_info["ocr_text"]
-        html_list = _visual_info["table_html"]
-        table_text_list = _visual_info["table_text"]
+        ocr_text = visual_info["ocr_text"]
+        html_list = visual_info["table_html"]
+        table_text_list = visual_info["table_text"]
 
         prompt_res = {"ocr_prompt": "str", "table_prompt": [], "html_prompt": []}
 
@@ -571,18 +535,21 @@ class PPChatOCRPipeline(_TableRecPipeline):
             logging.debug("get result from ocr")
             if retrieval_result:
                 ocr_text = retrieval_result.get("retrieval")
-            elif use_retrieval and any([visual_info, vector]):
-                # for serving or local
-                ocr_text = self.retrieval(
-                    key_list=key_list,
-                    visual_info=visual_info,
-                    vector=vector,
-                    llm_name=llm_name,
-                    llm_params=llm_params,
-                )["retrieval"]
-            else:
+            elif vector:
+                # for serving
+                if llm_name:
+                    ocr_text = self.retrieval(
+                        key_list=key_list,
+                        vector=vector,
+                        llm_name=llm_name,
+                        llm_params=llm_params,
+                    )["retrieval"]
                 # for local
-                ocr_text = self.retrieval(key_list=key_list)["retrieval"]
+                else:
+                    ocr_text = self.retrieval(key_list=key_list, vector=vector)[
+                        "retrieval"
+                    ]
+
             prompt = self.get_prompt_for_ocr(
                 ocr_text,
                 key_list,

+ 1 - 1
paddlex/inference/pipelines/ppchatocrv3/utils.py

@@ -46,7 +46,7 @@ def get_oriclas_results(inputs, predictor):
     return results
 
 
-def get_uvdoc_results(inputs, predictor):
+def get_unwarp_results(inputs, predictor):
     results = []
     img_list = [img_info["img"] for img_info in inputs]
     for input, pred in zip(inputs, predictor(img_list)):

+ 8 - 8
paddlex/inference/results/chat_ocr.py

@@ -63,31 +63,31 @@ class VisualResult(BaseResult):
             oricls_result._HARD_FLAG = True
             oricls_result.save_to_img(oricls_save_path)
         uvdoc_save_path = f"{save_path}_uvdoc.jpg"
-        uvdoc_result = self["uvdoc_result"]
-        if uvdoc_result:
-            # uvdoc_result._HARD_FLAG = True
-            uvdoc_result.save_to_img(uvdoc_save_path)
+        unwarp_result = self["unwarp_result"]
+        if unwarp_result:
+            # unwarp_result._HARD_FLAG = True
+            unwarp_result.save_to_img(uvdoc_save_path)
         curve_save_path = f"{save_path}_curve.jpg"
         curve_results = self["curve_result"]
         # TODO(): support list of result
         if isinstance(curve_results, dict):
             curve_results = [curve_results]
         for curve_result in curve_results:
-            curve_result._HARD_FLAG = True if not uvdoc_result else False
+            curve_result._HARD_FLAG = True if not unwarp_result else False
             curve_result.save_to_img(curve_save_path)
         layout_save_path = f"{save_path}_layout.jpg"
         layout_result = self["layout_result"]
         if layout_result:
-            layout_result._HARD_FLAG = True if not uvdoc_result else False
+            layout_result._HARD_FLAG = True if not unwarp_result else False
             layout_result.save_to_img(layout_save_path)
         ocr_save_path = f"{save_path}_ocr.jpg"
         table_save_path = f"{save_path}_table.jpg"
         ocr_result = self["ocr_result"]
         if ocr_result:
-            ocr_result._HARD_FLAG = True if not uvdoc_result else False
+            ocr_result._HARD_FLAG = True if not unwarp_result else False
             ocr_result.save_to_img(ocr_save_path)
         for table_result in self["table_result"]:
-            table_result._HARD_FLAG = True if not uvdoc_result else False
+            table_result._HARD_FLAG = True if not unwarp_result else False
             table_result.save_to_img(table_save_path)
 
 

+ 3 - 3
paddlex/pipelines/PP-ChatOCRv3-doc.yaml

@@ -21,7 +21,7 @@ Pipeline:
   text_det_batch_size: 1
   text_rec_batch_size: 1
   table_batch_size: 1
-  uvdoc_batch_size: 1
-  curve_batch_size: 1
-  oricls_batch_size: 1
+  doc_image_ori_cls_batch_size: 1
+  doc_image_unwarp_batch_size: 1
+  seal_text_det_batch_size: 1
   recovery: True