ppchatocrv3.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import re
  16. import json
  17. import numpy as np
  18. from .utils import *
  19. from copy import deepcopy
  20. from ...components import *
  21. from ..ocr import OCRPipeline
  22. from ....utils import logging
  23. from ...results import *
  24. from ...components.llm import ErnieBot
  25. from ...utils.io import ImageReader, PDFReader
  26. from ..table_recognition import TableRecPipeline
  27. from ...components.llm import create_llm_api, ErnieBot
  28. from ....utils.file_interface import read_yaml_file
  29. from ..table_recognition.utils import convert_4point2rect, get_ori_coordinate_for_table
  30. PROMPT_FILE = os.path.join(os.path.dirname(__file__), "ch_prompt.yaml")
  31. class PPChatOCRPipeline(TableRecPipeline):
  32. """PP-ChatOCRv3 Pileline"""
  33. entities = "PP-ChatOCRv3-doc"
  34. def __init__(
  35. self,
  36. layout_model,
  37. text_det_model,
  38. text_rec_model,
  39. table_model,
  40. doc_image_ori_cls_model=None,
  41. doc_image_unwarp_model=None,
  42. seal_text_det_model=None,
  43. llm_name="ernie-3.5",
  44. llm_params={},
  45. task_prompt_yaml=None,
  46. user_prompt_yaml=None,
  47. layout_batch_size=1,
  48. text_det_batch_size=1,
  49. text_rec_batch_size=1,
  50. table_batch_size=1,
  51. uvdoc_batch_size=1,
  52. curve_batch_size=1,
  53. oricls_batch_size=1,
  54. recovery=True,
  55. device="gpu",
  56. predictor_kwargs=None,
  57. ):
  58. self.layout_model = layout_model
  59. self.text_det_model = text_det_model
  60. self.text_rec_model = text_rec_model
  61. self.table_model = table_model
  62. self.doc_image_ori_cls_model = doc_image_ori_cls_model
  63. self.doc_image_unwarp_model = doc_image_unwarp_model
  64. self.seal_text_det_model = seal_text_det_model
  65. self.llm_name = llm_name
  66. self.llm_params = llm_params
  67. self.task_prompt_yaml = task_prompt_yaml
  68. self.user_prompt_yaml = user_prompt_yaml
  69. self.layout_batch_size = layout_batch_size
  70. self.text_det_batch_size = text_det_batch_size
  71. self.text_rec_batch_size = text_rec_batch_size
  72. self.table_batch_size = table_batch_size
  73. self.uvdoc_batch_size = uvdoc_batch_size
  74. self.curve_batch_size = curve_batch_size
  75. self.oricls_batch_size = oricls_batch_size
  76. self.recovery = recovery
  77. self.device = device
  78. self.predictor_kwargs = predictor_kwargs
  79. super().__init__(
  80. layout_model=layout_model,
  81. text_det_model=text_det_model,
  82. text_rec_model=text_rec_model,
  83. table_model=table_model,
  84. layout_batch_size=layout_batch_size,
  85. text_det_batch_size=text_det_batch_size,
  86. text_rec_batch_size=text_rec_batch_size,
  87. table_batch_size=table_batch_size,
  88. predictor_kwargs=predictor_kwargs,
  89. )
  90. self._build_predictor()
  91. self.llm_api = create_llm_api(
  92. llm_name,
  93. llm_params,
  94. )
  95. self.cropper = CropByBoxes()
  96. # get base prompt from yaml info
  97. if task_prompt_yaml:
  98. self.task_prompt_dict = read_yaml_file(task_prompt_yaml)
  99. else:
  100. self.task_prompt_dict = read_yaml_file(
  101. PROMPT_FILE
  102. ) # get user prompt from yaml info
  103. if user_prompt_yaml:
  104. self.user_prompt_dict = read_yaml_file(user_prompt_yaml)
  105. else:
  106. self.user_prompt_dict = None
  107. self.recovery = recovery
  108. self.img_reader = ReadImage()
  109. self.visual_info = None
  110. self.vector = None
  111. self.visual_flag = False
  112. def _build_predictor(self):
  113. super()._build_predictor()
  114. if self.seal_text_det_model:
  115. self.curve_pipeline = OCRPipeline(
  116. text_det_model=self.seal_text_det_model,
  117. text_rec_model=self.text_rec_model,
  118. text_det_batch_size=self.text_det_batch_size,
  119. text_rec_batch_size=self.text_rec_batch_size,
  120. predictor_kwargs=self.predictor_kwargs,
  121. )
  122. else:
  123. self.curve_pipeline = None
  124. if self.doc_image_ori_cls_model:
  125. self.oricls_predictor = self._create_model(self.doc_image_ori_cls_model)
  126. else:
  127. self.oricls_predictor = None
  128. if self.doc_image_unwarp_model:
  129. self.uvdoc_predictor = self._create_model(self.doc_image_unwarp_model)
  130. else:
  131. self.uvdoc_predictor = None
  132. if self.curve_pipeline and self.curve_batch_size:
  133. self.curve_pipeline.text_det_model.set_predictor(
  134. batch_size=self.curve_batch_size, device=self.device
  135. )
  136. if self.oricls_predictor and self.oricls_batch_size:
  137. self.oricls_predictor.set_predictor(
  138. batch_size=self.oricls_batch_size, device=self.device
  139. )
  140. if self.uvdoc_predictor and self.uvdoc_batch_size:
  141. self.uvdoc_predictor.set_predictor(
  142. batch_size=self.uvdoc_batch_size, device=self.device
  143. )
  144. def set_predictor(
  145. self,
  146. layout_batch_size=None,
  147. text_det_batch_size=None,
  148. text_rec_batch_size=None,
  149. table_batch_size=None,
  150. curve_batch_size=None,
  151. oricls_batch_size=None,
  152. uvdoc_batch_size=None,
  153. device=None,
  154. ):
  155. if text_det_batch_size and text_det_batch_size > 1:
  156. logging.warning(
  157. f"text det model only support batch_size=1 now,the setting of text_det_batch_size={text_det_batch_size} will not using! "
  158. )
  159. if layout_batch_size:
  160. self.layout_predictor.set_predictor(batch_size=layout_batch_size)
  161. if text_rec_batch_size:
  162. self.ocr_pipeline.text_rec_model.set_predictor(
  163. batch_size=text_rec_batch_size
  164. )
  165. if table_batch_size:
  166. self.table_predictor.set_predictor(batch_size=table_batch_size)
  167. if self.curve_pipeline and curve_batch_size:
  168. self.curve_pipeline.text_det_model.set_predictor(
  169. batch_size=curve_batch_size, device=device
  170. )
  171. if self.oricls_predictor and oricls_batch_size:
  172. self.oricls_predictor.set_predictor(
  173. batch_size=oricls_batch_size, device=device
  174. )
  175. if self.uvdoc_predictor and uvdoc_batch_size:
  176. self.uvdoc_predictor.set_predictor(
  177. batch_size=uvdoc_batch_size, device=device
  178. )
  179. def predict(self, input, **kwargs):
  180. visual_info = {"ocr_text": [], "table_html": [], "table_text": []}
  181. # get all visual result
  182. visual_result = list(self.get_visual_result(input, **kwargs))
  183. # decode visual result to get table_html, table_text, ocr_text
  184. ocr_text, table_text, table_html = self.decode_visual_result(visual_result)
  185. visual_info["ocr_text"] = ocr_text
  186. visual_info["table_html"] = table_html
  187. visual_info["table_text"] = table_text
  188. visual_info = VisualInfoResult(visual_info)
  189. # for local user save visual info in self
  190. self.visual_info = visual_info
  191. self.visual_flag = True
  192. return visual_result, visual_info
  193. def get_visual_result(self, inputs, **kwargs):
  194. layout_batch_size = kwargs.get("layout_batch_size")
  195. text_det_batch_size = kwargs.get("text_det_batch_size")
  196. text_rec_batch_size = kwargs.get("text_rec_batch_size")
  197. table_batch_size = kwargs.get("table_batch_size")
  198. curve_batch_size = kwargs.get("curve_batch_size")
  199. oricls_batch_size = kwargs.get("oricls_batch_size")
  200. uvdoc_batch_size = kwargs.get("uvdoc_batch_size")
  201. device = kwargs.get("device")
  202. self.set_predictor(
  203. layout_batch_size,
  204. text_det_batch_size,
  205. text_rec_batch_size,
  206. table_batch_size,
  207. curve_batch_size,
  208. oricls_batch_size,
  209. uvdoc_batch_size,
  210. device,
  211. )
  212. # get oricls and uvdoc results
  213. img_info_list = list(self.img_reader(inputs))[0]
  214. oricls_results = []
  215. if self.oricls_predictor and kwargs.get("use_doc_image_ori_cls_model", True):
  216. oricls_results = get_oriclas_results(img_info_list, self.oricls_predictor)
  217. uvdoc_results = []
  218. if self.uvdoc_predictor and kwargs.get("use_doc_image_unwarp_model", True):
  219. uvdoc_results = get_uvdoc_results(img_info_list, self.uvdoc_predictor)
  220. img_list = [img_info["img"] for img_info in img_info_list]
  221. for idx, (img_info, layout_pred) in enumerate(
  222. zip(img_info_list, self.layout_predictor(img_list))
  223. ):
  224. single_img_res = {
  225. "input_path": "",
  226. "layout_result": DetResult({}),
  227. "ocr_result": OCRResult({}),
  228. "table_ocr_result": [],
  229. "table_result": StructureTableResult([]),
  230. "structure_result": [],
  231. "oricls_result": TopkResult({}),
  232. "uvdoc_result": DocTrResult({}),
  233. "curve_result": [],
  234. }
  235. # update oricls and uvdoc result
  236. if oricls_results:
  237. single_img_res["oricls_result"] = oricls_results[idx]
  238. if uvdoc_results:
  239. single_img_res["uvdoc_result"] = uvdoc_results[idx]
  240. # update layout result
  241. single_img_res["input_path"] = layout_pred["input_path"]
  242. single_img_res["layout_result"] = layout_pred
  243. single_img = img_info["img"]
  244. table_subs = []
  245. curve_subs = []
  246. structure_res = []
  247. ocr_res_with_layout = []
  248. if len(layout_pred["boxes"]) > 0:
  249. subs_of_img = list(self._crop_by_boxes(layout_pred))
  250. # get cropped images
  251. for sub in subs_of_img:
  252. box = sub["box"]
  253. xmin, ymin, xmax, ymax = [int(i) for i in box]
  254. mask_flag = True
  255. if sub["label"].lower() == "table":
  256. table_subs.append(sub)
  257. elif sub["label"].lower() == "seal":
  258. curve_subs.append(sub)
  259. else:
  260. if self.recovery and kwargs.get("recovery", True):
  261. # TODO: Why use the entire image?
  262. wht_im = (
  263. np.ones(single_img.shape, dtype=single_img.dtype) * 255
  264. )
  265. wht_im[ymin:ymax, xmin:xmax, :] = sub["img"]
  266. sub_ocr_res = get_ocr_res(self.ocr_pipeline, wht_im)
  267. else:
  268. sub_ocr_res = get_ocr_res(self.ocr_pipeline, sub)
  269. sub_ocr_res["dt_polys"] = get_ori_coordinate_for_table(
  270. xmin, ymin, sub_ocr_res["dt_polys"]
  271. )
  272. layout_label = sub["label"].lower()
  273. if sub_ocr_res and sub["label"].lower() in [
  274. "image",
  275. "figure",
  276. "img",
  277. "fig",
  278. ]:
  279. mask_flag = False
  280. else:
  281. ocr_res_with_layout.append(sub_ocr_res)
  282. structure_res.append(
  283. {
  284. "layout_bbox": box,
  285. f"{layout_label}": "\n".join(
  286. sub_ocr_res["rec_text"]
  287. ),
  288. }
  289. )
  290. if mask_flag:
  291. single_img[ymin:ymax, xmin:xmax, :] = 255
  292. curve_pipeline = self.ocr_pipeline
  293. if self.curve_pipeline and kwargs.get("use_seal_text_det_model", True):
  294. curve_pipeline = self.curve_pipeline
  295. all_curve_res = get_ocr_res(curve_pipeline, curve_subs)
  296. single_img_res["curve_result"] = all_curve_res
  297. if isinstance(all_curve_res, dict):
  298. all_curve_res = [all_curve_res]
  299. for sub, curve_res in zip(curve_subs, all_curve_res):
  300. structure_res.append(
  301. {
  302. "layout_bbox": sub["box"],
  303. "印章": "".join(curve_res["rec_text"]),
  304. }
  305. )
  306. ocr_res = get_ocr_res(self.ocr_pipeline, single_img)
  307. ocr_res["input_path"] = layout_pred["input_path"]
  308. all_table_res, _ = self.get_table_result(table_subs)
  309. for idx, single_dt_poly in enumerate(ocr_res["dt_polys"]):
  310. structure_res.append(
  311. {
  312. "layout_bbox": convert_4point2rect(single_dt_poly),
  313. "words in text block": ocr_res["rec_text"][idx],
  314. }
  315. )
  316. # update ocr result
  317. for layout_ocr_res in ocr_res_with_layout:
  318. ocr_res["dt_polys"].extend(layout_ocr_res["dt_polys"])
  319. ocr_res["rec_text"].extend(layout_ocr_res["rec_text"])
  320. ocr_res["input_path"] = single_img_res["input_path"]
  321. all_table_ocr_res = []
  322. # get table text from html
  323. structure_res_table, all_table_ocr_res = get_table_text_from_html(
  324. all_table_res
  325. )
  326. structure_res.extend(structure_res_table)
  327. # sort the layout result by the left top point of the box
  328. structure_res = sorted_layout_boxes(structure_res, w=single_img.shape[1])
  329. structure_res = [LayoutStructureResult(item) for item in structure_res]
  330. single_img_res["table_result"] = all_table_res
  331. single_img_res["ocr_result"] = ocr_res
  332. single_img_res["table_ocr_result"] = all_table_ocr_res
  333. single_img_res["structure_result"] = structure_res
  334. yield VisualResult(single_img_res)
  335. def decode_visual_result(self, visual_result):
  336. ocr_text = []
  337. table_text_list = []
  338. table_html = []
  339. for single_img_pred in visual_result:
  340. layout_res = single_img_pred["structure_result"]
  341. layout_res_copy = deepcopy(layout_res)
  342. # layout_res is [{"layout_bbox": [x1, y1, x2, y2], "layout": "single","words in text block":"xxx"}, {"layout_bbox": [x1, y1, x2, y2], "layout": "double","印章":"xxx"}
  343. ocr_res = {}
  344. for block in layout_res_copy:
  345. block.pop("layout_bbox")
  346. block.pop("layout")
  347. for layout_type, text in block.items():
  348. if text == "":
  349. continue
  350. # Table results are used separately
  351. if layout_type == "table":
  352. continue
  353. if layout_type not in ocr_res:
  354. ocr_res[layout_type] = text
  355. else:
  356. ocr_res[layout_type] += f"\n {text}"
  357. single_table_text = " ".join(single_img_pred["table_ocr_result"])
  358. for table_pred in single_img_pred["table_result"]:
  359. html = table_pred["html"]
  360. table_html.append(html)
  361. if ocr_res:
  362. ocr_text.append(ocr_res)
  363. table_text_list.append(single_table_text)
  364. return ocr_text, table_text_list, table_html
  365. def get_vector_text(
  366. self,
  367. llm_name=None,
  368. llm_params={},
  369. visual_info=None,
  370. min_characters=3500,
  371. llm_request_interval=1.0,
  372. ):
  373. """get vector for ocr"""
  374. if isinstance(self.llm_api, ErnieBot):
  375. get_vector_flag = True
  376. else:
  377. logging.warning("Do not use ErnieBot, will not get vector text.")
  378. get_vector_flag = False
  379. if not any([visual_info, self.visual_info]):
  380. return VectorResult({"vector": None})
  381. if visual_info:
  382. # use for serving or local
  383. _visual_info = visual_info
  384. else:
  385. # use for local
  386. _visual_info = self.visual_info
  387. ocr_text = _visual_info["ocr_text"]
  388. html_list = _visual_info["table_html"]
  389. table_text_list = _visual_info["table_text"]
  390. # add table text to ocr text
  391. for html, table_text_rec in zip(html_list, table_text_list):
  392. if len(html) > 3000:
  393. ocr_text.append({"table": table_text_rec})
  394. ocr_all_result = "".join(["\n".join(e.values()) for e in ocr_text])
  395. if len(ocr_all_result) > min_characters and get_vector_flag:
  396. if visual_info and llm_name:
  397. # for serving or local
  398. llm_api = create_llm_api(llm_name, llm_params)
  399. text_result = llm_api.get_vector(ocr_text, llm_request_interval)
  400. else:
  401. # for local
  402. text_result = self.llm_api.get_vector(ocr_text, llm_request_interval)
  403. else:
  404. text_result = str(ocr_text)
  405. self.visual_flag = False
  406. return VectorResult({"vector": text_result})
  407. def get_retrieval_text(
  408. self,
  409. key_list,
  410. visual_info=None,
  411. vector=None,
  412. llm_name=None,
  413. llm_params={},
  414. llm_request_interval=0.1,
  415. ):
  416. if not any([visual_info, vector, self.visual_info, self.vector]):
  417. return RetrievalResult({"retrieval": None})
  418. key_list = format_key(key_list)
  419. is_seving = visual_info and llm_name
  420. if self.visual_flag and not is_seving:
  421. self.vector = self.get_vector_text()
  422. if not any([vector, self.vector]):
  423. logging.warning(
  424. "The vector library is not created, and is being created automatically"
  425. )
  426. if is_seving:
  427. # for serving
  428. vector = self.get_vector_text(
  429. llm_name=llm_name, llm_params=llm_params, visual_info=visual_info
  430. )
  431. else:
  432. self.vector = self.get_vector_text()
  433. if vector and llm_name:
  434. _vector = vector["vector"]
  435. llm_api = create_llm_api(llm_name, llm_params)
  436. retrieval = llm_api.caculate_similar(
  437. vector=_vector,
  438. key_list=key_list,
  439. llm_params=llm_params,
  440. sleep_time=llm_request_interval,
  441. )
  442. else:
  443. _vector = self.vector["vector"]
  444. retrieval = self.llm_api.caculate_similar(
  445. vector=_vector, key_list=key_list, sleep_time=llm_request_interval
  446. )
  447. return RetrievalResult({"retrieval": retrieval})
  448. def chat(
  449. self,
  450. key_list,
  451. vector=None,
  452. visual_info=None,
  453. retrieval_result=None,
  454. user_task_description="",
  455. rules="",
  456. few_shot="",
  457. use_vector=True,
  458. save_prompt=False,
  459. llm_name="ernie-3.5",
  460. llm_params={},
  461. ):
  462. """
  463. chat with key
  464. """
  465. if not any(
  466. [vector, visual_info, retrieval_result, self.visual_info, self.vector]
  467. ):
  468. return ChatResult(
  469. {"chat_res": "请先完成图像解析再开始再对话", "prompt": ""}
  470. )
  471. key_list = format_key(key_list)
  472. # first get from table, then get from text in table, last get from all ocr
  473. if visual_info:
  474. # use for serving or local
  475. _visual_info = visual_info
  476. else:
  477. # use for local
  478. _visual_info = self.visual_info
  479. ocr_text = _visual_info["ocr_text"]
  480. html_list = _visual_info["table_html"]
  481. table_text_list = _visual_info["table_text"]
  482. prompt_res = {"ocr_prompt": "str", "table_prompt": [], "html_prompt": []}
  483. final_results = {}
  484. failed_results = ["大模型调用失败", "未知", "未找到关键信息", "None", ""]
  485. if html_list:
  486. prompt_list = self.get_prompt_for_table(
  487. html_list, key_list, rules, few_shot
  488. )
  489. prompt_res["html_prompt"] = prompt_list
  490. for prompt, table_text in zip(prompt_list, table_text_list):
  491. logging.debug(prompt)
  492. res = self.get_llm_result(prompt)
  493. # TODO: why use one html but the whole table_text in next step
  494. if list(res.values())[0] in failed_results:
  495. logging.info(
  496. "table html sequence is too much longer, using ocr directly"
  497. )
  498. prompt = self.get_prompt_for_ocr(
  499. table_text, key_list, rules, few_shot, user_task_description
  500. )
  501. logging.debug(prompt)
  502. prompt_res["table_prompt"].append(prompt)
  503. res = self.get_llm_result(prompt)
  504. for key, value in res.items():
  505. if value not in failed_results and key in key_list:
  506. key_list.remove(key)
  507. final_results[key] = value
  508. if len(key_list) > 0:
  509. logging.info("get result from ocr")
  510. if retrieval_result:
  511. ocr_text = retrieval_result.get("retrieval")
  512. elif use_vector and any([visual_info, vector]):
  513. # for serving or local
  514. ocr_text = self.get_retrieval_text(
  515. key_list=key_list,
  516. visual_info=visual_info,
  517. vector=vector,
  518. llm_name=llm_name,
  519. llm_params=llm_params,
  520. )["retrieval"]
  521. else:
  522. # for local
  523. ocr_text = self.get_retrieval_text(key_list=key_list)["retrieval"]
  524. prompt = self.get_prompt_for_ocr(
  525. ocr_text,
  526. key_list,
  527. rules,
  528. few_shot,
  529. user_task_description,
  530. )
  531. logging.debug(prompt)
  532. prompt_res["ocr_prompt"] = prompt
  533. res = self.get_llm_result(prompt)
  534. if res:
  535. final_results.update(res)
  536. if not res and not final_results:
  537. final_results = self.llm_api.ERROR_MASSAGE
  538. if save_prompt:
  539. return ChatResult({"chat_res": final_results, "prompt": prompt_res})
  540. else:
  541. return ChatResult({"chat_res": final_results, "prompt": ""})
  542. def get_llm_result(self, prompt):
  543. """get llm result and decode to dict"""
  544. llm_result = self.llm_api.pred(prompt)
  545. # when the llm pred failed, return None
  546. if not llm_result:
  547. return None
  548. if "json" in llm_result or "```" in llm_result:
  549. llm_result = (
  550. llm_result.replace("```", "").replace("json", "").replace("/n", "")
  551. )
  552. llm_result = llm_result.replace("[", "").replace("]", "")
  553. try:
  554. llm_result = json.loads(llm_result)
  555. llm_result_final = {}
  556. for key in llm_result:
  557. value = llm_result[key]
  558. if isinstance(value, list):
  559. if len(value) > 0:
  560. llm_result_final[key] = value[0]
  561. else:
  562. llm_result_final[key] = value
  563. return llm_result_final
  564. except:
  565. results = (
  566. llm_result.replace("\n", "")
  567. .replace(" ", "")
  568. .replace("{", "")
  569. .replace("}", "")
  570. )
  571. if not results.endswith('"'):
  572. results = results + '"'
  573. pattern = r'"(.*?)": "([^"]*)"'
  574. matches = re.findall(pattern, str(results))
  575. llm_result = {k: v for k, v in matches}
  576. return llm_result
  577. def get_prompt_for_table(self, table_result, key_list, rules="", few_shot=""):
  578. """get prompt for table"""
  579. prompt_key_information = []
  580. merge_table = ""
  581. for idx, result in enumerate(table_result):
  582. if len(merge_table + result) < 2000:
  583. merge_table += result
  584. if len(merge_table + result) > 2000 or idx == len(table_result) - 1:
  585. single_prompt = self.get_kie_prompt(
  586. merge_table,
  587. key_list,
  588. rules_str=rules,
  589. few_shot_demo_str=few_shot,
  590. prompt_type="table",
  591. )
  592. prompt_key_information.append(single_prompt)
  593. merge_table = ""
  594. return prompt_key_information
  595. def get_prompt_for_ocr(
  596. self,
  597. ocr_result,
  598. key_list,
  599. rules="",
  600. few_shot="",
  601. user_task_description="",
  602. ):
  603. """get prompt for ocr"""
  604. prompt_key_information = self.get_kie_prompt(
  605. ocr_result, key_list, user_task_description, rules, few_shot
  606. )
  607. return prompt_key_information
  608. def get_kie_prompt(
  609. self,
  610. text_result,
  611. key_list,
  612. user_task_description="",
  613. rules_str="",
  614. few_shot_demo_str="",
  615. prompt_type="common",
  616. ):
  617. """get_kie_prompt"""
  618. if prompt_type == "table":
  619. task_description = self.task_prompt_dict["kie_table_prompt"][
  620. "task_description"
  621. ]
  622. else:
  623. task_description = self.task_prompt_dict["kie_common_prompt"][
  624. "task_description"
  625. ]
  626. output_format = self.task_prompt_dict["kie_common_prompt"]["output_format"]
  627. if len(user_task_description) > 0:
  628. task_description = user_task_description
  629. task_description = task_description + output_format
  630. few_shot_demo_key_value = ""
  631. if self.user_prompt_dict:
  632. logging.info("======= common use custom ========")
  633. task_description = self.user_prompt_dict["task_description"]
  634. rules_str = self.user_prompt_dict["rules_str"]
  635. few_shot_demo_str = self.user_prompt_dict["few_shot_demo_str"]
  636. few_shot_demo_key_value = self.user_prompt_dict["few_shot_demo_key_value"]
  637. prompt = f"""{task_description}{rules_str}{few_shot_demo_str}{few_shot_demo_key_value}"""
  638. if prompt_type == "table":
  639. prompt += f"""\n结合上面,下面正式开始:\
  640. 表格内容:```{text_result}```\
  641. 关键词列表:[{key_list}]。""".replace(
  642. " ", ""
  643. )
  644. else:
  645. prompt += f"""\n结合上面的例子,下面正式开始:\
  646. OCR文字:```{text_result}```\
  647. 关键词列表:[{key_list}]。""".replace(
  648. " ", ""
  649. )
  650. return prompt