test_pp_chatocrv4.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from paddlex import create_pipeline
  15. pipeline = create_pipeline(pipeline="PP-ChatOCRv4-doc")
  16. img_path = "./test_samples/研报2_11.jpg"
  17. key_list = ["三位一体养老生态系统包含哪些"]
  18. # img_path = "./test_samples/财报1.pdf"
  19. # key_list = ['公司全称是什么']
  20. def load_mllm_results():
  21. """load mllm results"""
  22. import json
  23. predict_file_path = "/paddle/icode/baidu/paddlex_closed/evaluation/pipelines/ppchatocr/backend_predict_files/predict_mix_doc_v1_2B-1209.json"
  24. mllm_predict_dict = {}
  25. with open(predict_file_path, "r") as fin:
  26. predict_infos_list = json.load(fin)
  27. for predict_infos in predict_infos_list:
  28. img_name = predict_infos["image_path"]
  29. predict_info_list = predict_infos["predict_info_list"]
  30. for predict_info in predict_info_list:
  31. key = img_name + "_" + predict_info["question"]
  32. mllm_predict_dict[key] = predict_info
  33. return mllm_predict_dict
  34. mllm_predict_dict_all = load_mllm_results()
  35. visual_predict_res = pipeline.visual_predict(
  36. img_path,
  37. use_doc_orientation_classify=False,
  38. use_doc_unwarping=False,
  39. use_common_ocr=True,
  40. use_seal_recognition=True,
  41. use_table_recognition=True,
  42. )
  43. # ####[TODO] 增加类别信息
  44. visual_info_list = []
  45. for res in visual_predict_res:
  46. # res['layout_parsing_result'].save_results("./output/")
  47. # print(res["visual_info"])
  48. visual_info_list.append(res["visual_info"])
  49. pipeline.save_visual_info_list(
  50. visual_info_list, "./res_visual_info/tmp_visual_info.json"
  51. )
  52. visual_info_list = pipeline.load_visual_info_list(
  53. "./res_visual_info/tmp_visual_info.json"
  54. )
  55. vector_info = pipeline.build_vector(visual_info_list)
  56. pipeline.save_vector(vector_info, "./res_visual_info/tmp_vector_info.json")
  57. vector_info = pipeline.load_vector("./res_visual_info/tmp_vector_info.json")
  58. mllm_predict_dict = {}
  59. image_name = img_path.split("/")[-1]
  60. for key in key_list:
  61. mllm_predict_key = image_name + "_" + key
  62. mllm_result = ""
  63. if mllm_predict_key in mllm_predict_dict_all:
  64. mllm_result = mllm_predict_dict_all[mllm_predict_key]["predicts"]
  65. mllm_predict_dict[key] = mllm_result
  66. chat_result = pipeline.chat(
  67. key_list,
  68. visual_info_list,
  69. vector_info=vector_info,
  70. mllm_predict_dict=mllm_predict_dict,
  71. )
  72. print(chat_result)