Sfoglia il codice sorgente

Merge pull request #3365 from myhloli/dev

Dev
Xiaomeng Zhao 2 mesi fa
parent
commit
532cfd20f8

+ 1 - 1
mineru/backend/pipeline/model_init.py

@@ -4,7 +4,7 @@ import torch
 from loguru import logger
 from loguru import logger
 
 
 from .model_list import AtomicModel
 from .model_list import AtomicModel
-from ...model.layout.doclayout_yolo import DocLayoutYOLOModel
+from ...model.layout.doclayoutyolo import DocLayoutYOLOModel
 from ...model.mfd.yolo_v8 import YOLOv8MFDModel
 from ...model.mfd.yolo_v8 import YOLOv8MFDModel
 from ...model.mfr.unimernet.Unimernet import UnimernetModel
 from ...model.mfr.unimernet.Unimernet import UnimernetModel
 from ...model.ocr.paddleocr2pytorch.pytorch_paddle import PytorchPaddleOCR
 from ...model.ocr.paddleocr2pytorch.pytorch_paddle import PytorchPaddleOCR

+ 44 - 2
mineru/model/layout/doclayout_yolo.py → mineru/model/layout/doclayoutyolo.py

@@ -1,8 +1,13 @@
+import os
 from typing import List, Dict, Union
 from typing import List, Dict, Union
+
 from doclayout_yolo import YOLOv10
 from doclayout_yolo import YOLOv10
 from tqdm import tqdm
 from tqdm import tqdm
 import numpy as np
 import numpy as np
-from PIL import Image
+from PIL import Image, ImageDraw
+
+from mineru.utils.enum_class import ModelPath
+from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
 
 
 
 
 class DocLayoutYOLOModel:
 class DocLayoutYOLOModel:
@@ -74,4 +79,41 @@ class DocLayoutYOLOModel:
                 for pred in predictions:
                 for pred in predictions:
                     results.append(self._parse_prediction(pred))
                     results.append(self._parse_prediction(pred))
                 pbar.update(len(batch))
                 pbar.update(len(batch))
-        return results
+        return results
+
+    def visualize(
+            self,
+            image: Union[np.ndarray, Image.Image],
+            results: List
+    ) -> Image.Image:
+
+        if isinstance(image, np.ndarray):
+            image = Image.fromarray(image)
+
+        draw = ImageDraw.Draw(image)
+        for res in results:
+            poly = res['poly']
+            xmin, ymin, xmax, ymax = poly[0], poly[1], poly[4], poly[5]
+            print(
+                f"Detected box: {xmin}, {ymin}, {xmax}, {ymax}, Category ID: {res['category_id']}, Score: {res['score']}")
+            # 使用PIL在图像上画框
+            draw.rectangle([xmin, ymin, xmax, ymax], outline="red", width=2)
+            # 在框旁边画置信度
+            draw.text((xmax + 10, ymin + 10), f"{res['score']:.2f}", fill="red", font_size=22)
+        return image
+
+
+if __name__ == '__main__':
+    image_path = r"C:\Users\zhaoxiaomeng\Downloads\下载1.jpg"
+    doclayout_yolo_weights = os.path.join(auto_download_and_get_model_root_path(ModelPath.doclayout_yolo), ModelPath.doclayout_yolo)
+    device = 'cuda'
+    model = DocLayoutYOLOModel(
+        weight=doclayout_yolo_weights,
+        device=device,
+    )
+    image = Image.open(image_path)
+    results = model.predict(image)
+
+    image = model.visualize(image, results)
+
+    image.show()  # 显示图像

+ 55 - 2
mineru/model/mfd/yolo_v8.py

@@ -1,8 +1,12 @@
+import os
 from typing import List, Union
 from typing import List, Union
 from tqdm import tqdm
 from tqdm import tqdm
 from ultralytics import YOLO
 from ultralytics import YOLO
 import numpy as np
 import numpy as np
