ppchatocrv3.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import re
  16. import json
  17. import numpy as np
  18. from .utils import *
  19. from copy import deepcopy
  20. from ...components import *
  21. from ..ocr import OCRPipeline
  22. from ....utils import logging
  23. from ...results import (
  24. TableResult,
  25. LayoutStructureResult,
  26. VisualInfoResult,
  27. ChatOCRResult,
  28. )
  29. from ...components.llm import ErnieBot
  30. from ...utils.io import ImageReader, PDFReader
  31. from ..table_recognition import TableRecPipeline
  32. from ...components.llm import create_llm_api, ErnieBot
  33. from ....utils.file_interface import read_yaml_file
  34. from ..table_recognition.utils import convert_4point2rect, get_ori_coordinate_for_table
  35. PROMPT_FILE = os.path.join(os.path.dirname(__file__), "ch_prompt.yaml")
  36. class PPChatOCRPipeline(TableRecPipeline):
  37. """PP-ChatOCRv3 Pileline"""
  38. entities = "PP-ChatOCRv3"
  39. def __init__(
  40. self,
  41. layout_model,
  42. text_det_model,
  43. text_rec_model,
  44. table_model,
  45. oricls_model=None,
  46. uvdoc_model=None,
  47. curve_model=None,
  48. llm_name="ernie-3.5",
  49. llm_params={},
  50. task_prompt_yaml=None,
  51. user_prompt_yaml=None,
  52. layout_batch_size=1,
  53. text_det_batch_size=1,
  54. text_rec_batch_size=1,
  55. table_batch_size=1,
  56. uvdoc_batch_size=1,
  57. curve_batch_size=1,
  58. oricls_batch_size=1,
  59. recovery=True,
  60. device="gpu",
  61. predictor_kwargs=None,
  62. ):
  63. self.layout_model = layout_model
  64. self.text_det_model = text_det_model
  65. self.text_rec_model = text_rec_model
  66. self.table_model = table_model
  67. self.oricls_model = oricls_model
  68. self.uvdoc_model = uvdoc_model
  69. self.curve_model = curve_model
  70. self.llm_name = llm_name
  71. self.llm_params = llm_params
  72. self.task_prompt_yaml = task_prompt_yaml
  73. self.user_prompt_yaml = user_prompt_yaml
  74. self.layout_batch_size = layout_batch_size
  75. self.text_det_batch_size = text_det_batch_size
  76. self.text_rec_batch_size = text_rec_batch_size
  77. self.table_batch_size = table_batch_size
  78. self.uvdoc_batch_size = uvdoc_batch_size
  79. self.curve_batch_size = curve_batch_size
  80. self.oricls_batch_size = oricls_batch_size
  81. self.recovery = recovery
  82. self.device = device
  83. self.predictor_kwargs = predictor_kwargs
  84. super().__init__(
  85. layout_model=layout_model,
  86. text_det_model=text_det_model,
  87. text_rec_model=text_rec_model,
  88. table_model=table_model,
  89. layout_batch_size=layout_batch_size,
  90. text_det_batch_size=text_det_batch_size,
  91. text_rec_batch_size=text_rec_batch_size,
  92. table_batch_size=table_batch_size,
  93. predictor_kwargs=predictor_kwargs,
  94. )
  95. self._build_predictor()
  96. self.llm_api = create_llm_api(
  97. llm_name,
  98. llm_params,
  99. )
  100. self.cropper = CropByBoxes()
  101. # get base prompt from yaml info
  102. if task_prompt_yaml:
  103. self.task_prompt_dict = read_yaml_file(task_prompt_yaml)
  104. else:
  105. self.task_prompt_dict = read_yaml_file(
  106. PROMPT_FILE
  107. ) # get user prompt from yaml info
  108. if user_prompt_yaml:
  109. self.user_prompt_dict = read_yaml_file(user_prompt_yaml)
  110. else:
  111. self.user_prompt_dict = None
  112. self.recovery = recovery
  113. self.img_reader = ReadImage()
  114. self.pdf_reader = PDFReader()
  115. self.visual_info = None
  116. self.vector = None
  117. self._set_predictor(
  118. oricls_batch_size, uvdoc_batch_size, curve_batch_size, device=device
  119. )
  120. def _build_predictor(self):
  121. super()._build_predictor()
  122. if self.curve_model:
  123. self.curve_pipeline = OCRPipeline(
  124. text_det_model=self.curve_model,
  125. text_rec_model=self.text_rec_model,
  126. text_det_batch_size=self.text_det_batch_size,
  127. text_rec_batch_size=self.text_rec_batch_size,
  128. predictor_kwargs=self.predictor_kwargs,
  129. )
  130. else:
  131. self.curve_pipeline = None
  132. if self.oricls_model:
  133. self.oricls_predictor = self._create_model(self.oricls_model)
  134. else:
  135. self.oricls_predictor = None
  136. if self.uvdoc_model:
  137. self.uvdoc_predictor = self._create_model(self.uvdoc_model)
  138. else:
  139. self.uvdoc_predictor = None
  140. if self.curve_pipeline and self.curve_batch_size:
  141. self.curve_pipeline.text_det_model.set_predictor(
  142. batch_size=self.curve_batch_size, device=self.device
  143. )
  144. if self.oricls_predictor and self.oricls_batch_size:
  145. self.oricls_predictor.set_predictor(
  146. batch_size=self.oricls_batch_size, device=self.device
  147. )
  148. if self.uvdoc_predictor and self.uvdoc_batch_size:
  149. self.uvdoc_predictor.set_predictor(
  150. batch_size=self.uvdoc_batch_size, device=self.device
  151. )
  152. def _set_predictor(
  153. self, curve_batch_size, oricls_batch_size, uvdoc_batch_size, device
  154. ):
  155. if self.curve_pipeline and curve_batch_size:
  156. self.curve_pipeline.text_det_model.set_predictor(
  157. batch_size=curve_batch_size, device=device
  158. )
  159. if self.oricls_predictor and oricls_batch_size:
  160. self.oricls_predictor.set_predictor(
  161. batch_size=oricls_batch_size, device=device
  162. )
  163. if self.uvdoc_predictor and uvdoc_batch_size:
  164. self.uvdoc_predictor.set_predictor(
  165. batch_size=uvdoc_batch_size, device=device
  166. )
  167. def predict(self, input, **kwargs):
  168. visual_info = {"ocr_text": [], "table_html": [], "table_text": []}
  169. # get all visual result
  170. visual_result = list(self.get_visual_result(input, **kwargs))
  171. # decode visual result to get table_html, table_text, ocr_text
  172. ocr_text, table_text, table_html = self.decode_visual_result(visual_result)
  173. visual_info["ocr_text"] = ocr_text
  174. visual_info["table_html"] = table_html
  175. visual_info["table_text"] = table_text
  176. visual_info = VisualInfoResult(visual_info)
  177. # for local user save visual info in self
  178. self.visual_info = visual_info
  179. return visual_result, visual_info
  180. def get_visual_result(self, inputs, **kwargs):
  181. curve_batch_size = kwargs.get("curve_batch_size")
  182. oricls_batch_size = kwargs.get("oricls_batch_size")
  183. uvdoc_batch_size = kwargs.get("uvdoc_batch_size")
  184. device = kwargs.get("device")
  185. self._set_predictor(
  186. curve_batch_size, oricls_batch_size, uvdoc_batch_size, device
  187. )
  188. input_imgs = []
  189. img_list = []
  190. for file in inputs:
  191. if isinstance(file, str) and file.endswith(".pdf"):
  192. img_list = self.pdf_reader.read(file)
  193. for page, img in enumerate(img_list):
  194. input_imgs.append(
  195. {
  196. "input_path": f"{Path(file).parent}/{Path(file).stem}_{page}.jpg",
  197. "img": img,
  198. }
  199. )
  200. else:
  201. for imgs in self.img_reader(file):
  202. input_imgs.extend(imgs)
  203. # get oricls and uvdoc results
  204. oricls_results = []
  205. if self.oricls_predictor and kwargs.get("use_oricls_model", True):
  206. img_list = [img["img"] for img in input_imgs]
  207. oricls_results = get_oriclas_results(
  208. input_imgs, self.oricls_predictor, img_list
  209. )
  210. uvdoc_results = []
  211. if self.uvdoc_predictor and kwargs.get("use_uvdoc_model", True):
  212. img_list = [img["img"] for img in input_imgs]
  213. uvdoc_results = get_uvdoc_results(
  214. input_imgs, self.uvdoc_predictor, img_list
  215. )
  216. img_list = [img["img"] for img in input_imgs]
  217. for idx, (input_img, layout_pred) in enumerate(
  218. zip(input_imgs, self.layout_predictor(img_list))
  219. ):
  220. single_img_res = {
  221. "input_path": "",
  222. "layout_result": {},
  223. "ocr_result": {},
  224. "table_ocr_result": [],
  225. "table_result": [],
  226. "structure_result": [],
  227. "structure_result": [],
  228. "oricls_result": {},
  229. "uvdoc_result": {},
  230. "curve_result": [],
  231. }
  232. # update oricls and uvdoc result
  233. if oricls_results:
  234. single_img_res["oricls_result"] = oricls_results[idx]
  235. if uvdoc_results:
  236. single_img_res["uvdoc_result"] = uvdoc_results[idx]
  237. # update layout result
  238. single_img_res["input_path"] = layout_pred["input_path"]
  239. single_img_res["layout_result"] = layout_pred
  240. single_img = input_img["img"]
  241. if len(layout_pred["boxes"]) > 0:
  242. subs_of_img = list(self._crop_by_boxes(layout_pred))
  243. # get cropped images with label "table"
  244. table_subs = []
  245. curve_subs = []
  246. structure_res = []
  247. ocr_res_with_layout = []
  248. for sub in subs_of_img:
  249. box = sub["box"]
  250. xmin, ymin, xmax, ymax = [int(i) for i in box]
  251. mask_flag = True
  252. if sub["label"].lower() == "table":
  253. table_subs.append(sub)
  254. elif sub["label"].lower() == "seal":
  255. curve_subs.append(sub)
  256. else:
  257. if self.recovery and kwargs.get("recovery", True):
  258. # TODO: Why use the entire image?
  259. wht_im = (
  260. np.ones(single_img.shape, dtype=single_img.dtype) * 255
  261. )
  262. wht_im[ymin:ymax, xmin:xmax, :] = sub["img"]
  263. sub_ocr_res = get_ocr_res(self.ocr_pipeline, wht_im)
  264. else:
  265. sub_ocr_res = get_ocr_res(self.ocr_pipeline, sub)
  266. sub_ocr_res["dt_polys"] = get_ori_coordinate_for_table(
  267. xmin, ymin, sub_ocr_res["dt_polys"]
  268. )
  269. layout_label = sub["label"].lower()
  270. if sub_ocr_res and sub["label"].lower() in [
  271. "image",
  272. "figure",
  273. "img",
  274. "fig",
  275. ]:
  276. mask_flag = False
  277. else:
  278. ocr_res_with_layout.append(sub_ocr_res)
  279. structure_res.append(
  280. {
  281. "layout_bbox": box,
  282. f"{layout_label}": "\n".join(
  283. sub_ocr_res["rec_text"]
  284. ),
  285. }
  286. )
  287. if mask_flag:
  288. single_img[ymin:ymax, xmin:xmax, :] = 255
  289. curve_pipeline = self.ocr_pipeline
  290. if self.curve_pipeline and kwargs.get("use_curve_model", True):
  291. curve_pipeline = self.curve_pipeline
  292. all_curve_res = get_ocr_res(curve_pipeline, curve_subs)
  293. single_img_res["curve_result"] = all_curve_res
  294. for sub, curve_res in zip(curve_subs, all_curve_res):
  295. structure_res.append(
  296. {
  297. "layout_bbox": sub["box"],
  298. "印章": "".join(curve_res["rec_text"]),
  299. }
  300. )
  301. ocr_res = get_ocr_res(self.ocr_pipeline, single_img)
  302. ocr_res["input_path"] = layout_pred["input_path"]
  303. all_table_res, _ = self.get_table_result(table_subs)
  304. for idx, single_dt_poly in enumerate(ocr_res["dt_polys"]):
  305. structure_res.append(
  306. {
  307. "layout_bbox": convert_4point2rect(single_dt_poly),
  308. "words in text block": ocr_res["rec_text"][idx],
  309. }
  310. )
  311. # update ocr result
  312. for layout_ocr_res in ocr_res_with_layout:
  313. ocr_res["dt_polys"].extend(layout_ocr_res["dt_polys"])
  314. ocr_res["rec_text"].extend(layout_ocr_res["rec_text"])
  315. ocr_res["input_path"] = single_img_res["input_path"]
  316. all_table_ocr_res = []
  317. # get table text from html
  318. structure_res_table, all_table_ocr_res = get_table_text_from_html(
  319. all_table_res
  320. )
  321. structure_res.extend(structure_res_table)
  322. # sort the layout result by the left top point of the box
  323. structure_res = sorted_layout_boxes(structure_res, w=single_img.shape[1])
  324. structure_res = [LayoutStructureResult(item) for item in structure_res]
  325. single_img_res["table_result"] = all_table_res
  326. single_img_res["ocr_result"] = ocr_res
  327. single_img_res["table_ocr_result"] = all_table_ocr_res
  328. single_img_res["structure_result"] = structure_res
  329. yield ChatOCRResult(single_img_res)
  330. def decode_visual_result(self, visual_result):
  331. ocr_text = []
  332. table_text_list = []
  333. table_html = []
  334. for single_img_pred in visual_result:
  335. layout_res = single_img_pred["structure_result"]
  336. layout_res_copy = deepcopy(layout_res)
  337. # layout_res is [{"layout_bbox": [x1, y1, x2, y2], "layout": "single","words in text block":"xxx"}, {"layout_bbox": [x1, y1, x2, y2], "layout": "double","印章":"xxx"}
  338. ocr_res = {}
  339. for block in layout_res_copy:
  340. block.pop("layout_bbox")
  341. block.pop("layout")
  342. for layout_type, text in block.items():
  343. if text == "":
  344. continue
  345. # Table results are used separately
  346. if layout_type == "table":
  347. continue
  348. if layout_type not in ocr_res:
  349. ocr_res[layout_type] = text
  350. else:
  351. ocr_res[layout_type] += f"\n {text}"
  352. single_table_text = " ".join(single_img_pred["table_ocr_result"])
  353. for table_pred in single_img_pred["table_result"]:
  354. html = table_pred["html"]
  355. table_html.append(html)
  356. if ocr_res:
  357. ocr_text.append(ocr_res)
  358. table_text_list.append(single_table_text)
  359. return ocr_text, table_text_list, table_html
  360. def get_vector_text(
  361. self,
  362. llm_name=None,
  363. llm_params={},
  364. visual_info=None,
  365. min_characters=0,
  366. llm_request_interval=1.0,
  367. ):
  368. """get vector for ocr"""
  369. if isinstance(self.llm_api, ErnieBot):
  370. get_vector_flag = True
  371. else:
  372. logging.warning("Do not use ErnieBot, will not get vector text.")
  373. get_vector_flag = False
  374. if not any([visual_info, self.visual_info]):
  375. return {"vector": None}
  376. if visual_info:
  377. # use for serving or local
  378. _visual_info = visual_info
  379. else:
  380. # use for local
  381. _visual_info = self.visual_info
  382. ocr_text = _visual_info["ocr_text"]
  383. html_list = _visual_info["table_html"]
  384. table_text_list = _visual_info["table_text"]
  385. # add table text to ocr text
  386. for html, table_text_rec in zip(html_list, table_text_list):
  387. if len(html) > 3000:
  388. ocr_text.append({"table": table_text_rec})
  389. ocr_all_result = "".join(["\n".join(e.values()) for e in ocr_text])
  390. if len(ocr_all_result) > min_characters and get_vector_flag:
  391. if visual_info and llm_name:
  392. # for serving or local
  393. llm_api = create_llm_api(llm_name, llm_params)
  394. text_result = llm_api.get_vector(ocr_text, llm_request_interval)
  395. else:
  396. # for local
  397. text_result = self.llm_api.get_vector(ocr_text, llm_request_interval)
  398. else:
  399. text_result = str(ocr_text)
  400. return {"vector": text_result}
  401. def get_retrieval_text(
  402. self,
  403. key_list,
  404. visual_info=None,
  405. vector=None,
  406. llm_name=None,
  407. llm_params={},
  408. llm_request_interval=0.1,
  409. ):
  410. if not any([visual_info, vector, self.visual_info, self.vector]):
  411. return {"retrieval": None}
  412. key_list = format_key(key_list)
  413. if not any([vector, self.vector]):
  414. logging.warning(
  415. "The vector library is not created, and is being created automatically"
  416. )
  417. if visual_info and llm_name:
  418. # for serving
  419. vector = self.get_vector_text(
  420. llm_name=llm_name, llm_params=llm_params, visual_info=visual_info
  421. )
  422. else:
  423. self.vector = self.get_vector_text()
  424. if vector and llm_name:
  425. _vector = vector["vector"]
  426. llm_api = create_llm_api(llm_name, llm_params)
  427. retrieval = llm_api.caculate_similar(
  428. vector=_vector,
  429. key_list=key_list,
  430. llm_params=llm_params,
  431. sleep_time=llm_request_interval,
  432. )
  433. else:
  434. _vector = self.vector["vector"]
  435. retrieval = self.llm_api.caculate_similar(
  436. vector=_vector, key_list=key_list, sleep_time=llm_request_interval
  437. )
  438. return {"retrieval": retrieval}
  439. def chat(
  440. self,
  441. key_list,
  442. vector=None,
  443. visual_info=None,
  444. retrieval_result=None,
  445. user_task_description="",
  446. rules="",
  447. few_shot="",
  448. use_vector=True,
  449. save_prompt=False,
  450. llm_name="ernie-3.5",
  451. llm_params={},
  452. ):
  453. """
  454. chat with key
  455. """
  456. if not any(
  457. [vector, visual_info, retrieval_result, self.visual_info, self.vector]
  458. ):
  459. return {"chat_res": "请先完成图像解析再开始再对话", "prompt": ""}
  460. key_list = format_key(key_list)
  461. # first get from table, then get from text in table, last get from all ocr
  462. if visual_info:
  463. # use for serving or local
  464. _visual_info = visual_info
  465. else:
  466. # use for local
  467. _visual_info = self.visual_info
  468. ocr_text = _visual_info["ocr_text"]
  469. html_list = _visual_info["table_html"]
  470. table_text_list = _visual_info["table_text"]
  471. if retrieval_result:
  472. ocr_text = retrieval_result
  473. elif use_vector and any([visual_info, vector]):
  474. # for serving or local
  475. ocr_text = self.get_retrieval_text(
  476. key_list=key_list,
  477. visual_info=visual_info,
  478. vector=vector,
  479. llm_name=llm_name,
  480. llm_params=llm_params,
  481. )
  482. else:
  483. # for local
  484. ocr_text = self.get_retrieval_text(key_list=key_list)
  485. prompt_res = {"ocr_prompt": "str", "table_prompt": [], "html_prompt": []}
  486. final_results = {}
  487. failed_results = ["大模型调用失败", "未知", "未找到关键信息", "None", ""]
  488. if html_list:
  489. prompt_list = self.get_prompt_for_table(
  490. html_list, key_list, rules, few_shot
  491. )
  492. prompt_res["html_prompt"] = prompt_list
  493. for prompt, table_text in zip(prompt_list, table_text_list):
  494. logging.debug(prompt)
  495. res = self.get_llm_result(prompt)
  496. # TODO: why use one html but the whole table_text in next step
  497. if list(res.values())[0] in failed_results:
  498. logging.info(
  499. "table html sequence is too much longer, using ocr directly"
  500. )
  501. prompt = self.get_prompt_for_ocr(
  502. table_text, key_list, rules, few_shot, user_task_description
  503. )
  504. logging.debug(prompt)
  505. prompt_res["table_prompt"].append(prompt)
  506. res = self.get_llm_result(prompt)
  507. for key, value in res.items():
  508. if value not in failed_results and key in key_list:
  509. key_list.remove(key)
  510. final_results[key] = value
  511. if len(key_list) > 0:
  512. logging.info("get result from ocr")
  513. prompt = self.get_prompt_for_ocr(
  514. ocr_text,
  515. key_list,
  516. rules,
  517. few_shot,
  518. user_task_description,
  519. )
  520. logging.debug(prompt)
  521. prompt_res["ocr_prompt"] = prompt
  522. res = self.get_llm_result(prompt)
  523. final_results.update(res)
  524. if not res and not final_results:
  525. final_results = self.llm_api.ERROR_MASSAGE
  526. if save_prompt:
  527. return {"chat_res": final_results, "prompt": prompt_res}
  528. else:
  529. return {"chat_res": final_results, "prompt": ""}
  530. def get_llm_result(self, prompt):
  531. """get llm result and decode to dict"""
  532. llm_result = self.llm_api.pred(prompt)
  533. # when the llm pred failed, return None
  534. if not llm_result:
  535. return None
  536. if "json" in llm_result or "```" in llm_result:
  537. llm_result = (
  538. llm_result.replace("```", "").replace("json", "").replace("/n", "")
  539. )
  540. llm_result = llm_result.replace("[", "").replace("]", "")
  541. try:
  542. llm_result = json.loads(llm_result)
  543. llm_result_final = {}
  544. for key in llm_result:
  545. value = llm_result[key]
  546. if isinstance(value, list):
  547. if len(value) > 0:
  548. llm_result_final[key] = value[0]
  549. else:
  550. llm_result_final[key] = value
  551. return llm_result_final
  552. except:
  553. results = (
  554. llm_result.replace("\n", "")
  555. .replace(" ", "")
  556. .replace("{", "")
  557. .replace("}", "")
  558. )
  559. if not results.endswith('"'):
  560. results = results + '"'
  561. pattern = r'"(.*?)": "([^"]*)"'
  562. matches = re.findall(pattern, str(results))
  563. llm_result = {k: v for k, v in matches}
  564. return llm_result
  565. def get_prompt_for_table(self, table_result, key_list, rules="", few_shot=""):
  566. """get prompt for table"""
  567. prompt_key_information = []
  568. merge_table = ""
  569. for idx, result in enumerate(table_result):
  570. if len(merge_table + result) < 2000:
  571. merge_table += result
  572. if len(merge_table + result) > 2000 or idx == len(table_result) - 1:
  573. single_prompt = self.get_kie_prompt(
  574. merge_table,
  575. key_list,
  576. rules_str=rules,
  577. few_shot_demo_str=few_shot,
  578. prompt_type="table",
  579. )
  580. prompt_key_information.append(single_prompt)
  581. merge_table = ""
  582. return prompt_key_information
  583. def get_prompt_for_ocr(
  584. self,
  585. ocr_result,
  586. key_list,
  587. rules="",
  588. few_shot="",
  589. user_task_description="",
  590. ):
  591. """get prompt for ocr"""
  592. prompt_key_information = self.get_kie_prompt(
  593. ocr_result, key_list, user_task_description, rules, few_shot
  594. )
  595. return prompt_key_information
  596. def get_kie_prompt(
  597. self,
  598. text_result,
  599. key_list,
  600. user_task_description="",
  601. rules_str="",
  602. few_shot_demo_str="",
  603. prompt_type="common",
  604. ):
  605. """get_kie_prompt"""
  606. if prompt_type == "table":
  607. task_description = self.task_prompt_dict["kie_table_prompt"][
  608. "task_description"
  609. ]
  610. else:
  611. task_description = self.task_prompt_dict["kie_common_prompt"][
  612. "task_description"
  613. ]
  614. output_format = self.task_prompt_dict["kie_common_prompt"]["output_format"]
  615. if len(user_task_description) > 0:
  616. task_description = user_task_description
  617. task_description = task_description + output_format
  618. few_shot_demo_key_value = ""
  619. if self.user_prompt_dict:
  620. logging.info("======= common use custom ========")
  621. task_description = self.user_prompt_dict["task_description"]
  622. rules_str = self.user_prompt_dict["rules_str"]
  623. few_shot_demo_str = self.user_prompt_dict["few_shot_demo_str"]
  624. few_shot_demo_key_value = self.user_prompt_dict["few_shot_demo_key_value"]
  625. prompt = f"""{task_description}{rules_str}{few_shot_demo_str}{few_shot_demo_key_value}"""
  626. if prompt_type == "table":
  627. prompt += f"""\n结合上面,下面正式开始:\
  628. 表格内容:```{text_result}```\
  629. 关键词列表:[{key_list}]。""".replace(
  630. " ", ""
  631. )
  632. else:
  633. prompt += f"""\n结合上面的例子,下面正式开始:\
  634. OCR文字:```{text_result}```\
  635. 关键词列表:[{key_list}]。""".replace(
  636. " ", ""
  637. )
  638. return prompt