Explorar o código

feat(language-detection): add YOLOv11 language detection model

- Add YOLOv11 language detection model for PDF documents
- Implement language detection in PymuDocDataset
- Update app.py to include 'auto' language option
- Create language detection utilities and constants
myhloli hai 11 meses
pai
achega
20438bd2b7

+ 2 - 0
magic_pdf/config/constants.py

@@ -52,6 +52,8 @@ class MODEL_NAME:
 
     RAPID_TABLE = 'rapid_table'
 
+    YOLO_V11_LangDetect = 'yolo_v11n_langdetect'
+
 
 PARSE_TYPE_TXT = 'txt'
 PARSE_TYPE_OCR = 'ocr'

+ 12 - 1
magic_pdf/data/dataset.py

@@ -3,11 +3,13 @@ from abc import ABC, abstractmethod
 from typing import Callable, Iterator
 
 import fitz
+from loguru import logger
 
 from magic_pdf.config.enums import SupportedPdfParseMethod
 from magic_pdf.data.schemas import PageInfo
 from magic_pdf.data.utils import fitz_doc_to_image
 from magic_pdf.filter import classify
+from magic_pdf.model.sub_modules.language_detection.utils import auto_detect_lang
 
 
 class PageableData(ABC):
@@ -133,7 +135,7 @@ class Dataset(ABC):
 
 
 class PymuDocDataset(Dataset):
-    def __init__(self, bits: bytes):
+    def __init__(self, bits: bytes, lang=None):
         """Initialize the dataset, which wraps the pymudoc documents.
 
         Args:
@@ -144,6 +146,13 @@ class PymuDocDataset(Dataset):
         self._data_bits = bits
         self._raw_data = bits
 
+        if lang == '':
+            self._lang = None
+        elif lang == 'auto':
+            self._lang = auto_detect_lang(bits)
+            logger.info(f"lang: {lang}, detect_lang: {self._lang}")
+        else:
+            self._lang = lang
     def __len__(self) -> int:
         """The page number of the pdf."""
         return len(self._records)
@@ -197,6 +206,8 @@ class PymuDocDataset(Dataset):
         Returns:
             Any: return the result generated by proc
         """
+        if 'lang' in kwargs and self._lang is not None:
+            kwargs['lang'] = self._lang
         return proc(self, *args, **kwargs)
 
     def classify(self) -> SupportedPdfParseMethod:

+ 35 - 0
magic_pdf/data/utils.py

@@ -1,6 +1,7 @@
 
 import fitz
 import numpy as np
+from loguru import logger
 
 from magic_pdf.utils.annotations import ImportPIL
 
@@ -30,3 +31,37 @@ def fitz_doc_to_image(doc, dpi=200) -> dict:
     img_dict = {'img': img, 'width': pm.width, 'height': pm.height}
 
     return img_dict
+
+@ImportPIL
+def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id=None) -> list:
+    from PIL import Image
+    images = []
+    with fitz.open('pdf', pdf_bytes) as doc:
+        pdf_page_num = doc.page_count
+        end_page_id = (
+            end_page_id
+            if end_page_id is not None and end_page_id >= 0
+            else pdf_page_num - 1
+        )
+        if end_page_id > pdf_page_num - 1:
+            logger.warning('end_page_id is out of range, use images length')
+            end_page_id = pdf_page_num - 1
+
+        for index in range(0, doc.page_count):
+            if start_page_id <= index <= end_page_id:
+                page = doc[index]
+                mat = fitz.Matrix(dpi / 72, dpi / 72)
+                pm = page.get_pixmap(matrix=mat, alpha=False)
+
+                # If the width or height exceeds 4500 after scaling, do not scale further.
+                if pm.width > 4500 or pm.height > 4500:
+                    pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
+
+                img = Image.frombytes('RGB', (pm.width, pm.height), pm.samples)
+                img = np.array(img)
+                img_dict = {'img': img, 'width': pm.width, 'height': pm.height}
+            else:
+                img_dict = {'img': [], 'width': 0, 'height': 0}
+
+            images.append(img_dict)
+    return images

