소스 검색

fix: rename doclayout_yolo.py to doclayoutyolo.py and add visualization method for bounding box results

myhloli 2 달 전
부모
커밋
c8a17c5f98
3개의 변경된 파일100개의 추가작업 그리고 5개의 파일을 삭제
  1. 1 1
      mineru/backend/pipeline/model_init.py
  2. 44 2
      mineru/model/layout/doclayoutyolo.py
  3. 55 2
      mineru/model/mfd/yolo_v8.py

+ 1 - 1
mineru/backend/pipeline/model_init.py

@@ -4,7 +4,7 @@ import torch
 from loguru import logger
 
 from .model_list import AtomicModel
-from ...model.layout.doclayout_yolo import DocLayoutYOLOModel
+from ...model.layout.doclayoutyolo import DocLayoutYOLOModel
 from ...model.mfd.yolo_v8 import YOLOv8MFDModel
 from ...model.mfr.unimernet.Unimernet import UnimernetModel
 from ...model.ocr.paddleocr2pytorch.pytorch_paddle import PytorchPaddleOCR

+ 44 - 2
mineru/model/layout/doclayout_yolo.py → mineru/model/layout/doclayoutyolo.py

@@ -1,8 +1,13 @@
+import os
 from typing import List, Dict, Union
+
 from doclayout_yolo import YOLOv10
 from tqdm import tqdm
 import numpy as np
-from PIL import Image
+from PIL import Image, ImageDraw
+
+from mineru.utils.enum_class import ModelPath
+from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
 
 
 class DocLayoutYOLOModel:
@@ -74,4 +79,41 @@ class DocLayoutYOLOModel:
                 for pred in predictions:
                     results.append(self._parse_prediction(pred))
                 pbar.update(len(batch))
-        return results
+        return results
+
+    def visualize(
+            self,
+            image: Union[np.ndarray, Image.Image],
+            results: List
+    ) -> Image.Image:
+
+        if isinstance(image, np.ndarray):
+            image = Image.fromarray(image)
+
+        draw = ImageDraw.Draw(image)
+        for res in results:
+            poly = res['poly']
+            xmin, ymin, xmax, ymax = poly[0], poly[1], poly[4], poly[5]
+            print(
+                f"Detected box: {xmin}, {ymin}, {xmax}, {ymax}, Category ID: {res['category_id']}, Score: {res['score']}")
+            # 使用PIL在图像上画框
+            draw.rectangle([xmin, ymin, xmax, ymax], outline="red", width=2)
+            # 在框旁边画置信度
+            draw.text((xmax + 10, ymin + 10), f"{res['score']:.2f}", fill="red", font_size=22)
+        return image
+
+
+if __name__ == '__main__':
+    image_path = r"C:\Users\zhaoxiaomeng\Downloads\下载1.jpg"
+    doclayout_yolo_weights = os.path.join(auto_download_and_get_model_root_path(ModelPath.doclayout_yolo), ModelPath.doclayout_yolo)
+    device = 'cuda'
+    model = DocLayoutYOLOModel(
+        weight=doclayout_yolo_weights,
+        device=device,
+    )
+    image = Image.open(image_path)
+    results = model.predict(image)
+
+    image = model.visualize(image, results)
+
+    image.show()  # 显示图像

+ 55 - 2
mineru/model/mfd/yolo_v8.py

@@ -1,8 +1,12 @@
+import os
 from typing import List, Union
 from tqdm import tqdm
 from ultralytics import YOLO
 import numpy as np
-from PIL import Image
+from PIL import Image, ImageDraw
+
+from mineru.utils.enum_class import ModelPath
+from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
 
 
 class YOLOv8MFDModel:
@@ -50,4 +54,53 @@ class YOLOv8MFDModel:
                 batch_preds = self._run_predict(batch, is_batch=True)
                 results.extend(batch_preds)
                 pbar.update(len(batch))
-        return results
+        return results
+
+    def visualize(
+        self,
+        image: Union[np.ndarray, Image.Image],
+        results: List
+    ) -> Image.Image:
+
+        if isinstance(image, np.ndarray):
+            image = Image.fromarray(image)
+
+        formula_list = []
+        for xyxy, conf, cla in zip(
+                results.boxes.xyxy.cpu(), results.boxes.conf.cpu(), results.boxes.cls.cpu()
+        ):
+            xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
+            new_item = {
+                "category_id": 13 + int(cla.item()),
+                "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
+                "score": round(float(conf.item()), 2),
+            }
+            formula_list.append(new_item)
+
+        draw = ImageDraw.Draw(image)
+        for res in formula_list:
+            poly = res['poly']
+            xmin, ymin, xmax, ymax = poly[0], poly[1], poly[4], poly[5]
+            print(
+                f"Detected box: {xmin}, {ymin}, {xmax}, {ymax}, Category ID: {res['category_id']}, Score: {res['score']}")
+            # 使用PIL在图像上画框
+            draw.rectangle([xmin, ymin, xmax, ymax], outline="red", width=2)
+            # 在框旁边画置信度
+            draw.text((xmax + 10, ymin + 10), f"{res['score']:.2f}", fill="red")
+        return image
+
+if __name__ == '__main__':
+    image_path = r"C:\Users\zhaoxiaomeng\Downloads\下载1.jpg"
+    yolo_v8_mfd_weights = os.path.join(auto_download_and_get_model_root_path(ModelPath.yolo_v8_mfd),
+                                          ModelPath.yolo_v8_mfd)
+    device = 'cuda'
+    model = YOLOv8MFDModel(
+        weight=yolo_v8_mfd_weights,
+        device=device,
+    )
+    image = Image.open(image_path)
+    results = model.predict(image)
+
+    image = model.visualize(image, results)
+
+    image.show()  # 显示图像