layout_utils.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. from PIL import Image
  2. from typing import Dict, List
  3. import fitz
  4. from io import BytesIO
  5. import json
  6. from dots_ocr.utils.image_utils import smart_resize
  7. from dots_ocr.utils.consts import MIN_PIXELS, MAX_PIXELS
  8. from dots_ocr.utils.output_cleaner import OutputCleaner
  9. # Define a color map (using RGBA format)
  10. dict_layout_type_to_color = {
  11. "Text": (0, 128, 0, 256), # Green, translucent
  12. "Picture": (255, 0, 255, 256), # Magenta, translucent
  13. "Caption": (255, 165, 0, 256), # Orange, translucent
  14. "Section-header": (0, 255, 255, 256), # Cyan, translucent
  15. "Footnote": (0, 128, 0, 256), # Green, translucent
  16. "Formula": (128, 128, 128, 256), # Gray, translucent
  17. "Table": (255, 192, 203, 256), # Pink, translucent
  18. "Title": (255, 0, 0, 256), # Red, translucent
  19. "List-item": (0, 0, 255, 256), # Blue, translucent
  20. "Page-header": (0, 128, 0, 256), # Green, translucent
  21. "Page-footer": (128, 0, 128, 256), # Purple, translucent
  22. "Other": (165, 42, 42, 256), # Brown, translucent
  23. "Unknown": (0, 0, 0, 0),
  24. }
  25. def draw_layout_on_image(image, cells, resized_height=None, resized_width=None, fill_bbox=True, draw_bbox=True):
  26. """
  27. Draw transparent boxes on an image.
  28. Args:
  29. image: The source PIL Image.
  30. cells: A list of cells containing bounding box information.
  31. resized_height: The resized height.
  32. resized_width: The resized width.
  33. fill_bbox: Whether to fill the bounding box.
  34. draw_bbox: Whether to draw the bounding box.
  35. Returns:
  36. PIL.Image: The image with drawings.
  37. """
  38. # origin_image = Image.open(image_path)
  39. original_width, original_height = image.size
  40. # Create a new PDF document
  41. doc = fitz.open()
  42. # Get image information
  43. img_bytes = BytesIO()
  44. image.save(img_bytes, format='PNG')
  45. # pix = fitz.Pixmap(image_path)
  46. pix = fitz.Pixmap(img_bytes)
  47. # Create a page
  48. page = doc.new_page(width=pix.width, height=pix.height)
  49. page.insert_image(
  50. fitz.Rect(0, 0, pix.width, pix.height),
  51. # filename=image_path
  52. pixmap=pix
  53. )
  54. for i, cell in enumerate(cells):
  55. bbox = cell['bbox']
  56. layout_type = cell['category']
  57. order = i
  58. top_left = (bbox[0], bbox[1])
  59. down_right = (bbox[2], bbox[3])
  60. if resized_height and resized_width:
  61. scale_x = resized_width / original_width
  62. scale_y = resized_height / original_height
  63. top_left = (int(bbox[0] / scale_x), int(bbox[1] / scale_y))
  64. down_right = (int(bbox[2] / scale_x), int(bbox[3] / scale_y))
  65. color = dict_layout_type_to_color.get(layout_type, (0, 128, 0, 256))
  66. color = [col/255 for col in color[:3]]
  67. x0, y0, x1, y1 = top_left[0], top_left[1], down_right[0], down_right[1]
  68. rect_coords = fitz.Rect(x0, y0, x1, y1)
  69. if draw_bbox:
  70. if fill_bbox:
  71. page.draw_rect(
  72. rect_coords,
  73. color=None,
  74. fill=color,
  75. fill_opacity=0.3,
  76. width=0.5,
  77. overlay=True,
  78. ) # Draw the rectangle
  79. else:
  80. page.draw_rect(
  81. rect_coords,
  82. color=color,
  83. fill=None,
  84. fill_opacity=1,
  85. width=0.5,
  86. overlay=True,
  87. ) # Draw the rectangle
  88. order_cate = f"{order}_{layout_type}"
  89. page.insert_text(
  90. (x1, y0 + 20), order_cate, fontsize=20, color=color
  91. ) # Insert the index in the top left corner of the rectangle
  92. # Convert to a Pixmap (maintaining original dimensions)
  93. mat = fitz.Matrix(1.0, 1.0)
  94. pix = page.get_pixmap(matrix=mat)
  95. return Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
  96. def pre_process_bboxes(
  97. origin_image,
  98. bboxes,
  99. input_width,
  100. input_height,
  101. factor: int = 28,
  102. min_pixels: int = 3136,
  103. max_pixels: int = 11289600
  104. ):
  105. assert isinstance(bboxes, list) and len(bboxes) > 0 and isinstance(bboxes[0], list)
  106. min_pixels = min_pixels or MIN_PIXELS
  107. max_pixels = max_pixels or MAX_PIXELS
  108. original_width, original_height = origin_image.size
  109. input_height, input_width = smart_resize(input_height, input_width, min_pixels=min_pixels, max_pixels=max_pixels)
  110. scale_x = original_width / input_width
  111. scale_y = original_height / input_height
  112. bboxes_out = []
  113. for bbox in bboxes:
  114. bbox_resized = [
  115. int(float(bbox[0]) / scale_x),
  116. int(float(bbox[1]) / scale_y),
  117. int(float(bbox[2]) / scale_x),
  118. int(float(bbox[3]) / scale_y)
  119. ]
  120. bboxes_out.append(bbox_resized)
  121. return bboxes_out
  122. def post_process_cells(
  123. origin_image: Image.Image,
  124. cells: List[Dict],
  125. input_width, # server input width, also has smart_resize in server
  126. input_height,
  127. factor: int = 28,
  128. min_pixels: int = 3136,
  129. max_pixels: int = 11289600
  130. ) -> List[Dict]:
  131. """
  132. Post-processes cell bounding boxes, converting coordinates from the resized dimensions back to the original dimensions.
  133. Args:
  134. origin_image: The original PIL Image.
  135. cells: A list of cells containing bounding box information.
  136. input_width: The width of the input image sent to the server.
  137. input_height: The height of the input image sent to the server.
  138. factor: Resizing factor.
  139. min_pixels: Minimum number of pixels.
  140. max_pixels: Maximum number of pixels.
  141. Returns:
  142. A list of post-processed cells.
  143. """
  144. assert isinstance(cells, list) and len(cells) > 0 and isinstance(cells[0], dict)
  145. min_pixels = min_pixels or MIN_PIXELS
  146. max_pixels = max_pixels or MAX_PIXELS
  147. original_width, original_height = origin_image.size
  148. input_height, input_width = smart_resize(input_height, input_width, min_pixels=min_pixels, max_pixels=max_pixels)
  149. scale_x = input_width / original_width
  150. scale_y = input_height / original_height
  151. cells_out = []
  152. for cell in cells:
  153. bbox = cell['bbox']
  154. bbox_resized = [
  155. int(float(bbox[0]) / scale_x),
  156. int(float(bbox[1]) / scale_y),
  157. int(float(bbox[2]) / scale_x),
  158. int(float(bbox[3]) / scale_y)
  159. ]
  160. cell_copy = cell.copy()
  161. cell_copy['bbox'] = bbox_resized
  162. cells_out.append(cell_copy)
  163. return cells_out
  164. def is_legal_bbox(cells):
  165. for cell in cells:
  166. bbox = cell['bbox']
  167. if bbox[2] <= bbox[0] or bbox[3] <= bbox[1]:
  168. return False
  169. return True
  170. def post_process_output(response, prompt_mode, origin_image, input_image, min_pixels=None, max_pixels=None):
  171. if prompt_mode in ["prompt_ocr", "prompt_table_html", "prompt_table_latex", "prompt_formula_latex"]:
  172. return response
  173. json_load_failed = False
  174. cells = response
  175. try:
  176. cells = json.loads(cells)
  177. cells = post_process_cells(
  178. origin_image,
  179. cells,
  180. input_image.width,
  181. input_image.height,
  182. min_pixels=min_pixels,
  183. max_pixels=max_pixels
  184. )
  185. return cells, False
  186. except Exception as e:
  187. print(f"cells post process error: {e}, when using {prompt_mode}")
  188. json_load_failed = True
  189. if json_load_failed:
  190. cleaner = OutputCleaner()
  191. response_clean = cleaner.clean_model_output(cells)
  192. if isinstance(response_clean, list):
  193. response_clean = "\n\n".join([cell['text'] for cell in response_clean if 'text' in cell])
  194. return response_clean, True