Forráskód Böngészése

Merge pull request #3365 from myhloli/dev

Dev
Xiaomeng Zhao 2 hónapja
szülő
commit
532cfd20f8

+ 1 - 1
mineru/backend/pipeline/model_init.py

@@ -4,7 +4,7 @@ import torch
 from loguru import logger
 
 from .model_list import AtomicModel
-from ...model.layout.doclayout_yolo import DocLayoutYOLOModel
+from ...model.layout.doclayoutyolo import DocLayoutYOLOModel
 from ...model.mfd.yolo_v8 import YOLOv8MFDModel
 from ...model.mfr.unimernet.Unimernet import UnimernetModel
 from ...model.ocr.paddleocr2pytorch.pytorch_paddle import PytorchPaddleOCR

+ 44 - 2
mineru/model/layout/doclayout_yolo.py → mineru/model/layout/doclayoutyolo.py

@@ -1,8 +1,13 @@
+import os
 from typing import List, Dict, Union
+
 from doclayout_yolo import YOLOv10
 from tqdm import tqdm
 import numpy as np
-from PIL import Image
+from PIL import Image, ImageDraw
+
+from mineru.utils.enum_class import ModelPath
+from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
 
 
 class DocLayoutYOLOModel:
@@ -74,4 +79,41 @@ class DocLayoutYOLOModel:
                 for pred in predictions:
                     results.append(self._parse_prediction(pred))
                 pbar.update(len(batch))
-        return results
+        return results
+
+    def visualize(
+            self,
+            image: Union[np.ndarray, Image.Image],
+            results: List
+    ) -> Image.Image:
+
+        if isinstance(image, np.ndarray):
+            image = Image.fromarray(image)
+
+        draw = ImageDraw.Draw(image)
+        for res in results:
+            poly = res['poly']
+            xmin, ymin, xmax, ymax = poly[0], poly[1], poly[4], poly[5]
+            print(
+                f"Detected box: {xmin}, {ymin}, {xmax}, {ymax}, Category ID: {res['category_id']}, Score: {res['score']}")
+            # 使用PIL在图像上画框
+            draw.rectangle([xmin, ymin, xmax, ymax], outline="red", width=2)
+            # 在框旁边画置信度
+            draw.text((xmax + 10, ymin + 10), f"{res['score']:.2f}", fill="red", font_size=22)
+        return image
+
+
+if __name__ == '__main__':
+    image_path = r"C:\Users\zhaoxiaomeng\Downloads\下载1.jpg"
+    doclayout_yolo_weights = os.path.join(auto_download_and_get_model_root_path(ModelPath.doclayout_yolo), ModelPath.doclayout_yolo)
+    device = 'cuda'
+    model = DocLayoutYOLOModel(
+        weight=doclayout_yolo_weights,
+        device=device,
+    )
+    image = Image.open(image_path)
+    results = model.predict(image)
+
+    image = model.visualize(image, results)
+
+    image.show()  # 显示图像

+ 55 - 2
mineru/model/mfd/yolo_v8.py

@@ -1,8 +1,12 @@
+import os
 from typing import List, Union
 from tqdm import tqdm
 from ultralytics import YOLO
 import numpy as np
-from PIL import Image
+from PIL import Image, ImageDraw
+
+from mineru.utils.enum_class import ModelPath
+from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
 
 
 class YOLOv8MFDModel:
@@ -50,4 +54,53 @@ class YOLOv8MFDModel:
                 batch_preds = self._run_predict(batch, is_batch=True)
                 results.extend(batch_preds)
                 pbar.update(len(batch))
-        return results
+        return results
+
+    def visualize(
+        self,
+        image: Union[np.ndarray, Image.Image],
+        results: List
+    ) -> Image.Image:
+
+        if isinstance(image, np.ndarray):
+            image = Image.fromarray(image)
+
+        formula_list = []
+        for xyxy, conf, cla in zip(
+                results.boxes.xyxy.cpu(), results.boxes.conf.cpu(), results.boxes.cls.cpu()
+        ):
+            xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
+            new_item = {
+                "category_id": 13 + int(cla.item()),
+                "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
+                "score": round(float(conf.item()), 2),
+            }
+            formula_list.append(new_item)
+
+        draw = ImageDraw.Draw(image)
+        for res in formula_list:
+            poly = res['poly']
+            xmin, ymin, xmax, ymax = poly[0], poly[1], poly[4], poly[5]
+            print(
+                f"Detected box: {xmin}, {ymin}, {xmax}, {ymax}, Category ID: {res['category_id']}, Score: {res['score']}")
+            # 使用PIL在图像上画框
+            draw.rectangle([xmin, ymin, xmax, ymax], outline="red", width=2)
+            # 在框旁边画置信度
+            draw.text((xmax + 10, ymin + 10), f"{res['score']:.2f}", fill="red", font_size=22)
+        return image
+
+if __name__ == '__main__':
+    image_path = r"C:\Users\zhaoxiaomeng\Downloads\screenshot-20250821-192948.png"
+    yolo_v8_mfd_weights = os.path.join(auto_download_and_get_model_root_path(ModelPath.yolo_v8_mfd),
+                                          ModelPath.yolo_v8_mfd)
+    device = 'cuda'
+    model = YOLOv8MFDModel(
+        weight=yolo_v8_mfd_weights,
+        device=device,
+    )
+    image = Image.open(image_path)
+    results = model.predict(image)
+
+    image = model.visualize(image, results)
+
+    image.show()  # 显示图像