+ 0 - 46
magic_pdf/model/doc_analyze_by_custom_model.py

@@ -1,8 +1,6 @@
 import os
 import time
 
-import fitz
-import numpy as np
 from loguru import logger
 
 # 关闭paddle的信号处理
@@ -44,47 +42,6 @@ def remove_duplicates_dicts(lst):
     return unique_dicts
 
 
-def load_images_from_pdf(
-    pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id=None
-) -> list:
-    try:
-        from PIL import Image
-    except ImportError:
-        logger.error('Pillow not installed, please install by pip.')
-        exit(1)
-
-    images = []
-    with fitz.open('pdf', pdf_bytes) as doc:
-        pdf_page_num = doc.page_count
-        end_page_id = (
-            end_page_id
-            if end_page_id is not None and end_page_id >= 0
-            else pdf_page_num - 1
-        )
-        if end_page_id > pdf_page_num - 1:
-            logger.warning('end_page_id is out of range, use images length')
-            end_page_id = pdf_page_num - 1
-
-        for index in range(0, doc.page_count):
-            if start_page_id <= index <= end_page_id:
-                page = doc[index]
-                mat = fitz.Matrix(dpi / 72, dpi / 72)
-                pm = page.get_pixmap(matrix=mat, alpha=False)
-
-                # If the width or height exceeds 4500 after scaling, do not scale further.
-                if pm.width > 4500 or pm.height > 4500:
-                    pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
-
-                img = Image.frombytes('RGB', (pm.width, pm.height), pm.samples)
-                img = np.array(img)
-                img_dict = {'img': img, 'width': pm.width, 'height': pm.height}
-            else:
-                img_dict = {'img': [], 'width': 0, 'height': 0}
-
-            images.append(img_dict)
-    return images
-
-
 class ModelSingleton:
     _instance = None
     _models = {}
@@ -197,9 +154,6 @@ def doc_analyze(
     table_enable=None,
 ) -> InferenceResult:
 
