瀏覽代碼

feat: enhance image processing by introducing ImageType enum and updating related functions

myhloli 3 月之前
父節點
當前提交
ff11e602fc

+ 3 - 1
mineru/backend/pipeline/model_json_to_middle_json.py

@@ -17,6 +17,7 @@ from mineru.utils.llm_aided import llm_aided_title
 from mineru.utils.model_utils import clean_memory
 from mineru.backend.pipeline.pipeline_magic_model import MagicModel
 from mineru.utils.ocr_utils import OcrConfidence
+from mineru.utils.pdf_reader import image_to_b64str
 from mineru.utils.span_block_fix import fill_spans_in_blocks, fix_discarded_block, fix_block_spans
 from mineru.utils.span_pre_proc import remove_outside_spans, remove_overlaps_low_confidence_spans, \
     remove_overlaps_min_spans, txt_spans_extract
@@ -27,7 +28,8 @@ from mineru.utils.hash_utils import str_md5
 def page_model_info_to_page_info(page_model_info, image_dict, page, image_writer, page_index, ocr_enable=False, formula_enabled=True):
     scale = image_dict["scale"]
     page_pil_img = image_dict["img_pil"]
-    page_img_md5 = str_md5(image_dict["img_base64"])
+    # page_img_md5 = str_md5(image_dict["img_base64"])
+    page_img_md5 = str_md5(image_to_b64str(page_pil_img))
     page_w, page_h = map(int, page.get_size())
     magic_model = MagicModel(page_model_info, scale)
 

+ 2 - 1
mineru/backend/pipeline/pipeline_analyze.py

@@ -6,6 +6,7 @@ from loguru import logger
 
 from .model_init import MineruPipelineModel
 from mineru.utils.config_reader import get_device
+from ...utils.enum_class import ImageType
 from ...utils.pdf_classify import classify
 from ...utils.pdf_image_tools import load_images_from_pdf
 from ...utils.model_utils import get_vram, clean_memory
