Bläddra i källkod

perf(mfr): improve Math Formula Recognition by sorting images by area

- Sort detected images by area before processing to enhance MFR accuracy
- Implement stable sorting to maintain original order of images with equal
myhloli 8 månader sedan
förälder
incheckning
59fc80d473
1 ändrade filer med 74 tillägg och 9 borttagningar
  1. 74 9
      magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py

+ 74 - 9
magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py

@@ -100,20 +100,61 @@ class UnimernetModel(object):
             res["latex"] = latex_rm_whitespace(latex)
         return formula_list
 
-    def batch_predict(
-        self, images_mfd_res: list, images: list, batch_size: int = 64
-    ) -> list:
+    # def batch_predict(
+    #     self, images_mfd_res: list, images: list, batch_size: int = 64
+    # ) -> list:
+    #     images_formula_list = []
+    #     mf_image_list = []
+    #     backfill_list = []
+    #     for image_index in range(len(images_mfd_res)):
+    #         mfd_res = images_mfd_res[image_index]
+    #         pil_img = Image.fromarray(images[image_index])
+    #         formula_list = []
+    #
+    #         for xyxy, conf, cla in zip(
+    #             mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
+    #         ):
+    #             xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
+    #             new_item = {
+    #                 "category_id": 13 + int(cla.item()),
+    #                 "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
+    #                 "score": round(float(conf.item()), 2),
+    #                 "latex": "",
+    #             }
+    #             formula_list.append(new_item)
+    #             bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
+    #             mf_image_list.append(bbox_img)
+    #
+    #         images_formula_list.append(formula_list)
+    #         backfill_list += formula_list
+    #
+    #     dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
+    #     dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
+    #     mfr_res = []
+    #     for mf_img in dataloader:
+    #         mf_img = mf_img.to(self.device)
+    #         with torch.no_grad():
+    #             output = self.model.generate({"image": mf_img})
+    #         mfr_res.extend(output["pred_str"])
+    #     for res, latex in zip(backfill_list, mfr_res):
+    #         res["latex"] = latex_rm_whitespace(latex)
+    #     return images_formula_list
+
+    def batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list:
         images_formula_list = []
         mf_image_list = []
         backfill_list = []
+        image_info = []  # Store (area, original_index, image) tuples
+
+        # Collect images with their original indices
         for image_index in range(len(images_mfd_res)):
             mfd_res = images_mfd_res[image_index]
             pil_img = Image.fromarray(images[image_index])
             formula_list = []
 
-            for xyxy, conf, cla in zip(
-                mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
-            ):
+            for idx, (xyxy, conf, cla) in enumerate(zip(
+                    mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
+            )):
                 xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
                 new_item = {
                     "category_id": 13 + int(cla.item()),
@@ -123,19 +164,43 @@ class UnimernetModel(object):
                 }
                 formula_list.append(new_item)
                 bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
+                area = (xmax - xmin) * (ymax - ymin)
+
+                curr_idx = len(mf_image_list)
+                image_info.append((area, curr_idx, bbox_img))
                 mf_image_list.append(bbox_img)
 
             images_formula_list.append(formula_list)
             backfill_list += formula_list
 
-        dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
+        # Stable sort by area
+        image_info.sort(key=lambda x: x[0])  # sort by area
+        sorted_indices = [x[1] for x in image_info]
+        sorted_images = [x[2] for x in image_info]
+
+        # Create mapping for results
+        index_mapping = {new_idx: old_idx for new_idx, old_idx in enumerate(sorted_indices)}
+
+        # Create dataset with sorted images
+        dataset = MathDataset(sorted_images, transform=self.mfr_transform)
         dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
+
+        # Process batches and store results
         mfr_res = []
         for mf_img in dataloader:
             mf_img = mf_img.to(self.device)
             with torch.no_grad():
                 output = self.model.generate({"image": mf_img})
             mfr_res.extend(output["pred_str"])
-        for res, latex in zip(backfill_list, mfr_res):
-            res["latex"] = latex_rm_whitespace(latex)
+
+        # Restore original order
+        unsorted_results = [""] * len(mfr_res)
+        for new_idx, latex in enumerate(mfr_res):
+            original_idx = index_mapping[new_idx]
+            unsorted_results[original_idx] = latex_rm_whitespace(latex)
+
+        # Fill results back
+        for res, latex in zip(backfill_list, unsorted_results):
+            res["latex"] = latex
+
         return images_formula_list