ソースを参照

fix: batch methods in DocLayoutYOLO and YOLOv8 models

Suven 11 ヶ月 前
コミット
4fd1e41e0e

+ 4 - 2
magic_pdf/model/batch_analyze.py

@@ -34,15 +34,15 @@ class BatchAnalyze:
         self.batch_ratio = batch_ratio
 
     def __call__(self, images: list) -> list:
+        images_layout_res = []
         if self.model.layout_model_name == MODEL_NAME.LAYOUTLMv3:
             # layoutlmv3
-            images_layout_res = []
             for image in images:
                 layout_res = self.model.layout_model(image, ignore_catids=[])
                 images_layout_res.append(layout_res)
         elif self.model.layout_model_name == MODEL_NAME.DocLayout_YOLO:
             # doclayout_yolo
-            images_layout_res = self.model.layout_model.batch_predict(
+            images_layout_res += self.model.layout_model.batch_predict(
                 images, self.batch_ratio * YOLO_LAYOUT_BASE_BATCH_SIZE
             )
 
@@ -148,6 +148,8 @@ class BatchAnalyze:
                         )
                 logger.info(f"table time: {round(time.time() - table_start, 2)}")
 
+        return images_layout_res
+
 
 def doc_batch_analyze(
     dataset: Dataset,

+ 11 - 8
magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py

@@ -28,14 +28,17 @@ class DocLayoutYOLOModel(object):
     def batch_predict(self, images: list, batch_size: int) -> list:
         images_layout_res = []
         for index in range(0, len(images), batch_size):
-            doclayout_yolo_res = self.model.predict(
-                images[index : index + batch_size],
-                imgsz=1024,
-                conf=0.25,
-                iou=0.45,
-                verbose=True,
-                device=self.device,
-            ).cpu()
+            doclayout_yolo_res = [
+                image_res.cpu()
+                for image_res in self.model.predict(
+                    images[index : index + batch_size],
+                    imgsz=1024,
+                    conf=0.25,
+                    iou=0.45,
+                    verbose=True,
+                    device=self.device,
+                )
+            ]
             for image_res in doclayout_yolo_res:
                 layout_res = []
                 for xyxy, conf, cla in zip(

+ 11 - 8
magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py

@@ -15,14 +15,17 @@ class YOLOv8MFDModel(object):
     def batch_predict(self, images: list, batch_size: int) -> list:
         images_mfd_res = []
         for index in range(0, len(images), batch_size):
-            mfd_res = self.mfd_model.predict(
-                images[index : index + batch_size],
-                imgsz=1888,
-                conf=0.25,
-                iou=0.45,
-                verbose=True,
-                device=self.device,
-            ).cpu()
+            mfd_res = [
+                image_res.cpu()
+                for image_res in self.mfd_model.predict(
+                    images[index : index + batch_size],
+                    imgsz=1888,
+                    conf=0.25,
+                    iou=0.45,
+                    verbose=True,
+                    device=self.device,
+                )
+            ]
             for image_res in mfd_res:
                 images_mfd_res.append(image_res)
         return images_mfd_res