test_pp_chatocrv4.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from paddlex import create_pipeline
  15. pipeline = create_pipeline(pipeline="PP-ChatOCRv4-doc")
  16. img_path = "./test_samples/研报2_11.jpg"
  17. key_list = ["三位一体养老生态系统包含哪些"]
  18. # img_path = "./test_samples/财报1.pdf"
  19. # key_list = ['公司全称是什么']
  20. def load_mllm_results():
  21. """load mllm results"""
  22. import json
  23. predict_file_path = "/paddle/icode/baidu/paddlex_closed/evaluation/pipelines/ppchatocr/backend_predict_files/predict_mix_doc_v1_2B-1209.json"
  24. mllm_predict_dict = {}
  25. with open(predict_file_path, "r") as fin:
  26. predict_infos_list = json.load(fin)
  27. for predict_infos in predict_infos_list:
  28. img_name = predict_infos["image_path"]
  29. predict_info_list = predict_infos["predict_info_list"]
  30. for predict_info in predict_info_list:
  31. key = img_name + "_" + predict_info["question"]
  32. mllm_predict_dict[key] = predict_info
  33. return mllm_predict_dict
  34. mllm_predict_dict_all = load_mllm_results()
  35. visual_predict_res = pipeline.visual_predict(
  36. img_path,
  37. use_doc_orientation_classify=False,
  38. use_doc_unwarping=False,
  39. use_common_ocr=True,
  40. use_seal_recognition=True,
  41. use_table_recognition=True,
  42. )
  43. visual_info_list = []
  44. for res in visual_predict_res:
  45. visual_info_list.append(res["visual_info"])
  46. layout_parsing_result = res["layout_parsing_result"]
  47. print(layout_parsing_result)
  48. layout_parsing_result.print()
  49. layout_parsing_result.save_to_img("./output")
  50. layout_parsing_result.save_to_json("./output")
  51. layout_parsing_result.save_to_xlsx("./output")
  52. layout_parsing_result.save_to_html("./output")
  53. pipeline.save_visual_info_list(
  54. visual_info_list, "./res_visual_info/tmp_visual_info.json"
  55. )
  56. visual_info_list = pipeline.load_visual_info_list(
  57. "./res_visual_info/tmp_visual_info.json"
  58. )
  59. vector_info = pipeline.build_vector(visual_info_list, flag_save_bytes_vector=True)
  60. pipeline.save_vector(vector_info, "./res_visual_info/tmp_vector_info.json")
  61. vector_info = pipeline.load_vector("./res_visual_info/tmp_vector_info.json")
  62. mllm_predict_dict = {}
  63. image_name = img_path.split("/")[-1]
  64. for key in key_list:
  65. mllm_predict_key = image_name + "_" + key
  66. mllm_result = ""
  67. if mllm_predict_key in mllm_predict_dict_all:
  68. mllm_result = mllm_predict_dict_all[mllm_predict_key]["predicts"]
  69. mllm_predict_dict[key] = mllm_result
  70. chat_result = pipeline.chat(
  71. key_list,
  72. visual_info_list,
  73. vector_info=vector_info,
  74. mllm_predict_dict=mllm_predict_dict,
  75. )
  76. print(chat_result)