-    if lang == '':
-        lang = None
-
     model_manager = ModelSingleton()
     custom_model = model_manager.get_model(
         ocr, show_log, lang, layout_model, formula_enable, table_enable

+ 1 - 0
magic_pdf/model/sub_modules/language_detection/__init__.py

@@ -0,0 +1 @@
+# Copyright (c) Opendatalab. All rights reserved.

+ 73 - 0
magic_pdf/model/sub_modules/language_detection/utils.py

@@ -0,0 +1,73 @@
+# Copyright (c) Opendatalab. All rights reserved.
+import os
+from pathlib import Path
+
+import yaml
+from PIL import Image
+
+os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1'  # 禁止albumentations检查更新
+
+from magic_pdf.config.constants import MODEL_NAME
+from magic_pdf.data.utils import load_images_from_pdf
+from magic_pdf.libs.config_reader import get_local_models_dir, get_device
+from magic_pdf.libs.pdf_check import extract_pages
+from magic_pdf.model.model_list import AtomicModel
+from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import YOLOv11LangDetModel
+from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
+
+
+def get_model_config():
+    local_models_dir = get_local_models_dir()
+    device = get_device()
+    current_file_path = os.path.abspath(__file__)
+    root_dir = Path(current_file_path).parents[3]
+    model_config_dir = os.path.join(root_dir, 'resources', 'model_config')
+    config_path = os.path.join(model_config_dir, 'model_configs.yaml')
+    with open(config_path, 'r', encoding='utf-8') as f:
+        configs = yaml.load(f, Loader=yaml.FullLoader)
+    return local_models_dir, device, configs
+
+
+def get_text_images(simple_images):
+    local_models_dir, device, configs = get_model_config()
+    atom_model_manager = AtomModelSingleton()
+    temp_layout_model = atom_model_manager.get_atom_model(
+        atom_model_name=AtomicModel.Layout,
+        layout_model_name=MODEL_NAME.DocLayout_YOLO,
+        doclayout_yolo_weights=str(
+            os.path.join(
+                local_models_dir, configs['weights'][MODEL_NAME.DocLayout_YOLO]
+            )
+        ),
+        device=device,
+    )
+    text_images = []
+    for simple_image in simple_images:
+        image = Image.fromarray(simple_image['img'])
+        layout_res = temp_layout_model.predict(image)
+        # 给textblock截图
+        for res in layout_res:
+            if res['category_id'] in [1]:
+                x1, y1, _, _, x2, y2, _, _ = res['poly']
+                # 初步清洗(宽和高都小于100)
+                if x2 - x1 < 100 and y2 - y1 < 100:
+                    continue
+                text_images.append(image.crop((x1, y1, x2, y2)))
+    return text_images
+
+
+def auto_detect_lang(pdf_bytes: bytes):
+    sample_docs = extract_pages(pdf_bytes)
+    sample_pdf_bytes = sample_docs.tobytes()
+    simple_images = load_images_from_pdf(sample_pdf_bytes, dpi=96)
+    text_images = get_text_images(simple_images)
+    local_models_dir, device, configs = get_model_config()
+    # 用yolo11做语言分类
+    langdetect_model_weights = str(
+        os.path.join(
+            local_models_dir, configs['weights'][MODEL_NAME.YOLO_V11_LangDetect]
+        )
+    )
+    langdetect_model = YOLOv11LangDetModel(langdetect_model_weights, device)
+    lang = langdetect_model.do_detect(text_images)
+    return lang

+ 134 - 0
magic_pdf/model/sub_modules/language_detection/yolov11/YOLOv11.py

@@ -0,0 +1,134 @@
+# Copyright (c) Opendatalab. All rights reserved.
+from collections import Counter
+from uuid import uuid4
+
+from PIL import Image
+from loguru import logger
+from ultralytics import YOLO
+
+language_dict = {
+    "ch": "中文简体",
+    "en": "英语",
+    "japan": "日语",
+    "korean": "韩语",
+    "fr": "法语",
+    "german": "德语",
+    "ar": "阿拉伯语",
+    "ru": "俄语"
+}
+
+
+def split_images(image, result_images=None):
+    """
+    对输入文件夹内的图片进行处理,若图片竖向(y方向)分辨率超过400,则进行拆分,
+    每次平分图片,直至拆分出的图片竖向分辨率都满足400以下,将处理后的图片(拆分后的子图片)保存到输出文件夹。
+    避免保存因裁剪区域超出图片范围导致出现的无效黑色图片部分。
+    """
+    if result_images is None:
+        result_images = []
+
+    width, height = image.size
+    long_side = max(width, height)  # 获取较长边长度
+
+    if long_side <= 400:
+        result_images.append(image)
+        return result_images
+
+    new_long_side = long_side // 2
+    sub_images = []
+
+    if width >= height:  # 如果宽度是较长边
+        for x in range(0, width, new_long_side):
+            # 判断裁剪区域是否超出图片范围,如果超出则不进行裁剪保存操作
+            if x + new_long_side > width:
+                continue
+            box = (x, 0, x + new_long_side, height)
+            sub_image = image.crop(box)
+            sub_images.append(sub_image)
+    else:  # 如果高度是较长边
+        for y in range(0, height, new_long_side):
+            # 判断裁剪区域是否超出图片范围,如果超出则不进行裁剪保存操作
+            if y + new_long_side > height:
+                continue
+            box = (0, y, width, y + new_long_side)
+            sub_image = image.crop(box)
+            sub_images.append(sub_image)
+
+    for sub_image in sub_images:
+        split_images(sub_image, result_images)
+
+    return result_images
+
+
+def resize_images_to_224(image):
+    """
+    若分辨率小于224则用黑色背景补齐到224*224大小,若大于等于224则调整为224*224大小,并保存到输出文件夹中。
+    """
+    try:
+        width, height = image.size
+        if width < 224 or height < 224:
+            new_image = Image.new('RGB', (224, 224), (0, 0, 0))
+            paste_x = (224 - width) // 2
+            paste_y = (224 - height) // 2
+            new_image.paste(image, (paste_x, paste_y))
+            image = new_image
+        else:
+            image = image.resize((224, 224), Image.Resampling.LANCZOS)
+
+        # uuid = str(uuid4())
+        # image.save(f"/tmp/{uuid}.jpg")
+        return image
+    except Exception as e:
+        logger.exception(e)
+
+
+class YOLOv11LangDetModel(object):
+    def __init__(self, weight, device):
+        self.model = YOLO(weight)
+        self.device = device
+
+    def do_detect(self, images: list):
+        all_images = []
+        for image in images:
+            width, height = image.size
+            # logger.info(f"image size: {width} x {height}")
+            if width < 100 and height < 100:
+                continue
+            temp_images = split_images(image)
+            for temp_image in temp_images:
+                all_images.append(resize_images_to_224(temp_image))
+
+        images_lang_res = self.batch_predict(all_images, batch_size=8)
+        logger.info(f"images_lang_res: {images_lang_res}")
+        if len(images_lang_res) > 0:
+            count_dict = Counter(images_lang_res)
+            language = max(count_dict, key=count_dict.get)
+        else:
+            language = None
+        return language
+
+
+    def predict(self, image):
+        results = self.model.predict(image, verbose=False, device=self.device)
+        predicted_class_id = int(results[0].probs.top1)
+        predicted_class_name = self.model.names[predicted_class_id]
+        return predicted_class_name
+
+
+    def batch_predict(self, images: list, batch_size: int) -> list:
+        images_lang_res = []
+        for index in range(0, len(images), batch_size):
+            lang_res = [
+                image_res.cpu()
+                for image_res in self.model.predict(
+                    images[index: index + batch_size],
+                    verbose = False,
+                    device=self.device,
+                )
+            ]
+            for res in lang_res:
+                predicted_class_id = int(res.probs.top1)
+                predicted_class_name = self.model.names[predicted_class_id]
+                images_lang_res.append(predicted_class_name)
+
+        return images_lang_res

+ 1 - 0
magic_pdf/model/sub_modules/language_detection/yolov11/__init__.py

@@ -0,0 +1 @@
+# Copyright (c) Opendatalab. All rights reserved.

+ 2 - 2
magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py

@@ -9,7 +9,7 @@ class DocLayoutYOLOModel(object):
     def predict(self, image):
         layout_res = []
         doclayout_yolo_res = self.model.predict(
-            image, imgsz=1024, conf=0.25, iou=0.45, verbose=True, device=self.device
+            image, imgsz=1024, conf=0.25, iou=0.45, verbose=False, device=self.device
         )[0]
         for xyxy, conf, cla in zip(
             doclayout_yolo_res.boxes.xyxy.cpu(),
@@ -35,7 +35,7 @@ class DocLayoutYOLOModel(object):
                     imgsz=1024,
                     conf=0.25,
                     iou=0.45,
-                    verbose=True,
+                    verbose=False,
                     device=self.device,
                 )
             ]