-from PIL import Image
+from PIL import Image, ImageDraw
+
+from mineru.utils.enum_class import ModelPath
+from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
 
 
 
 
 class YOLOv8MFDModel:
 class YOLOv8MFDModel:
@@ -50,4 +54,53 @@ class YOLOv8MFDModel:
                 batch_preds = self._run_predict(batch, is_batch=True)
                 batch_preds = self._run_predict(batch, is_batch=True)
                 results.extend(batch_preds)
                 results.extend(batch_preds)
                 pbar.update(len(batch))
                 pbar.update(len(batch))
-        return results
+        return results
+
+    def visualize(
+        self,
+        image: Union[np.ndarray, Image.Image],
+        results: List
+    ) -> Image.Image:
+
+        if isinstance(image, np.ndarray):
+            image = Image.fromarray(image)
+
+        formula_list = []
+        for xyxy, conf, cla in zip(
+                results.boxes.xyxy.cpu(), results.boxes.conf.cpu(), results.boxes.cls.cpu()
+        ):
+            xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
+            new_item = {
+                "category_id": 13 + int(cla.item()),
+                "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
+                "score": round(float(conf.item()), 2),
+            }
+            formula_list.append(new_item)
+
+        draw = ImageDraw.Draw(image)
+        for res in formula_list:
+            poly = res['poly']
+            xmin, ymin, xmax, ymax = poly[0], poly[1], poly[4], poly[5]
+            print(
+                f"Detected box: {xmin}, {ymin}, {xmax}, {ymax}, Category ID: {res['category_id']}, Score: {res['score']}")
+            # 使用PIL在图像上画框
+            draw.rectangle([xmin, ymin, xmax, ymax], outline="red", width=2)
+            # 在框旁边画置信度
+            draw.text((xmax + 10, ymin + 10), f"{res['score']:.2f}", fill="red", font_size=22)
+        return image
+
+if __name__ == '__main__':
+    image_path = r"C:\Users\zhaoxiaomeng\Downloads\screenshot-20250821-192948.png"
+    yolo_v8_mfd_weights = os.path.join(auto_download_and_get_model_root_path(ModelPath.yolo_v8_mfd),
+                                          ModelPath.yolo_v8_mfd)
+    device = 'cuda'
+    model = YOLOv8MFDModel(
+        weight=yolo_v8_mfd_weights,
+        device=device,
+    )
+    image = Image.open(image_path)
+    results = model.predict(image)
+
+    image = model.visualize(image, results)
+
+    image.show()  # 显示图像

+ 18 - 27
mineru/utils/model_utils.py