@@ -98,7 +99,7 @@ def doc_analyze(
         _lang = lang_list[pdf_idx]
 
         # 收集每个数据集中的页面
-        images_list, pdf_doc = load_images_from_pdf(pdf_bytes)
+        images_list, pdf_doc = load_images_from_pdf(pdf_bytes, image_type=ImageType.PIL)
         all_image_lists.append(images_list)
         all_pdf_docs.append(pdf_doc)
         for page_idx in range(len(images_list)):

+ 3 - 1
mineru/backend/vlm/token_to_middle_json.py

@@ -8,6 +8,7 @@ from mineru.utils.enum_class import ContentType
 from mineru.utils.hash_utils import str_md5
 from mineru.backend.vlm.vlm_magic_model import MagicModel
 from mineru.utils.pdf_image_tools import get_crop_img
+from mineru.utils.pdf_reader import base64_to_pil_image
 from mineru.version import __version__
 
 heading_level_import_success = False
@@ -32,7 +33,8 @@ def token_to_page_info(token, image_dict, page, image_writer, page_index) -> dic
     # 提取所有完整块,每个块从<|box_start|>开始到<|md_end|>或<|im_end|>结束
 
     scale = image_dict["scale"]
-    page_pil_img = image_dict["img_pil"]
+    # page_pil_img = image_dict["img_pil"]
+    page_pil_img = base64_to_pil_image(image_dict["img_base64"])
     page_img_md5 = str_md5(image_dict["img_base64"])
     width, height = map(int, page.get_size())
 

+ 2 - 1
mineru/backend/vlm/vlm_analyze.py

@@ -8,6 +8,7 @@ from mineru.utils.pdf_image_tools import load_images_from_pdf
 from .base_predictor import BasePredictor
 from .predictor import get_predictor
 from .token_to_middle_json import result_to_middle_json
+from ...utils.enum_class import ImageType
 from ...utils.models_download_utils import auto_download_and_get_model_root_path
 
 
@@ -53,7 +54,7 @@ def doc_analyze(
         predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
 
     # load_images_start = time.time()
-    images_list, pdf_doc = load_images_from_pdf(pdf_bytes)
+    images_list, pdf_doc = load_images_from_pdf(pdf_bytes, image_type=ImageType.BASE64)
     images_base64_list = [image_dict["img_base64"] for image_dict in images_list]
     # load_images_time = round(time.time() - load_images_start, 2)
     # logger.info(f"load images cost: {load_images_time}, speed: {round(len(images_base64_list)/load_images_time, 3)} images/s")

+ 6 - 1
mineru/utils/enum_class.py

@@ -66,4 +66,9 @@ class ModelPath:
 
 class SplitFlag:
     CROSS_PAGE = 'cross_page'
-    LINES_DELETED = 'lines_deleted'
+    LINES_DELETED = 'lines_deleted'
+
+
+class ImageType:
+    PIL = 'pil_img'
+    BASE64 = 'base64_img'

+ 10 - 6
mineru/utils/pdf_image_tools.py

@@ -7,27 +7,30 @@ from PIL import Image
 
 from mineru.data.data_reader_writer import FileBasedDataWriter
 from mineru.utils.pdf_reader import image_to_b64str, image_to_bytes, page_to_image
+from .enum_class import ImageType
 from .hash_utils import str_sha256
 
 
-def pdf_page_to_image(page: pdfium.PdfPage, dpi=200) -> dict:
+def pdf_page_to_image(page: pdfium.PdfPage, dpi=200, image_type=ImageType.PIL) -> dict:
     """Convert pdfium.PdfDocument to image, Then convert the image to base64.
 
     Args:
         page (_type_): pdfium.PdfPage
         dpi (int, optional): reset the dpi of dpi. Defaults to 200.
+        image_type (ImageType, optional): The type of image to return. Defaults to ImageType.PIL.
 
     Returns:
         dict:  {'img_base64': str, 'img_pil': pil_img, 'scale': float }
     """
     pil_img, scale = page_to_image(page, dpi=dpi)
-    img_base64 = image_to_b64str(pil_img)
-
     image_dict = {
-        "img_base64": img_base64,
-        "img_pil": pil_img,
         "scale": scale,
     }
+    if image_type == ImageType.BASE64:
+        image_dict["img_base64"] = image_to_b64str(pil_img)
+    else:
+        image_dict["img_pil"] = pil_img
+
     return image_dict
 
 
@@ -36,6 +39,7 @@ def load_images_from_pdf(
     dpi=200,
     start_page_id=0,
     end_page_id=None,
+    image_type=ImageType.PIL,  # PIL or BASE64
 ):
     images_list = []
     pdf_doc = pdfium.PdfDocument(pdf_bytes)
@@ -48,7 +52,7 @@ def load_images_from_pdf(
     for index in range(0, pdf_page_num):
         if start_page_id <= index <= end_page_id:
             page = pdf_doc[index]
-            image_dict = pdf_page_to_image(page, dpi=dpi)
+            image_dict = pdf_page_to_image(page, dpi=dpi, image_type=image_type)
             images_list.append(image_dict)
 
     return images_list, pdf_doc

+ 14 - 7
mineru/utils/pdf_reader.py

@@ -19,16 +19,14 @@ def page_to_image(
         scale = max_width_or_height / long_side_length
 
     bitmap: PdfBitmap = page.render(scale=scale)  # type: ignore
-    try:
-        image = bitmap.to_pil()
-    finally:
-        try:
-            bitmap.close()
-        except Exception:
-            pass
+
+    image = bitmap.to_pil()
+    bitmap.close()
     return image, scale
 
 
+
+
 def image_to_bytes(
     image: Image.Image,
     # image_format: str = "PNG",  # 也可以用 "JPEG"
@@ -48,6 +46,15 @@ def image_to_b64str(
     return base64.b64encode(image_bytes).decode("utf-8")
 
 
+def base64_to_pil_image(
+    base64_str: str,
+) -> Image.Image:
+    """Convert base64 string to PIL Image."""
+    image_bytes = base64.b64decode(base64_str)
+    with BytesIO(image_bytes) as image_buffer:
+        return Image.open(image_buffer).convert("RGB")
+
+
 def pdf_to_images(
     pdf: str | bytes | PdfDocument,
     dpi: int = 200,