+ 2 - 1
magic_pdf/resources/model_config/model_configs.yaml

@@ -5,4 +5,5 @@ weights:
   unimernet_small: MFR/unimernet_small
   struct_eqtable: TabRec/StructEqTable
   tablemaster: TabRec/TableMaster
-  rapid_table: TabRec/RapidTable
+  rapid_table: TabRec/RapidTable
+  yolo_v11n_langdetect: LangDetect/YOLO/yolo_v11_cls_ft.pt

+ 13 - 16
magic_pdf/tools/common.py

@@ -95,9 +95,6 @@ def do_parse(
         f_draw_model_bbox = True
         f_draw_line_sort_bbox = True
 
-    if lang == '':
-        lang = None
-
     pdf_bytes = convert_pdf_bytes_to_bytes_by_pymupdf(
         pdf_bytes, start_page_id, end_page_id
     )
@@ -109,7 +106,7 @@ def do_parse(
     )
     image_dir = str(os.path.basename(local_image_dir))
 
-    ds = PymuDocDataset(pdf_bytes)
+    ds = PymuDocDataset(pdf_bytes, lang=lang)
 
     if len(model_list) == 0:
         if model_config.__use_inside_model__:
@@ -118,50 +115,50 @@ def do_parse(
                     infer_result = ds.apply(
                         doc_analyze,
                         ocr=False,
-                        lang=lang,
+                        lang=ds._lang,
                         layout_model=layout_model,
                         formula_enable=formula_enable,
                         table_enable=table_enable,
                     )
                     pipe_result = infer_result.pipe_txt_mode(
-                        image_writer, debug_mode=True, lang=lang
+                        image_writer, debug_mode=True, lang=ds._lang
                     )
                 else:
                     infer_result = ds.apply(
                         doc_analyze,
                         ocr=True,
-                        lang=lang,
+                        lang=ds._lang,
                         layout_model=layout_model,
                         formula_enable=formula_enable,
                         table_enable=table_enable,
                     )
                     pipe_result = infer_result.pipe_ocr_mode(
-                        image_writer, debug_mode=True, lang=lang
+                        image_writer, debug_mode=True, lang=ds._lang
                     )
 
             elif parse_method == 'txt':
                 infer_result = ds.apply(
                     doc_analyze,
                     ocr=False,
-                    lang=lang,
+                    lang=ds._lang,
                     layout_model=layout_model,
                     formula_enable=formula_enable,
                     table_enable=table_enable,
                 )
                 pipe_result = infer_result.pipe_txt_mode(
-                    image_writer, debug_mode=True, lang=lang
+                    image_writer, debug_mode=True, lang=ds._lang
                 )
             elif parse_method == 'ocr':
                 infer_result = ds.apply(
                     doc_analyze,
                     ocr=True,
-                    lang=lang,
+                    lang=ds._lang,
                     layout_model=layout_model,
                     formula_enable=formula_enable,
                     table_enable=table_enable,
                 )
                 pipe_result = infer_result.pipe_ocr_mode(
-                    image_writer, debug_mode=True, lang=lang
+                    image_writer, debug_mode=True, lang=ds._lang
                 )
             else:
                 logger.error('unknown parse method')
@@ -174,20 +171,20 @@ def do_parse(
         infer_result = InferenceResult(model_list, ds)
         if parse_method == 'ocr':
             pipe_result = infer_result.pipe_ocr_mode(
-                image_writer, debug_mode=True, lang=lang
+                image_writer, debug_mode=True, lang=ds._lang
             )
         elif parse_method == 'txt':
             pipe_result = infer_result.pipe_txt_mode(
-                image_writer, debug_mode=True, lang=lang
+                image_writer, debug_mode=True, lang=ds._lang
             )
         else:
             if ds.classify() == SupportedPdfParseMethod.TXT:
                 pipe_result = infer_result.pipe_txt_mode(
-                        image_writer, debug_mode=True, lang=lang
+                        image_writer, debug_mode=True, lang=ds._lang
                     )
             else:
                 pipe_result = infer_result.pipe_ocr_mode(
-                        image_writer, debug_mode=True, lang=lang
+                        image_writer, debug_mode=True, lang=ds._lang
                     )
             
 

+ 1 - 1
projects/gradio_app/app.py

@@ -159,7 +159,7 @@ devanagari_lang = [
 ]
 other_lang = ['ch', 'en', 'korean', 'japan', 'chinese_cht', 'ta', 'te', 'ka']
 
-all_lang = ['']
+all_lang = ['', 'auto']
 all_lang.extend([*other_lang, *latin_lang, *arabic_lang, *cyrillic_lang, *devanagari_lang])