@@ -201,6 +201,10 @@ def filter_nested_tables(table_res_list, overlap_threshold=0.8, area_threshold=0
 
 
 
 
 def remove_overlaps_min_blocks(res_list):
 def remove_overlaps_min_blocks(res_list):
+
+    for res in res_list:
+        res['bbox'] = [int(res['poly'][0]), int(res['poly'][1]), int(res['poly'][4]), int(res['poly'][5])]
+
     # 重叠block,小的不能直接删除,需要和大的那个合并成一个更大的。
     # 重叠block,小的不能直接删除,需要和大的那个合并成一个更大的。
     # 删除重叠blocks中较小的那些
     # 删除重叠blocks中较小的那些
     need_remove = []
     need_remove = []
@@ -248,6 +252,14 @@ def remove_overlaps_min_blocks(res_list):
     # 从列表中移除标记的元素
     # 从列表中移除标记的元素
     for res in need_remove:
     for res in need_remove:
         res_list.remove(res)
         res_list.remove(res)
+        del res['bbox']  # 删除bbox字段
+
+    for res in res_list:
+        # 将res的poly使用bbox重构
+        res['poly'] = [res['bbox'][0], res['bbox'][1], res['bbox'][2], res['bbox'][1],
+                       res['bbox'][2], res['bbox'][3], res['bbox'][0], res['bbox'][3]]
+        # 删除res的bbox
+        del res['bbox']
 
 
     return res_list, need_remove
     return res_list, need_remove
 
 
@@ -352,7 +364,6 @@ def get_res_list_from_layout_res(layout_res, iou_threshold=0.7, overlap_threshol
             table_res_list.append(res)
             table_res_list.append(res)
             table_indices.append(i)
             table_indices.append(i)
         elif category_id in [1]:  # Text regions
         elif category_id in [1]:  # Text regions
-            res['bbox'] = [int(res['poly'][0]), int(res['poly'][1]), int(res['poly'][4]), int(res['poly'][5])]
             text_res_list.append(res)
             text_res_list.append(res)
 
 
     # Process tables: merge high IoU tables first, then filter nested tables
     # Process tables: merge high IoU tables first, then filter nested tables
@@ -362,23 +373,11 @@ def get_res_list_from_layout_res(layout_res, iou_threshold=0.7, overlap_threshol
     filtered_table_res_list = filter_nested_tables(
     filtered_table_res_list = filter_nested_tables(
         table_res_list, overlap_threshold, area_threshold)
         table_res_list, overlap_threshold, area_threshold)
 
 
-    for table_res in filtered_table_res_list:
-        table_res['bbox'] = [int(table_res['poly'][0]), int(table_res['poly'][1]), int(table_res['poly'][4]), int(table_res['poly'][5])]
-
     filtered_table_res_list, table_need_remove = remove_overlaps_min_blocks(filtered_table_res_list)
     filtered_table_res_list, table_need_remove = remove_overlaps_min_blocks(filtered_table_res_list)
 
 
-    for res in filtered_table_res_list:
-        # 将res的poly使用bbox重构
-        res['poly'] = [res['bbox'][0], res['bbox'][1], res['bbox'][2], res['bbox'][1],
-                       res['bbox'][2], res['bbox'][3], res['bbox'][0], res['bbox'][3]]
-        # 删除res的bbox
-        del res['bbox']
-
-    if len(table_need_remove) > 0:
-        for res in table_need_remove:
-            del res['bbox']
-            if res in layout_res:
-                layout_res.remove(res)
+    for res in table_need_remove:
+        if res in layout_res:
+            layout_res.remove(res)
 
 
     # Remove filtered out tables from layout_res
     # Remove filtered out tables from layout_res
     if len(filtered_table_res_list) < len(table_res_list):
     if len(filtered_table_res_list) < len(table_res_list):
@@ -390,20 +389,12 @@ def get_res_list_from_layout_res(layout_res, iou_threshold=0.7, overlap_threshol
 
 
     # Remove overlaps in OCR and text regions
     # Remove overlaps in OCR and text regions
     text_res_list, need_remove = remove_overlaps_min_blocks(text_res_list)
     text_res_list, need_remove = remove_overlaps_min_blocks(text_res_list)
-    for res in text_res_list:
-        # 将res的poly使用bbox重构
-        res['poly'] = [res['bbox'][0], res['bbox'][1], res['bbox'][2], res['bbox'][1],
-                       res['bbox'][2], res['bbox'][3], res['bbox'][0], res['bbox'][3]]
-        # 删除res的bbox
-        del res['bbox']
 
 
     ocr_res_list.extend(text_res_list)
     ocr_res_list.extend(text_res_list)
 
 
-    if len(need_remove) > 0:
-        for res in need_remove:
-            del res['bbox']
-            if res in layout_res:
-                layout_res.remove(res)
+    for res in need_remove:
+        if res in layout_res:
+            layout_res.remove(res)
 
 
     # 检测大block内部是否包含多个小block, 合并ocr和table列表进行检测
     # 检测大block内部是否包含多个小block, 合并ocr和table列表进行检测
     combined_res_list = ocr_res_list + filtered_table_res_list
     combined_res_list = ocr_res_list + filtered_table_res_list