vlm_magic_model.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600
  1. import re
  2. from typing import Literal
  3. from loguru import logger
  4. from mineru.utils.boxbase import calculate_overlap_area_in_bbox1_area_ratio
  5. from mineru.utils.enum_class import ContentType, BlockType
  6. from mineru.utils.guess_suffix_or_lang import guess_language_by_text
  7. from mineru.utils.magic_model_utils import reduct_overlap, tie_up_category_by_distance_v3
  8. class MagicModel:
  9. def __init__(self, page_blocks: list, width, height):
  10. self.page_blocks = page_blocks
  11. blocks = []
  12. self.all_spans = []
  13. # 解析每个块
  14. for index, block_info in enumerate(page_blocks):
  15. block_bbox = block_info["bbox"]
  16. try:
  17. x1, y1, x2, y2 = block_bbox
  18. x_1, y_1, x_2, y_2 = (
  19. int(x1 * width),
  20. int(y1 * height),
  21. int(x2 * width),
  22. int(y2 * height),
  23. )
  24. if x_2 < x_1:
  25. x_1, x_2 = x_2, x_1
  26. if y_2 < y_1:
  27. y_1, y_2 = y_2, y_1
  28. block_bbox = (x_1, y_1, x_2, y_2)
  29. block_type = block_info["type"]
  30. block_content = block_info["content"]
  31. block_angle = block_info["angle"]
  32. # print(f"坐标: {block_bbox}")
  33. # print(f"类型: {block_type}")
  34. # print(f"内容: {block_content}")
  35. # print("-" * 50)
  36. except Exception as e:
  37. # 如果解析失败,可能是因为格式不正确,跳过这个块
  38. logger.warning(f"Invalid block format: {block_info}, error: {e}")
  39. continue
  40. span_type = "unknown"
  41. line_type = None
  42. guess_lang = None
  43. if block_type in [
  44. "text",
  45. "title",
  46. "image_caption",
  47. "image_footnote",
  48. "table_caption",
  49. "table_footnote",
  50. "code_caption",
  51. "ref_text",
  52. "phonetic",
  53. "header",
  54. "footer",
  55. "page_number",
  56. "aside_text",
  57. "page_footnote",
  58. "list"
  59. ]:
  60. span_type = ContentType.TEXT
  61. elif block_type in ["image"]:
  62. block_type = BlockType.IMAGE_BODY
  63. span_type = ContentType.IMAGE
  64. elif block_type in ["table"]:
  65. block_type = BlockType.TABLE_BODY
  66. span_type = ContentType.TABLE
  67. elif block_type in ["code", "algorithm"]:
  68. block_content = code_content_clean(block_content)
  69. line_type = block_type
  70. block_type = BlockType.CODE_BODY
  71. span_type = ContentType.TEXT
  72. guess_lang = guess_language_by_text(block_content)
  73. elif block_type in ["equation"]:
  74. block_type = BlockType.INTERLINE_EQUATION
  75. span_type = ContentType.INTERLINE_EQUATION
  76. if span_type in ["image", "table"]:
  77. span = {
  78. "bbox": block_bbox,
  79. "type": span_type,
  80. }
  81. if span_type == ContentType.TABLE:
  82. span["html"] = block_content
  83. elif span_type in [ContentType.INTERLINE_EQUATION]:
  84. span = {
  85. "bbox": block_bbox,
  86. "type": span_type,
  87. "content": isolated_formula_clean(block_content),
  88. }
  89. else:
  90. if block_content:
  91. block_content = clean_content(block_content)
  92. if block_content and block_content.count("\\(") == block_content.count("\\)") and block_content.count("\\(") > 0:
  93. # 生成包含文本和公式的span列表
  94. spans = []
  95. last_end = 0
  96. # 查找所有公式
  97. for match in re.finditer(r'\\\((.+?)\\\)', block_content):
  98. start, end = match.span()
  99. # 添加公式前的文本
  100. if start > last_end:
  101. text_before = block_content[last_end:start]
  102. if text_before.strip():
  103. spans.append({
  104. "bbox": block_bbox,
  105. "type": ContentType.TEXT,
  106. "content": text_before
  107. })
  108. # 添加公式(去除\(和\))
  109. formula = match.group(1)
  110. spans.append({
  111. "bbox": block_bbox,
  112. "type": ContentType.INLINE_EQUATION,
  113. "content": formula.strip()
  114. })
  115. last_end = end
  116. # 添加最后一个公式后的文本
  117. if last_end < len(block_content):
  118. text_after = block_content[last_end:]
  119. if text_after.strip():
  120. spans.append({
  121. "bbox": block_bbox,
  122. "type": ContentType.TEXT,
  123. "content": text_after
  124. })
  125. span = spans
  126. else:
  127. span = {
  128. "bbox": block_bbox,
  129. "type": span_type,
  130. "content": block_content,
  131. }
  132. # 处理span类型并添加到all_spans
  133. if isinstance(span, dict) and "bbox" in span:
  134. self.all_spans.append(span)
  135. spans = [span]
  136. elif isinstance(span, list):
  137. self.all_spans.extend(span)
  138. spans = span
  139. else:
  140. raise ValueError(f"Invalid span type: {span_type}, expected dict or list, got {type(span)}")
  141. # 构造line对象
  142. if block_type in [BlockType.CODE_BODY]:
  143. line = {"bbox": block_bbox, "spans": spans, "extra": {"type": line_type, "guess_lang": guess_lang}}
  144. else:
  145. line = {"bbox": block_bbox, "spans": spans}
  146. blocks.append(
  147. {
  148. "bbox": block_bbox,
  149. "type": block_type,
  150. "angle": block_angle,
  151. "lines": [line],
  152. "index": index,
  153. }
  154. )
  155. self.image_blocks = []
  156. self.table_blocks = []
  157. self.interline_equation_blocks = []
  158. self.text_blocks = []
  159. self.title_blocks = []
  160. self.code_blocks = []
  161. self.discarded_blocks = []
  162. self.ref_text_blocks = []
  163. self.phonetic_blocks = []
  164. self.list_blocks = []
  165. for block in blocks:
  166. if block["type"] in [BlockType.IMAGE_BODY, BlockType.IMAGE_CAPTION, BlockType.IMAGE_FOOTNOTE]:
  167. self.image_blocks.append(block)
  168. elif block["type"] in [BlockType.TABLE_BODY, BlockType.TABLE_CAPTION, BlockType.TABLE_FOOTNOTE]:
  169. self.table_blocks.append(block)
  170. elif block["type"] in [BlockType.CODE_BODY, BlockType.CODE_CAPTION]:
  171. self.code_blocks.append(block)
  172. elif block["type"] == BlockType.INTERLINE_EQUATION:
  173. self.interline_equation_blocks.append(block)
  174. elif block["type"] == BlockType.TEXT:
  175. self.text_blocks.append(block)
  176. elif block["type"] == BlockType.TITLE:
  177. self.title_blocks.append(block)
  178. elif block["type"] in [BlockType.REF_TEXT]:
  179. self.ref_text_blocks.append(block)
  180. elif block["type"] in [BlockType.PHONETIC]:
  181. self.phonetic_blocks.append(block)
  182. elif block["type"] in [BlockType.HEADER, BlockType.FOOTER, BlockType.PAGE_NUMBER, BlockType.ASIDE_TEXT, BlockType.PAGE_FOOTNOTE]:
  183. self.discarded_blocks.append(block)
  184. elif block["type"] == BlockType.LIST:
  185. self.list_blocks.append(block)
  186. else:
  187. continue
  188. self.list_blocks, self.text_blocks, self.ref_text_blocks = fix_list_blocks(self.list_blocks, self.text_blocks, self.ref_text_blocks)
  189. self.image_blocks, not_include_image_blocks = fix_two_layer_blocks(self.image_blocks, BlockType.IMAGE)
  190. self.table_blocks, not_include_table_blocks = fix_two_layer_blocks(self.table_blocks, BlockType.TABLE)
  191. self.code_blocks, not_include_code_blocks = fix_two_layer_blocks(self.code_blocks, BlockType.CODE)
  192. for code_block in self.code_blocks:
  193. for block in code_block['blocks']:
  194. if block['type'] == BlockType.CODE_BODY:
  195. if len(block["lines"]) > 0:
  196. line = block["lines"][0]
  197. code_block["sub_type"] = line["extra"]["type"]
  198. if code_block["sub_type"] in ["code"]:
  199. code_block["guess_lang"] = line["extra"]["guess_lang"]
  200. del line["extra"]
  201. else:
  202. code_block["sub_type"] = "code"
  203. code_block["guess_lang"] = "txt"
  204. for block in not_include_image_blocks + not_include_table_blocks + not_include_code_blocks:
  205. block["type"] = BlockType.TEXT
  206. self.text_blocks.append(block)
  207. def get_list_blocks(self):
  208. return self.list_blocks
  209. def get_image_blocks(self):
  210. return self.image_blocks
  211. def get_table_blocks(self):
  212. return self.table_blocks
  213. def get_code_blocks(self):
  214. return self.code_blocks
  215. def get_ref_text_blocks(self):
  216. return self.ref_text_blocks
  217. def get_phonetic_blocks(self):
  218. return self.phonetic_blocks
  219. def get_title_blocks(self):
  220. return self.title_blocks
  221. def get_text_blocks(self):
  222. return self.text_blocks
  223. def get_interline_equation_blocks(self):
  224. return self.interline_equation_blocks
  225. def get_discarded_blocks(self):
  226. return self.discarded_blocks
  227. def get_all_spans(self):
  228. return self.all_spans
  229. def isolated_formula_clean(txt):
  230. latex = txt[:]
  231. if latex.startswith("\\["): latex = latex[2:]
  232. if latex.endswith("\\]"): latex = latex[:-2]
  233. latex = latex.strip()
  234. return latex
  235. def code_content_clean(content):
  236. """清理代码内容,移除Markdown代码块的开始和结束标记"""
  237. if not content:
  238. return ""
  239. lines = content.splitlines()
  240. start_idx = 0
  241. end_idx = len(lines)
  242. # 处理开头的三个反引号
  243. if lines and lines[0].startswith("```"):
  244. start_idx = 1
  245. # 处理结尾的三个反引号
  246. if lines and end_idx > start_idx and lines[end_idx - 1].strip() == "```":
  247. end_idx -= 1
  248. # 只有在有内容时才进行join操作
  249. if start_idx < end_idx:
  250. return "\n".join(lines[start_idx:end_idx]).strip()
  251. return ""
  252. def clean_content(content):
  253. if content and content.count("\\[") == content.count("\\]") and content.count("\\[") > 0:
  254. # Function to handle each match
  255. def replace_pattern(match):
  256. # Extract content between \[ and \]
  257. inner_content = match.group(1)
  258. return f"[{inner_content}]"
  259. # Find all patterns of \[x\] and apply replacement
  260. pattern = r'\\\[(.*?)\\\]'
  261. content = re.sub(pattern, replace_pattern, content)
  262. return content
  263. def __tie_up_category_by_distance_v3(blocks, subject_block_type, object_block_type):
  264. # 定义获取主体和客体对象的函数
  265. def get_subjects():
  266. return reduct_overlap(
  267. list(
  268. map(
  269. lambda x: {"bbox": x["bbox"], "lines": x["lines"], "index": x["index"], "angle":x["angle"]},
  270. filter(
  271. lambda x: x["type"] == subject_block_type,
  272. blocks,
  273. ),
  274. )
  275. )
  276. )
  277. def get_objects():
  278. return reduct_overlap(
  279. list(
  280. map(
  281. lambda x: {"bbox": x["bbox"], "lines": x["lines"], "index": x["index"], "angle":x["angle"]},
  282. filter(
  283. lambda x: x["type"] == object_block_type,
  284. blocks,
  285. ),
  286. )
  287. )
  288. )
  289. # 调用通用方法
  290. return tie_up_category_by_distance_v3(
  291. get_subjects,
  292. get_objects
  293. )
  294. def get_type_blocks(blocks, block_type: Literal["image", "table", "code"]):
  295. with_captions = __tie_up_category_by_distance_v3(blocks, f"{block_type}_body", f"{block_type}_caption")
  296. with_footnotes = __tie_up_category_by_distance_v3(blocks, f"{block_type}_body", f"{block_type}_footnote")
  297. ret = []
  298. for v in with_captions:
  299. record = {
  300. f"{block_type}_body": v["sub_bbox"],
  301. f"{block_type}_caption_list": v["obj_bboxes"],
  302. }
  303. filter_idx = v["sub_idx"]
  304. d = next(filter(lambda x: x["sub_idx"] == filter_idx, with_footnotes))
  305. record[f"{block_type}_footnote_list"] = d["obj_bboxes"]
  306. ret.append(record)
  307. return ret
  308. def fix_two_layer_blocks_back(blocks, fix_type: Literal["image", "table", "code"]):
  309. need_fix_blocks = get_type_blocks(blocks, fix_type)
  310. fixed_blocks = []
  311. not_include_blocks = []
  312. processed_indices = set()
  313. # 处理需要组织成two_layer结构的blocks
  314. for block in need_fix_blocks:
  315. body = block[f"{fix_type}_body"]
  316. caption_list = block[f"{fix_type}_caption_list"]
  317. footnote_list = block[f"{fix_type}_footnote_list"]
  318. body["type"] = f"{fix_type}_body"
  319. for caption in caption_list:
  320. caption["type"] = f"{fix_type}_caption"
  321. processed_indices.add(caption["index"])
  322. for footnote in footnote_list:
  323. footnote["type"] = f"{fix_type}_footnote"
  324. processed_indices.add(footnote["index"])
  325. processed_indices.add(body["index"])
  326. two_layer_block = {
  327. "type": fix_type,
  328. "bbox": body["bbox"],
  329. "blocks": [
  330. body,
  331. ],
  332. "index": body["index"],
  333. }
  334. two_layer_block["blocks"].extend([*caption_list, *footnote_list])
  335. fixed_blocks.append(two_layer_block)
  336. # 添加未处理的blocks
  337. for block in blocks:
  338. if block["index"] not in processed_indices:
  339. # 直接添加未处理的block
  340. not_include_blocks.append(block)
  341. return fixed_blocks, not_include_blocks
  342. def fix_two_layer_blocks(blocks, fix_type: Literal["image", "table", "code"]):
  343. need_fix_blocks = get_type_blocks(blocks, fix_type)
  344. fixed_blocks = []
  345. not_include_blocks = []
  346. processed_indices = set()
  347. # 特殊处理表格类型,确保标题在表格前,注脚在表格后
  348. if fix_type == "table":
  349. # 收集所有不合适的caption和footnote
  350. misplaced_captions = [] # 存储(caption, 原始block索引)
  351. misplaced_footnotes = [] # 存储(footnote, 原始block索引)
  352. # 第一步:移除不符合位置要求的caption和footnote
  353. for block_idx, block in enumerate(need_fix_blocks):
  354. body = block[f"{fix_type}_body"]
  355. body_index = body["index"]
  356. # 检查caption应在body前或同位置
  357. valid_captions = []
  358. for caption in block[f"{fix_type}_caption_list"]:
  359. if caption["index"] <= body_index:
  360. valid_captions.append(caption)
  361. else:
  362. misplaced_captions.append((caption, block_idx))
  363. block[f"{fix_type}_caption_list"] = valid_captions
  364. # 检查footnote应在body后或同位置
  365. valid_footnotes = []
  366. for footnote in block[f"{fix_type}_footnote_list"]:
  367. if footnote["index"] >= body_index:
  368. valid_footnotes.append(footnote)
  369. else:
  370. misplaced_footnotes.append((footnote, block_idx))
  371. block[f"{fix_type}_footnote_list"] = valid_footnotes
  372. # 第二步:重新分配不合规的caption到合适的body
  373. for caption, original_block_idx in misplaced_captions:
  374. caption_index = caption["index"]
  375. best_block_idx = None
  376. min_distance = float('inf')
  377. # 寻找索引大于等于caption_index的最近body
  378. for idx, block in enumerate(need_fix_blocks):
  379. body_index = block[f"{fix_type}_body"]["index"]
  380. if body_index >= caption_index and idx != original_block_idx:
  381. distance = body_index - caption_index
  382. if distance < min_distance:
  383. min_distance = distance
  384. best_block_idx = idx
  385. if best_block_idx is not None:
  386. # 找到合适的body,添加到对应block的caption_list
  387. need_fix_blocks[best_block_idx][f"{fix_type}_caption_list"].append(caption)
  388. else:
  389. # 没找到合适的body,作为普通block处理
  390. not_include_blocks.append(caption)
  391. # 第三步:重新分配不合规的footnote到合适的body
  392. for footnote, original_block_idx in misplaced_footnotes:
  393. footnote_index = footnote["index"]
  394. best_block_idx = None
  395. min_distance = float('inf')
  396. # 寻找索引小于等于footnote_index的最近body
  397. for idx, block in enumerate(need_fix_blocks):
  398. body_index = block[f"{fix_type}_body"]["index"]
  399. if body_index <= footnote_index and idx != original_block_idx:
  400. distance = footnote_index - body_index
  401. if distance < min_distance:
  402. min_distance = distance
  403. best_block_idx = idx
  404. if best_block_idx is not None:
  405. # 找到合适的body,添加到对应block的footnote_list
  406. need_fix_blocks[best_block_idx][f"{fix_type}_footnote_list"].append(footnote)
  407. else:
  408. # 没找到合适的body,作为普通block处理
  409. not_include_blocks.append(footnote)
  410. # 第四步:将每个block的caption_list和footnote_list中不连续index的元素提出来作为普通block处理
  411. for block in need_fix_blocks:
  412. caption_list = block[f"{fix_type}_caption_list"]
  413. footnote_list = block[f"{fix_type}_footnote_list"]
  414. body_index = block[f"{fix_type}_body"]["index"]
  415. # 处理caption_list (从body往前看,caption在body之前)
  416. if caption_list:
  417. # 按index降序排列,从最接近body的开始检查
  418. caption_list.sort(key=lambda x: x["index"], reverse=True)
  419. filtered_captions = [caption_list[0]]
  420. for i in range(1, len(caption_list)):
  421. # 检查是否与前一个caption连续(降序所以是-1)
  422. if caption_list[i]["index"] == caption_list[i - 1]["index"] - 1:
  423. filtered_captions.append(caption_list[i])
  424. else:
  425. # 出现gap,后续所有caption都作为普通block
  426. not_include_blocks.extend(caption_list[i:])
  427. break
  428. # 恢复升序
  429. filtered_captions.reverse()
  430. block[f"{fix_type}_caption_list"] = filtered_captions
  431. # 处理footnote_list (从body往后看,footnote在body之后)
  432. if footnote_list:
  433. # 按index升序排列,从最接近body的开始检查
  434. footnote_list.sort(key=lambda x: x["index"])
  435. filtered_footnotes = [footnote_list[0]]
  436. for i in range(1, len(footnote_list)):
  437. # 检查是否与前一个footnote连续
  438. if footnote_list[i]["index"] == footnote_list[i - 1]["index"] + 1:
  439. filtered_footnotes.append(footnote_list[i])
  440. else:
  441. # 出现gap,后续所有footnote都作为普通block
  442. not_include_blocks.extend(footnote_list[i:])
  443. break
  444. block[f"{fix_type}_footnote_list"] = filtered_footnotes
  445. # 构建两层结构blocks
  446. for block in need_fix_blocks:
  447. body = block[f"{fix_type}_body"]
  448. caption_list = block[f"{fix_type}_caption_list"]
  449. footnote_list = block[f"{fix_type}_footnote_list"]
  450. body["type"] = f"{fix_type}_body"
  451. for caption in caption_list:
  452. caption["type"] = f"{fix_type}_caption"
  453. processed_indices.add(caption["index"])
  454. for footnote in footnote_list:
  455. footnote["type"] = f"{fix_type}_footnote"
  456. processed_indices.add(footnote["index"])
  457. processed_indices.add(body["index"])
  458. two_layer_block = {
  459. "type": fix_type,
  460. "bbox": body["bbox"],
  461. "blocks": [body],
  462. "index": body["index"],
  463. }
  464. two_layer_block["blocks"].extend([*caption_list, *footnote_list])
  465. # 对blocks按index排序
  466. two_layer_block["blocks"].sort(key=lambda x: x["index"])
  467. fixed_blocks.append(two_layer_block)
  468. # 添加未处理的blocks
  469. for block in blocks:
  470. block.pop("type", None)
  471. if block["index"] not in processed_indices and block not in not_include_blocks:
  472. not_include_blocks.append(block)
  473. return fixed_blocks, not_include_blocks
  474. def fix_list_blocks(list_blocks, text_blocks, ref_text_blocks):
  475. for list_block in list_blocks:
  476. list_block["blocks"] = []
  477. if "lines" in list_block:
  478. del list_block["lines"]
  479. temp_text_blocks = text_blocks + ref_text_blocks
  480. need_remove_blocks = []
  481. for block in temp_text_blocks:
  482. for list_block in list_blocks:
  483. if calculate_overlap_area_in_bbox1_area_ratio(block["bbox"], list_block["bbox"]) >= 0.8:
  484. list_block["blocks"].append(block)
  485. need_remove_blocks.append(block)
  486. break
  487. for block in need_remove_blocks:
  488. if block in text_blocks:
  489. text_blocks.remove(block)
  490. elif block in ref_text_blocks:
  491. ref_text_blocks.remove(block)
  492. # 移除blocks为空的list_block
  493. list_blocks = [lb for lb in list_blocks if lb["blocks"]]
  494. for list_block in list_blocks:
  495. # 统计list_block["blocks"]中所有block的type,用众数作为list_block的sub_type
  496. type_count = {}
  497. line_content = []
  498. for sub_block in list_block["blocks"]:
  499. sub_block_type = sub_block["type"]
  500. if sub_block_type not in type_count:
  501. type_count[sub_block_type] = 0
  502. type_count[sub_block_type] += 1
  503. if type_count:
  504. list_block["sub_type"] = max(type_count, key=type_count.get)
  505. else:
  506. list_block["sub_type"] = "unknown"
  507. return list_blocks, text_blocks, ref_text_blocks