ppchatocrv3.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import re
  16. import json
  17. import numpy as np
  18. from .utils import *
  19. from ...results import *
  20. from copy import deepcopy
  21. from ...components import *
  22. from ..ocr import OCRPipeline
  23. from ....utils import logging
  24. from ...components.llm import ErnieBot
  25. from ..table_recognition import _TableRecPipeline
  26. from ...components.llm import create_llm_api, ErnieBot
  27. from ....utils.file_interface import read_yaml_file
  28. from ..table_recognition.utils import convert_4point2rect, get_ori_coordinate_for_table
  29. PROMPT_FILE = os.path.join(os.path.dirname(__file__), "ch_prompt.yaml")
  30. class PPChatOCRPipeline(_TableRecPipeline):
  31. """PP-ChatOCRv3 Pileline"""
  32. entities = "PP-ChatOCRv3-doc"
  33. def __init__(
  34. self,
  35. layout_model,
  36. text_det_model,
  37. text_rec_model,
  38. table_model,
  39. doc_image_ori_cls_model=None,
  40. doc_image_unwarp_model=None,
  41. seal_text_det_model=None,
  42. llm_name="ernie-3.5",
  43. llm_params={},
  44. task_prompt_yaml=None,
  45. user_prompt_yaml=None,
  46. layout_batch_size=1,
  47. text_det_batch_size=1,
  48. text_rec_batch_size=1,
  49. table_batch_size=1,
  50. doc_image_ori_cls_batch_size=1,
  51. doc_image_unwarp_batch_size=1,
  52. seal_text_det_batch_size=1,
  53. recovery=True,
  54. device=None,
  55. predictor_kwargs=None,
  56. _build_models=True,
  57. ):
  58. super().__init__(device, predictor_kwargs)
  59. if _build_models:
  60. self._build_predictor(
  61. layout_model=layout_model,
  62. text_det_model=text_det_model,
  63. text_rec_model=text_rec_model,
  64. table_model=table_model,
  65. doc_image_ori_cls_model=doc_image_ori_cls_model,
  66. doc_image_unwarp_model=doc_image_unwarp_model,
  67. seal_text_det_model=seal_text_det_model,
  68. llm_name=llm_name,
  69. llm_params=llm_params,
  70. )
  71. self.set_predictor(
  72. layout_batch_size=layout_batch_size,
  73. text_det_batch_size=text_det_batch_size,
  74. text_rec_batch_size=text_rec_batch_size,
  75. table_batch_size=table_batch_size,
  76. doc_image_ori_cls_batch_size=doc_image_ori_cls_batch_size,
  77. doc_image_unwarp_batch_size=doc_image_unwarp_batch_size,
  78. seal_text_det_batch_size=seal_text_det_batch_size,
  79. )
  80. else:
  81. self.llm_api = create_llm_api(
  82. llm_name,
  83. llm_params,
  84. )
  85. # get base prompt from yaml info
  86. if task_prompt_yaml:
  87. self.task_prompt_dict = read_yaml_file(task_prompt_yaml)
  88. else:
  89. self.task_prompt_dict = read_yaml_file(
  90. PROMPT_FILE
  91. ) # get user prompt from yaml info
  92. if user_prompt_yaml:
  93. self.user_prompt_dict = read_yaml_file(user_prompt_yaml)
  94. else:
  95. self.user_prompt_dict = None
  96. self.recovery = recovery
  97. self.visual_info = None
  98. self.vector = None
  99. self.visual_flag = False
  100. def _build_predictor(
  101. self,
  102. layout_model,
  103. text_det_model,
  104. text_rec_model,
  105. table_model,
  106. llm_name,
  107. llm_params,
  108. seal_text_det_model=None,
  109. doc_image_ori_cls_model=None,
  110. doc_image_unwarp_model=None,
  111. ):
  112. super()._build_predictor(
  113. layout_model, text_det_model, text_rec_model, table_model
  114. )
  115. if seal_text_det_model:
  116. self.curve_pipeline = self._create(
  117. pipeline=OCRPipeline,
  118. text_det_model=seal_text_det_model,
  119. text_rec_model=text_rec_model,
  120. )
  121. else:
  122. self.curve_pipeline = None
  123. if doc_image_ori_cls_model:
  124. self.doc_image_ori_cls_predictor = self._create(doc_image_ori_cls_model)
  125. else:
  126. self.doc_image_ori_cls_predictor = None
  127. if doc_image_unwarp_model:
  128. self.doc_image_unwarp_predictor = self._create(doc_image_unwarp_model)
  129. else:
  130. self.doc_image_unwarp_predictor = None
  131. self.img_reader = ReadImage(format="BGR")
  132. self.llm_api = create_llm_api(
  133. llm_name,
  134. llm_params,
  135. )
  136. self.cropper = CropByBoxes()
  137. def set_predictor(
  138. self,
  139. layout_batch_size=None,
  140. text_det_batch_size=None,
  141. text_rec_batch_size=None,
  142. table_batch_size=None,
  143. doc_image_ori_cls_batch_size=None,
  144. doc_image_unwarp_batch_size=None,
  145. seal_text_det_batch_size=None,
  146. device=None,
  147. ):
  148. if text_det_batch_size and text_det_batch_size > 1:
  149. logging.warning(
  150. f"text det model only support batch_size=1 now,the setting of text_det_batch_size={text_det_batch_size} will not using! "
  151. )
  152. if layout_batch_size:
  153. self.layout_predictor.set_predictor(batch_size=layout_batch_size)
  154. if text_rec_batch_size:
  155. self.ocr_pipeline.text_rec_model.set_predictor(
  156. batch_size=text_rec_batch_size
  157. )
  158. if table_batch_size:
  159. self.table_predictor.set_predictor(batch_size=table_batch_size)
  160. if self.curve_pipeline and seal_text_det_batch_size:
  161. self.curve_pipeline.text_det_model.set_predictor(
  162. batch_size=seal_text_det_batch_size
  163. )
  164. if self.doc_image_ori_cls_predictor and doc_image_ori_cls_batch_size:
  165. self.doc_image_ori_cls_predictor.set_predictor(
  166. batch_size=doc_image_ori_cls_batch_size
  167. )
  168. if self.doc_image_unwarp_predictor and doc_image_unwarp_batch_size:
  169. self.doc_image_unwarp_predictor.set_predictor(
  170. batch_size=doc_image_unwarp_batch_size
  171. )
  172. if device:
  173. if self.curve_pipeline:
  174. self.curve_pipeline.set_predictor(device=device)
  175. if self.doc_image_ori_cls_predictor:
  176. self.doc_image_ori_cls_predictor.set_predictor(device=device)
  177. if self.doc_image_unwarp_predictor:
  178. self.doc_image_unwarp_predictor.set_predictor(device=device)
  179. self.layout_predictor.set_predictor(device=device)
  180. self.ocr_pipeline.set_predictor(device=device)
  181. def predict(self, *args, **kwargs):
  182. logging.error(
  183. "PP-ChatOCRv3-doc Pipeline do not support to call `predict()` directly! Please call `visual_predict(input)` firstly to get visual prediction of `input` and call `chat(key_list)` to get the result of query specified by `key_list`."
  184. )
  185. return
  186. def visual_predict(
  187. self,
  188. input,
  189. use_doc_image_ori_cls_model=True,
  190. use_doc_image_unwarp_model=True,
  191. use_seal_text_det_model=True,
  192. recovery=True,
  193. **kwargs,
  194. ):
  195. self.set_predictor(**kwargs)
  196. visual_info = {"ocr_text": [], "table_html": [], "table_text": []}
  197. # get all visual result
  198. visual_result = list(
  199. self.get_visual_result(
  200. input,
  201. use_doc_image_ori_cls_model=use_doc_image_ori_cls_model,
  202. use_doc_image_unwarp_model=use_doc_image_unwarp_model,
  203. use_seal_text_det_model=use_seal_text_det_model,
  204. recovery=recovery,
  205. )
  206. )
  207. # decode visual result to get table_html, table_text, ocr_text
  208. ocr_text, table_text, table_html = self.decode_visual_result(visual_result)
  209. visual_info["ocr_text"] = ocr_text
  210. visual_info["table_html"] = table_html
  211. visual_info["table_text"] = table_text
  212. visual_info = VisualInfoResult(visual_info)
  213. # for local user save visual info in self
  214. self.visual_info = visual_info
  215. self.visual_flag = True
  216. return visual_result, visual_info
  217. def get_visual_result(
  218. self,
  219. inputs,
  220. use_doc_image_ori_cls_model=True,
  221. use_doc_image_unwarp_model=True,
  222. use_seal_text_det_model=True,
  223. recovery=True,
  224. ):
  225. # get oricls and unwarp results
  226. if isinstance(inputs, str):
  227. img_info_list = list(self.img_reader(inputs))[0]
  228. elif isinstance(inputs, list):
  229. assert not any(
  230. isinstance(s, str) and s.endswith(".pdf") for s in inputs
  231. ), "List containing pdf is not supported; only a list of images or a single PDF is supported."
  232. img_info_list = [x[0] for x in list(self.img_reader(inputs))]
  233. oricls_results = []
  234. if self.doc_image_ori_cls_predictor and use_doc_image_ori_cls_model:
  235. oricls_results = get_oriclas_results(
  236. img_info_list, self.doc_image_ori_cls_predictor
  237. )
  238. unwarp_results = []
  239. if self.doc_image_unwarp_predictor and use_doc_image_unwarp_model:
  240. unwarp_results = get_unwarp_results(
  241. img_info_list, self.doc_image_unwarp_predictor
  242. )
  243. img_list = [img_info["img"] for img_info in img_info_list]
  244. for idx, (img_info, layout_pred) in enumerate(
  245. zip(img_info_list, self.layout_predictor(img_list))
  246. ):
  247. page_id = idx
  248. single_img_res = {
  249. "input_path": "",
  250. "layout_result": DetResult({}),
  251. "ocr_result": OCRResult({}),
  252. "table_ocr_result": [],
  253. "table_result": StructureTableResult([]),
  254. "layout_parsing_result": {},
  255. "oricls_result": TopkResult({}),
  256. "unwarp_result": DocTrResult({}),
  257. "curve_result": [],
  258. }
  259. # update oricls and unwarp results
  260. if oricls_results:
  261. single_img_res["oricls_result"] = oricls_results[idx]
  262. if unwarp_results:
  263. single_img_res["unwarp_result"] = unwarp_results[idx]
  264. # update layout result
  265. single_img_res["input_path"] = layout_pred["input_path"]
  266. single_img_res["layout_result"] = layout_pred
  267. single_img = img_info["img"]
  268. table_subs = []
  269. curve_subs = []
  270. structure_res = []
  271. ocr_res_with_layout = []
  272. if len(layout_pred["boxes"]) > 0:
  273. subs_of_img = list(self._crop_by_boxes(layout_pred))
  274. # get cropped images
  275. for sub in subs_of_img:
  276. box = sub["box"]
  277. xmin, ymin, xmax, ymax = [int(i) for i in box]
  278. mask_flag = True
  279. if sub["label"].lower() == "table":
  280. table_subs.append(sub)
  281. elif sub["label"].lower() == "seal":
  282. curve_subs.append(sub)
  283. else:
  284. if self.recovery and recovery:
  285. # TODO: Why use the entire image?
  286. wht_im = (
  287. np.ones(single_img.shape, dtype=single_img.dtype) * 255
  288. )
  289. wht_im[ymin:ymax, xmin:xmax, :] = sub["img"]
  290. sub_ocr_res = get_ocr_res(self.ocr_pipeline, wht_im)
  291. else:
  292. sub_ocr_res = get_ocr_res(self.ocr_pipeline, sub)
  293. sub_ocr_res["dt_polys"] = get_ori_coordinate_for_table(
  294. xmin, ymin, sub_ocr_res["dt_polys"]
  295. )
  296. layout_label = sub["label"].lower()
  297. if sub_ocr_res and sub["label"].lower() in [
  298. "image",
  299. "figure",
  300. "img",
  301. "fig",
  302. ]:
  303. mask_flag = False
  304. else:
  305. ocr_res_with_layout.append(sub_ocr_res)
  306. structure_res.append(
  307. {
  308. "layout_bbox": box,
  309. f"{layout_label}": "\n".join(
  310. sub_ocr_res["rec_text"]
  311. ),
  312. }
  313. )
  314. if mask_flag:
  315. single_img[ymin:ymax, xmin:xmax, :] = 255
  316. curve_pipeline = self.ocr_pipeline
  317. if self.curve_pipeline and use_seal_text_det_model:
  318. curve_pipeline = self.curve_pipeline
  319. all_curve_res = get_ocr_res(curve_pipeline, curve_subs)
  320. single_img_res["curve_result"] = all_curve_res
  321. if isinstance(all_curve_res, dict):
  322. all_curve_res = [all_curve_res]
  323. for sub, curve_res in zip(curve_subs, all_curve_res):
  324. dt_polys_list = [
  325. list(map(list, sublist)) for sublist in curve_res["dt_polys"]
  326. ]
  327. sorted_items = sorted(
  328. zip(dt_polys_list, curve_res["rec_text"]),
  329. key=lambda x: (x[0][0][1], x[0][0][0]),
  330. )
  331. _, sorted_text = zip(*sorted_items)
  332. structure_res.append(
  333. {
  334. "layout_bbox": sub["box"],
  335. "印章": " ".join(sorted_text),
  336. }
  337. )
  338. ocr_res = get_ocr_res(self.ocr_pipeline, single_img)
  339. ocr_res["input_path"] = layout_pred["input_path"]
  340. all_table_res, _ = self.get_table_result(table_subs)
  341. for idx, single_dt_poly in enumerate(ocr_res["dt_polys"]):
  342. structure_res.append(
  343. {
  344. "layout_bbox": convert_4point2rect(single_dt_poly),
  345. "words in text block": ocr_res["rec_text"][idx],
  346. }
  347. )
  348. # update ocr result
  349. for layout_ocr_res in ocr_res_with_layout:
  350. ocr_res["dt_polys"].extend(layout_ocr_res["dt_polys"])
  351. ocr_res["rec_text"].extend(layout_ocr_res["rec_text"])
  352. ocr_res["input_path"] = single_img_res["input_path"]
  353. all_table_ocr_res = []
  354. # get table text from html
  355. structure_res_table, all_table_ocr_res = get_table_text_from_html(
  356. all_table_res
  357. )
  358. structure_res.extend(structure_res_table)
  359. # sort the layout result by the left top point of the box
  360. structure_res = sorted_layout_boxes(structure_res, w=single_img.shape[1])
  361. structure_res = LayoutParsingResult(
  362. {
  363. "input_path": layout_pred["input_path"],
  364. "parsing_result": structure_res,
  365. }
  366. )
  367. single_img_res["table_result"] = all_table_res
  368. single_img_res["ocr_result"] = ocr_res
  369. single_img_res["table_ocr_result"] = all_table_ocr_res
  370. single_img_res["layout_parsing_result"] = structure_res
  371. single_img_res["layout_parsing_result"]["page_id"] = page_id + 1
  372. yield VisualResult(single_img_res, page_id, inputs)
  373. def decode_visual_result(self, visual_result):
  374. ocr_text = []
  375. table_text_list = []
  376. table_html = []
  377. for single_img_pred in visual_result:
  378. layout_res = single_img_pred["layout_parsing_result"]["parsing_result"]
  379. layout_res_copy = deepcopy(layout_res)
  380. # layout_res is [{"layout_bbox": [x1, y1, x2, y2], "layout": "single","words in text block":"xxx"}, {"layout_bbox": [x1, y1, x2, y2], "layout": "double","印章":"xxx"}
  381. ocr_res = {}
  382. for block in layout_res_copy:
  383. block.pop("layout_bbox")
  384. block.pop("layout")
  385. for layout_type, text in block.items():
  386. if text == "":
  387. continue
  388. # Table results are used separately
  389. if layout_type == "table":
  390. continue
  391. if layout_type not in ocr_res:
  392. ocr_res[layout_type] = text
  393. else:
  394. ocr_res[layout_type] += f"\n {text}"
  395. single_table_text = " ".join(single_img_pred["table_ocr_result"])
  396. for table_pred in single_img_pred["table_result"]:
  397. html = table_pred["html"]
  398. table_html.append(html)
  399. if ocr_res:
  400. ocr_text.append(ocr_res)
  401. table_text_list.append(single_table_text)
  402. return ocr_text, table_text_list, table_html
  403. def build_vector(
  404. self,
  405. llm_name=None,
  406. llm_params={},
  407. visual_info=None,
  408. min_characters=3500,
  409. llm_request_interval=1.0,
  410. ):
  411. """get vector for ocr"""
  412. if isinstance(self.llm_api, ErnieBot):
  413. get_vector_flag = True
  414. else:
  415. logging.warning("Do not use ErnieBot, will not get vector text.")
  416. get_vector_flag = False
  417. if not any([visual_info, self.visual_info]):
  418. return VectorResult({"vector": None})
  419. ocr_text = visual_info["ocr_text"]
  420. html_list = visual_info["table_html"]
  421. table_text_list = visual_info["table_text"]
  422. # add table text to ocr text
  423. for html, table_text_rec in zip(html_list, table_text_list):
  424. if len(html) > 3000:
  425. ocr_text.append({"table": table_text_rec})
  426. ocr_all_result = "".join(["\n".join(e.values()) for e in ocr_text])
  427. if len(ocr_all_result) > min_characters and get_vector_flag:
  428. if visual_info and llm_name:
  429. # for serving or local
  430. llm_api = create_llm_api(llm_name, llm_params)
  431. text_result = llm_api.get_vector(ocr_text, llm_request_interval)
  432. else:
  433. # for local
  434. text_result = self.llm_api.get_vector(ocr_text, llm_request_interval)
  435. else:
  436. text_result = str(ocr_text)
  437. self.visual_flag = False
  438. return VectorResult({"vector": text_result})
  439. def retrieval(
  440. self,
  441. key_list,
  442. vector,
  443. llm_name=None,
  444. llm_params={},
  445. llm_request_interval=0.1,
  446. ):
  447. assert "vector" in vector
  448. key_list = format_key(key_list)
  449. # for serving
  450. if llm_name:
  451. _vector = vector["vector"]
  452. llm_api = create_llm_api(llm_name, llm_params)
  453. retrieval = llm_api.caculate_similar(
  454. vector=_vector,
  455. key_list=key_list,
  456. llm_params=llm_params,
  457. sleep_time=llm_request_interval,
  458. )
  459. else:
  460. _vector = vector["vector"]
  461. retrieval = self.llm_api.caculate_similar(
  462. vector=_vector, key_list=key_list, sleep_time=llm_request_interval
  463. )
  464. return RetrievalResult({"retrieval": retrieval})
  465. def chat(
  466. self,
  467. key_list,
  468. vector=None,
  469. visual_info=None,
  470. retrieval_result=None,
  471. user_task_description="",
  472. rules="",
  473. few_shot="",
  474. save_prompt=False,
  475. llm_name=None,
  476. llm_params={},
  477. ):
  478. """
  479. chat with key
  480. """
  481. if not any([vector, visual_info, retrieval_result]):
  482. return ChatResult(
  483. {"chat_res": "请先完成图像解析再开始再对话", "prompt": ""}
  484. )
  485. key_list = format_key(key_list)
  486. # first get from table, then get from text in table, last get from all ocr
  487. ocr_text = visual_info["ocr_text"]
  488. html_list = visual_info["table_html"]
  489. table_text_list = visual_info["table_text"]
  490. prompt_res = {"ocr_prompt": "str", "table_prompt": [], "html_prompt": []}
  491. if llm_name:
  492. llm_api = create_llm_api(llm_name, llm_params)
  493. else:
  494. llm_api = self.llm_api
  495. final_results = {}
  496. failed_results = ["大模型调用失败", "未知", "未找到关键信息", "None", ""]
  497. if html_list:
  498. prompt_list = self.get_prompt_for_table(
  499. html_list, key_list, rules, few_shot
  500. )
  501. prompt_res["html_prompt"] = prompt_list
  502. for prompt, table_text in zip(prompt_list, table_text_list):
  503. logging.debug(prompt)
  504. res = self.get_llm_result(llm_api, prompt)
  505. # TODO: why use one html but the whole table_text in next step
  506. if not res or list(res.values())[0] in failed_results:
  507. logging.debug(
  508. "table html sequence is too much longer, using ocr directly!"
  509. )
  510. prompt = self.get_prompt_for_ocr(
  511. table_text, key_list, rules, few_shot, user_task_description
  512. )
  513. logging.debug(prompt)
  514. prompt_res["table_prompt"].append(prompt)
  515. res = self.get_llm_result(llm_api, prompt)
  516. for key, value in res.items():
  517. if value not in failed_results and key in key_list:
  518. key_list.remove(key)
  519. final_results[key] = value
  520. if len(key_list) > 0:
  521. logging.debug("get result from ocr")
  522. if retrieval_result:
  523. ocr_text = retrieval_result.get("retrieval")
  524. elif vector:
  525. # for serving
  526. if llm_name:
  527. ocr_text = self.retrieval(
  528. key_list=key_list,
  529. vector=vector,
  530. llm_name=llm_name,
  531. llm_params=llm_params,
  532. )["retrieval"]
  533. # for local
  534. else:
  535. ocr_text = self.retrieval(key_list=key_list, vector=vector)[
  536. "retrieval"
  537. ]
  538. prompt = self.get_prompt_for_ocr(
  539. ocr_text,
  540. key_list,
  541. rules,
  542. few_shot,
  543. user_task_description,
  544. )
  545. logging.debug(prompt)
  546. prompt_res["ocr_prompt"] = [prompt]
  547. res = self.get_llm_result(llm_api, prompt)
  548. if res:
  549. final_results.update(res)
  550. if not res and not final_results:
  551. final_results = {"error": llm_api.ERROR_MASSAGE}
  552. if save_prompt:
  553. return ChatResult({"chat_res": final_results, "prompt": prompt_res})
  554. else:
  555. return ChatResult({"chat_res": final_results, "prompt": ""})
  556. def get_llm_result(self, llm_api, prompt):
  557. """get llm result and decode to dict"""
  558. llm_result = llm_api.pred(prompt)
  559. # when the llm pred failed, return None
  560. if not llm_result:
  561. return {}
  562. if "json" in llm_result or "```" in llm_result:
  563. llm_result = (
  564. llm_result.replace("```", "").replace("json", "").replace("/n", "")
  565. )
  566. llm_result = llm_result.replace("[", "").replace("]", "")
  567. try:
  568. llm_result = json.loads(llm_result)
  569. llm_result_final = {}
  570. for key in llm_result:
  571. value = llm_result[key]
  572. if isinstance(value, list):
  573. if len(value) > 0:
  574. llm_result_final[key] = value[0]
  575. else:
  576. llm_result_final[key] = value
  577. return llm_result_final
  578. except:
  579. results = (
  580. llm_result.replace("\n", "")
  581. .replace(" ", "")
  582. .replace("{", "")
  583. .replace("}", "")
  584. )
  585. if not results.endswith('"'):
  586. results = results + '"'
  587. pattern = r'"(.*?)": "([^"]*)"'
  588. matches = re.findall(pattern, str(results))
  589. llm_result = {k: v for k, v in matches}
  590. return llm_result
  591. def get_prompt_for_table(self, table_result, key_list, rules="", few_shot=""):
  592. """get prompt for table"""
  593. prompt_key_information = []
  594. merge_table = ""
  595. for idx, result in enumerate(table_result):
  596. if len(merge_table + result) < 2000:
  597. merge_table += result
  598. if len(merge_table + result) > 2000 or idx == len(table_result) - 1:
  599. single_prompt = self.get_kie_prompt(
  600. merge_table,
  601. key_list,
  602. rules_str=rules,
  603. few_shot_demo_str=few_shot,
  604. prompt_type="table",
  605. )
  606. prompt_key_information.append(single_prompt)
  607. merge_table = ""
  608. return prompt_key_information
  609. def get_prompt_for_ocr(
  610. self,
  611. ocr_result,
  612. key_list,
  613. rules="",
  614. few_shot="",
  615. user_task_description="",
  616. ):
  617. """get prompt for ocr"""
  618. prompt_key_information = self.get_kie_prompt(
  619. ocr_result, key_list, user_task_description, rules, few_shot
  620. )
  621. return prompt_key_information
  622. def get_kie_prompt(
  623. self,
  624. text_result,
  625. key_list,
  626. user_task_description="",
  627. rules_str="",
  628. few_shot_demo_str="",
  629. prompt_type="common",
  630. ):
  631. """get_kie_prompt"""
  632. if prompt_type == "table":
  633. task_description = self.task_prompt_dict["kie_table_prompt"][
  634. "task_description"
  635. ]
  636. else:
  637. task_description = self.task_prompt_dict["kie_common_prompt"][
  638. "task_description"
  639. ]
  640. output_format = self.task_prompt_dict["kie_common_prompt"]["output_format"]
  641. if len(user_task_description) > 0:
  642. task_description = user_task_description
  643. task_description = task_description + output_format
  644. few_shot_demo_key_value = ""
  645. if self.user_prompt_dict:
  646. logging.info("======= common use custom ========")
  647. task_description = self.user_prompt_dict["task_description"]
  648. rules_str = self.user_prompt_dict["rules_str"]
  649. few_shot_demo_str = self.user_prompt_dict["few_shot_demo_str"]
  650. few_shot_demo_key_value = self.user_prompt_dict["few_shot_demo_key_value"]
  651. prompt = f"""{task_description}{rules_str}{few_shot_demo_str}{few_shot_demo_key_value}"""
  652. if prompt_type == "table":
  653. prompt += f"""\n结合上面,下面正式开始:\
  654. 表格内容:```{text_result}```\
  655. 关键词列表:[{key_list}]。""".replace(
  656. " ", ""
  657. )
  658. else:
  659. prompt += f"""\n结合上面的例子,下面正式开始:\
  660. OCR文字:```{text_result}```\
  661. 关键词列表:[{key_list}]。""".replace(
  662. " ", ""
  663. )
  664. return prompt