+ 18 - 27
mineru/utils/model_utils.py

@@ -201,6 +201,10 @@ def filter_nested_tables(table_res_list, overlap_threshold=0.8, area_threshold=0
 
 
 def remove_overlaps_min_blocks(res_list):
+
+    for res in res_list:
+        res['bbox'] = [int(res['poly'][0]), int(res['poly'][1]), int(res['poly'][4]), int(res['poly'][5])]
+
     # 重叠block,小的不能直接删除,需要和大的那个合并成一个更大的。
     # 删除重叠blocks中较小的那些
     need_remove = []
@@ -248,6 +252,14 @@ def remove_overlaps_min_blocks(res_list):
     # 从列表中移除标记的元素
     for res in need_remove:
         res_list.remove(res)
+        del res['bbox']  # 删除bbox字段
+
+    for res in res_list:
+        # 将res的poly使用bbox重构
+        res['poly'] = [res['bbox'][0], res['bbox'][1], res['bbox'][2], res['bbox'][1],
+                       res['bbox'][2], res['bbox'][3], res['bbox'][0], res['bbox'][3]]
+        # 删除res的bbox
+        del res['bbox']
 
     return res_list, need_remove
 
@@ -352,7 +364,6 @@ def get_res_list_from_layout_res(layout_res, iou_threshold=0.7, overlap_threshol
             table_res_list.append(res)
             table_indices.append(i)
         elif category_id in [1]:  # Text regions
-            res['bbox'] = [int(res['poly'][0]), int(res['poly'][1]), int(res['poly'][4]), int(res['poly'][5])]
             text_res_list.append(res)
 
     # Process tables: merge high IoU tables first, then filter nested tables
@@ -362,23 +373,11 @@ def get_res_list_from_layout_res(layout_res, iou_threshold=0.7, overlap_threshol
     filtered_table_res_list = filter_nested_tables(
         table_res_list, overlap_threshold, area_threshold)
 
-    for table_res in filtered_table_res_list:
-        table_res['bbox'] = [int(table_res['poly'][0]), int(table_res['poly'][1]), int(table_res['poly'][4]), int(table_res['poly'][5])]
-
     filtered_table_res_list, table_need_remove = remove_overlaps_min_blocks(filtered_table_res_list)
 
-    for res in filtered_table_res_list:
-        # 将res的poly使用bbox重构
-        res['poly'] = [res['bbox'][0], res['bbox'][1], res['bbox'][2], res['bbox'][1],
-                       res['bbox'][2], res['bbox'][3], res['bbox'][0], res['bbox'][3]]
-        # 删除res的bbox
-        del res['bbox']
-
-    if len(table_need_remove) > 0:
-        for res in table_need_remove:
-            del res['bbox']
-            if res in layout_res:
-                layout_res.remove(res)
+    for res in table_need_remove:
+        if res in layout_res:
+            layout_res.remove(res)
 
     # Remove filtered out tables from layout_res
     if len(filtered_table_res_list) < len(table_res_list):
@@ -390,20 +389,12 @@ def get_res_list_from_layout_res(layout_res, iou_threshold=0.7, overlap_threshol
 
     # Remove overlaps in OCR and text regions
     text_res_list, need_remove = remove_overlaps_min_blocks(text_res_list)
-    for res in text_res_list:
-        # 将res的poly使用bbox重构
-        res['poly'] = [res['bbox'][0], res['bbox'][1], res['bbox'][2], res['bbox'][1],
-                       res['bbox'][2], res['bbox'][3], res['bbox'][0], res['bbox'][3]]
-        # 删除res的bbox
-        del res['bbox']
 
     ocr_res_list.extend(text_res_list)
 
-    if len(need_remove) > 0:
-        for res in need_remove:
-            del res['bbox']
-            if res in layout_res:
-                layout_res.remove(res)
+    for res in need_remove:
+        if res in layout_res:
+            layout_res.remove(res)
 
     # 检测大block内部是否包含多个小block, 合并ocr和table列表进行检测
     combined_res_list = ocr_res_list + filtered_table_res_list