ppchatocrv3.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import re
  16. import json
  17. import numpy as np
  18. from .utils import *
  19. from copy import deepcopy
  20. from ...components import *
  21. from ..ocr import OCRPipeline
  22. from ....utils import logging
  23. from ...results import *
  24. from ...components.llm import ErnieBot
  25. from ...utils.io import ImageReader, PDFReader
  26. from ..table_recognition import TableRecPipeline
  27. from ...components.llm import create_llm_api, ErnieBot
  28. from ....utils.file_interface import read_yaml_file
  29. from ..table_recognition.utils import convert_4point2rect, get_ori_coordinate_for_table
  30. PROMPT_FILE = os.path.join(os.path.dirname(__file__), "ch_prompt.yaml")
  31. class PPChatOCRPipeline(TableRecPipeline):
  32. """PP-ChatOCRv3 Pileline"""
  33. entities = "PP-ChatOCRv3-doc"
  34. def __init__(
  35. self,
  36. layout_model,
  37. text_det_model,
  38. text_rec_model,
  39. table_model,
  40. oricls_model=None,
  41. uvdoc_model=None,
  42. curve_model=None,
  43. llm_name="ernie-3.5",
  44. llm_params={},
  45. task_prompt_yaml=None,
  46. user_prompt_yaml=None,
  47. layout_batch_size=1,
  48. text_det_batch_size=1,
  49. text_rec_batch_size=1,
  50. table_batch_size=1,
  51. uvdoc_batch_size=1,
  52. curve_batch_size=1,
  53. oricls_batch_size=1,
  54. recovery=True,
  55. device="gpu",
  56. predictor_kwargs=None,
  57. ):
  58. self.layout_model = layout_model
  59. self.text_det_model = text_det_model
  60. self.text_rec_model = text_rec_model
  61. self.table_model = table_model
  62. self.oricls_model = oricls_model
  63. self.uvdoc_model = uvdoc_model
  64. self.curve_model = curve_model
  65. self.llm_name = llm_name
  66. self.llm_params = llm_params
  67. self.task_prompt_yaml = task_prompt_yaml
  68. self.user_prompt_yaml = user_prompt_yaml
  69. self.layout_batch_size = layout_batch_size
  70. self.text_det_batch_size = text_det_batch_size
  71. self.text_rec_batch_size = text_rec_batch_size
  72. self.table_batch_size = table_batch_size
  73. self.uvdoc_batch_size = uvdoc_batch_size
  74. self.curve_batch_size = curve_batch_size
  75. self.oricls_batch_size = oricls_batch_size
  76. self.recovery = recovery
  77. self.device = device
  78. self.predictor_kwargs = predictor_kwargs
  79. super().__init__(
  80. layout_model=layout_model,
  81. text_det_model=text_det_model,
  82. text_rec_model=text_rec_model,
  83. table_model=table_model,
  84. layout_batch_size=layout_batch_size,
  85. text_det_batch_size=text_det_batch_size,
  86. text_rec_batch_size=text_rec_batch_size,
  87. table_batch_size=table_batch_size,
  88. predictor_kwargs=predictor_kwargs,
  89. )
  90. self._build_predictor()
  91. self.llm_api = create_llm_api(
  92. llm_name,
  93. llm_params,
  94. )
  95. self.cropper = CropByBoxes()
  96. # get base prompt from yaml info
  97. if task_prompt_yaml:
  98. self.task_prompt_dict = read_yaml_file(task_prompt_yaml)
  99. else:
  100. self.task_prompt_dict = read_yaml_file(
  101. PROMPT_FILE
  102. ) # get user prompt from yaml info
  103. if user_prompt_yaml:
  104. self.user_prompt_dict = read_yaml_file(user_prompt_yaml)
  105. else:
  106. self.user_prompt_dict = None
  107. self.recovery = recovery
  108. self.img_reader = ReadImage()
  109. self.visual_info = None
  110. self.vector = None
  111. def _build_predictor(self):
  112. super()._build_predictor()
  113. if self.curve_model:
  114. self.curve_pipeline = OCRPipeline(
  115. text_det_model=self.curve_model,
  116. text_rec_model=self.text_rec_model,
  117. text_det_batch_size=self.text_det_batch_size,
  118. text_rec_batch_size=self.text_rec_batch_size,
  119. predictor_kwargs=self.predictor_kwargs,
  120. )
  121. else:
  122. self.curve_pipeline = None
  123. if self.oricls_model:
  124. self.oricls_predictor = self._create_model(self.oricls_model)
  125. else:
  126. self.oricls_predictor = None
  127. if self.uvdoc_model:
  128. self.uvdoc_predictor = self._create_model(self.uvdoc_model)
  129. else:
  130. self.uvdoc_predictor = None
  131. if self.curve_pipeline and self.curve_batch_size:
  132. self.curve_pipeline.text_det_model.set_predictor(
  133. batch_size=self.curve_batch_size, device=self.device
  134. )
  135. if self.oricls_predictor and self.oricls_batch_size:
  136. self.oricls_predictor.set_predictor(
  137. batch_size=self.oricls_batch_size, device=self.device
  138. )
  139. if self.uvdoc_predictor and self.uvdoc_batch_size:
  140. self.uvdoc_predictor.set_predictor(
  141. batch_size=self.uvdoc_batch_size, device=self.device
  142. )
  143. def set_predictor(
  144. self,
  145. layout_batch_size=None,
  146. text_det_batch_size=None,
  147. text_rec_batch_size=None,
  148. table_batch_size=None,
  149. curve_batch_size=None,
  150. oricls_batch_size=None,
  151. uvdoc_batch_size=None,
  152. device=None,
  153. ):
  154. if text_det_batch_size and text_det_batch_size > 1:
  155. logging.warning(
  156. f"text det model only support batch_size=1 now,the setting of text_det_batch_size={text_det_batch_size} will not using! "
  157. )
  158. if layout_batch_size:
  159. self.layout_predictor.set_predictor(batch_size=layout_batch_size)
  160. if text_rec_batch_size:
  161. self.ocr_pipeline.text_rec_model.set_predictor(
  162. batch_size=text_rec_batch_size
  163. )
  164. if table_batch_size:
  165. self.table_predictor.set_predictor(batch_size=table_batch_size)
  166. if self.curve_pipeline and curve_batch_size:
  167. self.curve_pipeline.text_det_model.set_predictor(
  168. batch_size=curve_batch_size, device=device
  169. )
  170. if self.oricls_predictor and oricls_batch_size:
  171. self.oricls_predictor.set_predictor(
  172. batch_size=oricls_batch_size, device=device
  173. )
  174. if self.uvdoc_predictor and uvdoc_batch_size:
  175. self.uvdoc_predictor.set_predictor(
  176. batch_size=uvdoc_batch_size, device=device
  177. )
  178. def predict(self, input, **kwargs):
  179. visual_info = {"ocr_text": [], "table_html": [], "table_text": []}
  180. # get all visual result
  181. visual_result = list(self.get_visual_result(input, **kwargs))
  182. # decode visual result to get table_html, table_text, ocr_text
  183. ocr_text, table_text, table_html = self.decode_visual_result(visual_result)
  184. visual_info["ocr_text"] = ocr_text
  185. visual_info["table_html"] = table_html
  186. visual_info["table_text"] = table_text
  187. visual_info = VisualInfoResult(visual_info)
  188. # for local user save visual info in self
  189. self.visual_info = visual_info
  190. return visual_result, visual_info
  191. def get_visual_result(self, inputs, **kwargs):
  192. layout_batch_size = kwargs.get("layout_batch_size")
  193. text_det_batch_size = kwargs.get("text_det_batch_size")
  194. text_rec_batch_size = kwargs.get("text_rec_batch_size")
  195. table_batch_size = kwargs.get("table_batch_size")
  196. curve_batch_size = kwargs.get("curve_batch_size")
  197. oricls_batch_size = kwargs.get("oricls_batch_size")
  198. uvdoc_batch_size = kwargs.get("uvdoc_batch_size")
  199. device = kwargs.get("device")
  200. self.set_predictor(
  201. layout_batch_size,
  202. text_det_batch_size,
  203. text_rec_batch_size,
  204. table_batch_size,
  205. curve_batch_size,
  206. oricls_batch_size,
  207. uvdoc_batch_size,
  208. device,
  209. )
  210. # get oricls and uvdoc results
  211. img_info_list = list(self.img_reader(inputs))[0]
  212. oricls_results = []
  213. if self.oricls_predictor and kwargs.get("use_oricls_model", True):
  214. oricls_results = get_oriclas_results(img_info_list, self.oricls_predictor)
  215. uvdoc_results = []
  216. if self.uvdoc_predictor and kwargs.get("use_uvdoc_model", True):
  217. uvdoc_results = get_uvdoc_results(img_info_list, self.uvdoc_predictor)
  218. img_list = [img_info["img"] for img_info in img_info_list]
  219. for idx, (img_info, layout_pred) in enumerate(
  220. zip(img_info_list, self.layout_predictor(img_list))
  221. ):
  222. single_img_res = {
  223. "input_path": "",
  224. "layout_result": {},
  225. "ocr_result": {},
  226. "table_ocr_result": [],
  227. "table_result": [],
  228. "structure_result": [],
  229. "structure_result": [],
  230. "oricls_result": {},
  231. "uvdoc_result": {},
  232. "curve_result": [],
  233. }
  234. # update oricls and uvdoc result
  235. if oricls_results:
  236. single_img_res["oricls_result"] = oricls_results[idx]
  237. if uvdoc_results:
  238. single_img_res["uvdoc_result"] = uvdoc_results[idx]
  239. # update layout result
  240. single_img_res["input_path"] = layout_pred["input_path"]
  241. single_img_res["layout_result"] = layout_pred
  242. single_img = img_info["img"]
  243. table_subs = []
  244. curve_subs = []
  245. structure_res = []
  246. ocr_res_with_layout = []
  247. if len(layout_pred["boxes"]) > 0:
  248. subs_of_img = list(self._crop_by_boxes(layout_pred))
  249. # get cropped images
  250. for sub in subs_of_img:
  251. box = sub["box"]
  252. xmin, ymin, xmax, ymax = [int(i) for i in box]
  253. mask_flag = True
  254. if sub["label"].lower() == "table":
  255. table_subs.append(sub)
  256. elif sub["label"].lower() == "seal":
  257. curve_subs.append(sub)
  258. else:
  259. if self.recovery and kwargs.get("recovery", True):
  260. # TODO: Why use the entire image?
  261. wht_im = (
  262. np.ones(single_img.shape, dtype=single_img.dtype) * 255
  263. )
  264. wht_im[ymin:ymax, xmin:xmax, :] = sub["img"]
  265. sub_ocr_res = get_ocr_res(self.ocr_pipeline, wht_im)
  266. else:
  267. sub_ocr_res = get_ocr_res(self.ocr_pipeline, sub)
  268. sub_ocr_res["dt_polys"] = get_ori_coordinate_for_table(
  269. xmin, ymin, sub_ocr_res["dt_polys"]
  270. )
  271. layout_label = sub["label"].lower()
  272. if sub_ocr_res and sub["label"].lower() in [
  273. "image",
  274. "figure",
  275. "img",
  276. "fig",
  277. ]:
  278. mask_flag = False
  279. else:
  280. ocr_res_with_layout.append(sub_ocr_res)
  281. structure_res.append(
  282. {
  283. "layout_bbox": box,
  284. f"{layout_label}": "\n".join(
  285. sub_ocr_res["rec_text"]
  286. ),
  287. }
  288. )
  289. if mask_flag:
  290. single_img[ymin:ymax, xmin:xmax, :] = 255
  291. curve_pipeline = self.ocr_pipeline
  292. if self.curve_pipeline and kwargs.get("use_curve_model", True):
  293. curve_pipeline = self.curve_pipeline
  294. all_curve_res = get_ocr_res(curve_pipeline, curve_subs)
  295. single_img_res["curve_result"] = all_curve_res
  296. if isinstance(all_curve_res, dict):
  297. all_curve_res = [all_curve_res]
  298. for sub, curve_res in zip(curve_subs, all_curve_res):
  299. structure_res.append(
  300. {
  301. "layout_bbox": sub["box"],
  302. "印章": "".join(curve_res["rec_text"]),
  303. }
  304. )
  305. ocr_res = get_ocr_res(self.ocr_pipeline, single_img)
  306. ocr_res["input_path"] = layout_pred["input_path"]
  307. all_table_res, _ = self.get_table_result(table_subs)
  308. for idx, single_dt_poly in enumerate(ocr_res["dt_polys"]):
  309. structure_res.append(
  310. {
  311. "layout_bbox": convert_4point2rect(single_dt_poly),
  312. "words in text block": ocr_res["rec_text"][idx],
  313. }
  314. )
  315. # update ocr result
  316. for layout_ocr_res in ocr_res_with_layout:
  317. ocr_res["dt_polys"].extend(layout_ocr_res["dt_polys"])
  318. ocr_res["rec_text"].extend(layout_ocr_res["rec_text"])
  319. ocr_res["input_path"] = single_img_res["input_path"]
  320. all_table_ocr_res = []
  321. # get table text from html
  322. structure_res_table, all_table_ocr_res = get_table_text_from_html(
  323. all_table_res
  324. )
  325. structure_res.extend(structure_res_table)
  326. # sort the layout result by the left top point of the box
  327. structure_res = sorted_layout_boxes(structure_res, w=single_img.shape[1])
  328. structure_res = [LayoutStructureResult(item) for item in structure_res]
  329. single_img_res["table_result"] = all_table_res
  330. single_img_res["ocr_result"] = ocr_res
  331. single_img_res["table_ocr_result"] = all_table_ocr_res
  332. single_img_res["structure_result"] = structure_res
  333. yield VisualResult(single_img_res)
  334. def decode_visual_result(self, visual_result):
  335. ocr_text = []
  336. table_text_list = []
  337. table_html = []
  338. for single_img_pred in visual_result:
  339. layout_res = single_img_pred["structure_result"]
  340. layout_res_copy = deepcopy(layout_res)
  341. # layout_res is [{"layout_bbox": [x1, y1, x2, y2], "layout": "single","words in text block":"xxx"}, {"layout_bbox": [x1, y1, x2, y2], "layout": "double","印章":"xxx"}
  342. ocr_res = {}
  343. for block in layout_res_copy:
  344. block.pop("layout_bbox")
  345. block.pop("layout")
  346. for layout_type, text in block.items():
  347. if text == "":
  348. continue
  349. # Table results are used separately
  350. if layout_type == "table":
  351. continue
  352. if layout_type not in ocr_res:
  353. ocr_res[layout_type] = text
  354. else:
  355. ocr_res[layout_type] += f"\n {text}"
  356. single_table_text = " ".join(single_img_pred["table_ocr_result"])
  357. for table_pred in single_img_pred["table_result"]:
  358. html = table_pred["html"]
  359. table_html.append(html)
  360. if ocr_res:
  361. ocr_text.append(ocr_res)
  362. table_text_list.append(single_table_text)
  363. return ocr_text, table_text_list, table_html
  364. def get_vector_text(
  365. self,
  366. llm_name=None,
  367. llm_params={},
  368. visual_info=None,
  369. min_characters=0,
  370. llm_request_interval=1.0,
  371. ):
  372. """get vector for ocr"""
  373. if isinstance(self.llm_api, ErnieBot):
  374. get_vector_flag = True
  375. else:
  376. logging.warning("Do not use ErnieBot, will not get vector text.")
  377. get_vector_flag = False
  378. if not any([visual_info, self.visual_info]):
  379. return VectorResult({"vector": None})
  380. if visual_info:
  381. # use for serving or local
  382. _visual_info = visual_info
  383. else:
  384. # use for local
  385. _visual_info = self.visual_info
  386. ocr_text = _visual_info["ocr_text"]
  387. html_list = _visual_info["table_html"]
  388. table_text_list = _visual_info["table_text"]
  389. # add table text to ocr text
  390. for html, table_text_rec in zip(html_list, table_text_list):
  391. if len(html) > 3000:
  392. ocr_text.append({"table": table_text_rec})
  393. ocr_all_result = "".join(["\n".join(e.values()) for e in ocr_text])
  394. if len(ocr_all_result) > min_characters and get_vector_flag:
  395. if visual_info and llm_name:
  396. # for serving or local
  397. llm_api = create_llm_api(llm_name, llm_params)
  398. text_result = llm_api.get_vector(ocr_text, llm_request_interval)
  399. else:
  400. # for local
  401. text_result = self.llm_api.get_vector(ocr_text, llm_request_interval)
  402. else:
  403. text_result = str(ocr_text)
  404. return VectorResult({"vector": text_result})
  405. def get_retrieval_text(
  406. self,
  407. key_list,
  408. visual_info=None,
  409. vector=None,
  410. llm_name=None,
  411. llm_params={},
  412. llm_request_interval=0.1,
  413. ):
  414. if not any([visual_info, vector, self.visual_info, self.vector]):
  415. return RetrievalResult({"retrieval": None})
  416. key_list = format_key(key_list)
  417. if not any([vector, self.vector]):
  418. logging.warning(
  419. "The vector library is not created, and is being created automatically"
  420. )
  421. if visual_info and llm_name:
  422. # for serving
  423. vector = self.get_vector_text(
  424. llm_name=llm_name, llm_params=llm_params, visual_info=visual_info
  425. )
  426. else:
  427. self.vector = self.get_vector_text()
  428. if vector and llm_name:
  429. _vector = vector["vector"]
  430. llm_api = create_llm_api(llm_name, llm_params)
  431. retrieval = llm_api.caculate_similar(
  432. vector=_vector,
  433. key_list=key_list,
  434. llm_params=llm_params,
  435. sleep_time=llm_request_interval,
  436. )
  437. else:
  438. _vector = self.vector["vector"]
  439. retrieval = self.llm_api.caculate_similar(
  440. vector=_vector, key_list=key_list, sleep_time=llm_request_interval
  441. )
  442. return RetrievalResult({"retrieval": retrieval})
  443. def chat(
  444. self,
  445. key_list,
  446. vector=None,
  447. visual_info=None,
  448. retrieval_result=None,
  449. user_task_description="",
  450. rules="",
  451. few_shot="",
  452. use_vector=True,
  453. save_prompt=False,
  454. llm_name="ernie-3.5",
  455. llm_params={},
  456. ):
  457. """
  458. chat with key
  459. """
  460. if not any(
  461. [vector, visual_info, retrieval_result, self.visual_info, self.vector]
  462. ):
  463. return ChatResult(
  464. {"chat_res": "请先完成图像解析再开始再对话", "prompt": ""}
  465. )
  466. key_list = format_key(key_list)
  467. # first get from table, then get from text in table, last get from all ocr
  468. if visual_info:
  469. # use for serving or local
  470. _visual_info = visual_info
  471. else:
  472. # use for local
  473. _visual_info = self.visual_info
  474. ocr_text = _visual_info["ocr_text"]
  475. html_list = _visual_info["table_html"]
  476. table_text_list = _visual_info["table_text"]
  477. prompt_res = {"ocr_prompt": "str", "table_prompt": [], "html_prompt": []}
  478. final_results = {}
  479. failed_results = ["大模型调用失败", "未知", "未找到关键信息", "None", ""]
  480. if html_list:
  481. prompt_list = self.get_prompt_for_table(
  482. html_list, key_list, rules, few_shot
  483. )
  484. prompt_res["html_prompt"] = prompt_list
  485. for prompt, table_text in zip(prompt_list, table_text_list):
  486. logging.debug(prompt)
  487. res = self.get_llm_result(prompt)
  488. # TODO: why use one html but the whole table_text in next step
  489. if list(res.values())[0] in failed_results:
  490. logging.info(
  491. "table html sequence is too much longer, using ocr directly"
  492. )
  493. prompt = self.get_prompt_for_ocr(
  494. table_text, key_list, rules, few_shot, user_task_description
  495. )
  496. logging.debug(prompt)
  497. prompt_res["table_prompt"].append(prompt)
  498. res = self.get_llm_result(prompt)
  499. for key, value in res.items():
  500. if value not in failed_results and key in key_list:
  501. key_list.remove(key)
  502. final_results[key] = value
  503. if len(key_list) > 0:
  504. logging.info("get result from ocr")
  505. if retrieval_result:
  506. ocr_text = retrieval_result
  507. elif use_vector and any([visual_info, vector]):
  508. # for serving or local
  509. ocr_text = self.get_retrieval_text(
  510. key_list=key_list,
  511. visual_info=visual_info,
  512. vector=vector,
  513. llm_name=llm_name,
  514. llm_params=llm_params,
  515. )
  516. else:
  517. # for local
  518. ocr_text = self.get_retrieval_text(key_list=key_list)
  519. prompt = self.get_prompt_for_ocr(
  520. ocr_text,
  521. key_list,
  522. rules,
  523. few_shot,
  524. user_task_description,
  525. )
  526. logging.debug(prompt)
  527. prompt_res["ocr_prompt"] = prompt
  528. res = self.get_llm_result(prompt)
  529. if res:
  530. final_results.update(res)
  531. if not res and not final_results:
  532. final_results = self.llm_api.ERROR_MASSAGE
  533. if save_prompt:
  534. return ChatResult({"chat_res": final_results, "prompt": prompt_res})
  535. else:
  536. return ChatResult({"chat_res": final_results, "prompt": ""})
  537. def get_llm_result(self, prompt):
  538. """get llm result and decode to dict"""
  539. llm_result = self.llm_api.pred(prompt)
  540. # when the llm pred failed, return None
  541. if not llm_result:
  542. return None
  543. if "json" in llm_result or "```" in llm_result:
  544. llm_result = (
  545. llm_result.replace("```", "").replace("json", "").replace("/n", "")
  546. )
  547. llm_result = llm_result.replace("[", "").replace("]", "")
  548. try:
  549. llm_result = json.loads(llm_result)
  550. llm_result_final = {}
  551. for key in llm_result:
  552. value = llm_result[key]
  553. if isinstance(value, list):
  554. if len(value) > 0:
  555. llm_result_final[key] = value[0]
  556. else:
  557. llm_result_final[key] = value
  558. return llm_result_final
  559. except:
  560. results = (
  561. llm_result.replace("\n", "")
  562. .replace(" ", "")
  563. .replace("{", "")
  564. .replace("}", "")
  565. )
  566. if not results.endswith('"'):
  567. results = results + '"'
  568. pattern = r'"(.*?)": "([^"]*)"'
  569. matches = re.findall(pattern, str(results))
  570. llm_result = {k: v for k, v in matches}
  571. return llm_result
  572. def get_prompt_for_table(self, table_result, key_list, rules="", few_shot=""):
  573. """get prompt for table"""
  574. prompt_key_information = []
  575. merge_table = ""
  576. for idx, result in enumerate(table_result):
  577. if len(merge_table + result) < 2000:
  578. merge_table += result
  579. if len(merge_table + result) > 2000 or idx == len(table_result) - 1:
  580. single_prompt = self.get_kie_prompt(
  581. merge_table,
  582. key_list,
  583. rules_str=rules,
  584. few_shot_demo_str=few_shot,
  585. prompt_type="table",
  586. )
  587. prompt_key_information.append(single_prompt)
  588. merge_table = ""
  589. return prompt_key_information
  590. def get_prompt_for_ocr(
  591. self,
  592. ocr_result,
  593. key_list,
  594. rules="",
  595. few_shot="",
  596. user_task_description="",
  597. ):
  598. """get prompt for ocr"""
  599. prompt_key_information = self.get_kie_prompt(
  600. ocr_result, key_list, user_task_description, rules, few_shot
  601. )
  602. return prompt_key_information
  603. def get_kie_prompt(
  604. self,
  605. text_result,
  606. key_list,
  607. user_task_description="",
  608. rules_str="",
  609. few_shot_demo_str="",
  610. prompt_type="common",
  611. ):
  612. """get_kie_prompt"""
  613. if prompt_type == "table":
  614. task_description = self.task_prompt_dict["kie_table_prompt"][
  615. "task_description"
  616. ]
  617. else:
  618. task_description = self.task_prompt_dict["kie_common_prompt"][
  619. "task_description"
  620. ]
  621. output_format = self.task_prompt_dict["kie_common_prompt"]["output_format"]
  622. if len(user_task_description) > 0:
  623. task_description = user_task_description
  624. task_description = task_description + output_format
  625. few_shot_demo_key_value = ""
  626. if self.user_prompt_dict:
  627. logging.info("======= common use custom ========")
  628. task_description = self.user_prompt_dict["task_description"]
  629. rules_str = self.user_prompt_dict["rules_str"]
  630. few_shot_demo_str = self.user_prompt_dict["few_shot_demo_str"]
  631. few_shot_demo_key_value = self.user_prompt_dict["few_shot_demo_key_value"]
  632. prompt = f"""{task_description}{rules_str}{few_shot_demo_str}{few_shot_demo_key_value}"""
  633. if prompt_type == "table":
  634. prompt += f"""\n结合上面,下面正式开始:\
  635. 表格内容:```{text_result}```\
  636. 关键词列表:[{key_list}]。""".replace(
  637. " ", ""
  638. )
  639. else:
  640. prompt += f"""\n结合上面的例子,下面正式开始:\
  641. OCR文字:```{text_result}```\
  642. 关键词列表:[{key_list}]。""".replace(
  643. " ", ""
  644. )
  645. return prompt