| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228 |
- from PIL import Image
- from typing import Dict, List
- import fitz
- from io import BytesIO
- import json
- from dots_ocr.utils.image_utils import smart_resize
- from dots_ocr.utils.consts import MIN_PIXELS, MAX_PIXELS
- from dots_ocr.utils.output_cleaner import OutputCleaner
- # Define a color map (using RGBA format)
- dict_layout_type_to_color = {
- "Text": (0, 128, 0, 256), # Green, translucent
- "Picture": (255, 0, 255, 256), # Magenta, translucent
- "Caption": (255, 165, 0, 256), # Orange, translucent
- "Section-header": (0, 255, 255, 256), # Cyan, translucent
- "Footnote": (0, 128, 0, 256), # Green, translucent
- "Formula": (128, 128, 128, 256), # Gray, translucent
- "Table": (255, 192, 203, 256), # Pink, translucent
- "Title": (255, 0, 0, 256), # Red, translucent
- "List-item": (0, 0, 255, 256), # Blue, translucent
- "Page-header": (0, 128, 0, 256), # Green, translucent
- "Page-footer": (128, 0, 128, 256), # Purple, translucent
- "Other": (165, 42, 42, 256), # Brown, translucent
- "Unknown": (0, 0, 0, 0),
- }
- def draw_layout_on_image(image, cells, resized_height=None, resized_width=None, fill_bbox=True, draw_bbox=True):
- """
- Draw transparent boxes on an image.
-
- Args:
- image: The source PIL Image.
- cells: A list of cells containing bounding box information.
- resized_height: The resized height.
- resized_width: The resized width.
- fill_bbox: Whether to fill the bounding box.
- draw_bbox: Whether to draw the bounding box.
-
- Returns:
- PIL.Image: The image with drawings.
- """
- # origin_image = Image.open(image_path)
- original_width, original_height = image.size
-
- # Create a new PDF document
- doc = fitz.open()
-
- # Get image information
- img_bytes = BytesIO()
- image.save(img_bytes, format='PNG')
- # pix = fitz.Pixmap(image_path)
- pix = fitz.Pixmap(img_bytes)
-
- # Create a page
- page = doc.new_page(width=pix.width, height=pix.height)
- page.insert_image(
- fitz.Rect(0, 0, pix.width, pix.height),
- # filename=image_path
- pixmap=pix
- )
- for i, cell in enumerate(cells):
- bbox = cell['bbox']
- layout_type = cell['category']
- order = i
-
- top_left = (bbox[0], bbox[1])
- down_right = (bbox[2], bbox[3])
- if resized_height and resized_width:
- scale_x = resized_width / original_width
- scale_y = resized_height / original_height
- top_left = (int(bbox[0] / scale_x), int(bbox[1] / scale_y))
- down_right = (int(bbox[2] / scale_x), int(bbox[3] / scale_y))
-
- color = dict_layout_type_to_color.get(layout_type, (0, 128, 0, 256))
- color = [col/255 for col in color[:3]]
- x0, y0, x1, y1 = top_left[0], top_left[1], down_right[0], down_right[1]
- rect_coords = fitz.Rect(x0, y0, x1, y1)
- if draw_bbox:
- if fill_bbox:
- page.draw_rect(
- rect_coords,
- color=None,
- fill=color,
- fill_opacity=0.3,
- width=0.5,
- overlay=True,
- ) # Draw the rectangle
- else:
- page.draw_rect(
- rect_coords,
- color=color,
- fill=None,
- fill_opacity=1,
- width=0.5,
- overlay=True,
- ) # Draw the rectangle
- order_cate = f"{order}_{layout_type}"
- page.insert_text(
- (x1, y0 + 20), order_cate, fontsize=20, color=color
- ) # Insert the index in the top left corner of the rectangle
- # Convert to a Pixmap (maintaining original dimensions)
- mat = fitz.Matrix(1.0, 1.0)
- pix = page.get_pixmap(matrix=mat)
- return Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
- def pre_process_bboxes(
- origin_image,
- bboxes,
- input_width,
- input_height,
- factor: int = 28,
- min_pixels: int = 3136,
- max_pixels: int = 11289600
- ):
- assert isinstance(bboxes, list) and len(bboxes) > 0 and isinstance(bboxes[0], list)
- min_pixels = min_pixels or MIN_PIXELS
- max_pixels = max_pixels or MAX_PIXELS
- original_width, original_height = origin_image.size
- input_height, input_width = smart_resize(input_height, input_width, min_pixels=min_pixels, max_pixels=max_pixels)
-
- scale_x = original_width / input_width
- scale_y = original_height / input_height
- bboxes_out = []
- for bbox in bboxes:
- bbox_resized = [
- int(float(bbox[0]) / scale_x),
- int(float(bbox[1]) / scale_y),
- int(float(bbox[2]) / scale_x),
- int(float(bbox[3]) / scale_y)
- ]
- bboxes_out.append(bbox_resized)
-
- return bboxes_out
- def post_process_cells(
- origin_image: Image.Image,
- cells: List[Dict],
- input_width, # server input width, also has smart_resize in server
- input_height,
- factor: int = 28,
- min_pixels: int = 3136,
- max_pixels: int = 11289600
- ) -> List[Dict]:
- """
- Post-processes cell bounding boxes, converting coordinates from the resized dimensions back to the original dimensions.
-
- Args:
- origin_image: The original PIL Image.
- cells: A list of cells containing bounding box information.
- input_width: The width of the input image sent to the server.
- input_height: The height of the input image sent to the server.
- factor: Resizing factor.
- min_pixels: Minimum number of pixels.
- max_pixels: Maximum number of pixels.
-
- Returns:
- A list of post-processed cells.
- """
- assert isinstance(cells, list) and len(cells) > 0 and isinstance(cells[0], dict)
- min_pixels = min_pixels or MIN_PIXELS
- max_pixels = max_pixels or MAX_PIXELS
- original_width, original_height = origin_image.size
- input_height, input_width = smart_resize(input_height, input_width, min_pixels=min_pixels, max_pixels=max_pixels)
-
- scale_x = input_width / original_width
- scale_y = input_height / original_height
-
- cells_out = []
- for cell in cells:
- bbox = cell['bbox']
- bbox_resized = [
- int(float(bbox[0]) / scale_x),
- int(float(bbox[1]) / scale_y),
- int(float(bbox[2]) / scale_x),
- int(float(bbox[3]) / scale_y)
- ]
- cell_copy = cell.copy()
- cell_copy['bbox'] = bbox_resized
- cells_out.append(cell_copy)
-
- return cells_out
- def is_legal_bbox(cells):
- for cell in cells:
- bbox = cell['bbox']
- if bbox[2] <= bbox[0] or bbox[3] <= bbox[1]:
- return False
- return True
- def post_process_output(response, prompt_mode, origin_image, input_image, min_pixels=None, max_pixels=None):
- if prompt_mode in ["prompt_ocr", "prompt_table_html", "prompt_table_latex", "prompt_formula_latex"]:
- return response
- json_load_failed = False
- cells = response
- try:
- cells = json.loads(cells)
- cells = post_process_cells(
- origin_image,
- cells,
- input_image.width,
- input_image.height,
- min_pixels=min_pixels,
- max_pixels=max_pixels
- )
- return cells, False
- except Exception as e:
- print(f"cells post process error: {e}, when using {prompt_mode}")
- json_load_failed = True
- if json_load_failed:
- cleaner = OutputCleaner()
- response_clean = cleaner.clean_model_output(cells)
- if isinstance(response_clean, list):
- response_clean = "\n\n".join([cell['text'] for cell in response_clean if 'text' in cell])
- return response_clean, True
|