Эх сурвалжийг харах

Merge pull request #3405 from myhloli/dev

Dev
Xiaomeng Zhao 2 сар өмнө
parent
commit
45a8ca81e8

+ 18 - 11
mineru/backend/pipeline/batch_analyze.py

@@ -93,18 +93,19 @@ class BatchAnalyze:
                                           })
 
             for table_res in table_res_list:
-                # table_img, _ = crop_img(table_res, pil_img)
-                # bbox = (241, 208, 1475, 2019)
-                scale = 10/3
-                # scale = 1
-                crop_xmin, crop_ymin = int(table_res['poly'][0]), int(table_res['poly'][1])
-                crop_xmax, crop_ymax = int(table_res['poly'][4]), int(table_res['poly'][5])
-                bbox = (int(crop_xmin/scale), int(crop_ymin/scale), int(crop_xmax/scale), int(crop_ymax/scale))
-                table_img = get_crop_np_img(bbox, np_img, scale=scale)
+                def get_crop_table_img(scale):
+                    crop_xmin, crop_ymin = int(table_res['poly'][0]), int(table_res['poly'][1])
+                    crop_xmax, crop_ymax = int(table_res['poly'][4]), int(table_res['poly'][5])
+                    bbox = (int(crop_xmin / scale), int(crop_ymin / scale), int(crop_xmax / scale), int(crop_ymax / scale))
+                    return get_crop_np_img(bbox, np_img, scale=scale)
+
+                wireless_table_img = get_crop_table_img(scale = 1)
+                wired_table_img = get_crop_table_img(scale = 10/3)
 
                 table_res_list_all_page.append({'table_res':table_res,
                                                 'lang':_lang,
-                                                'table_img':table_img,
+                                                'table_img':wireless_table_img,
+                                                'wired_table_img':wired_table_img,
                                               })
 
         # 表格识别 table recognition
@@ -193,8 +194,14 @@ class BatchAnalyze:
             # 单独拿出有线表格进行预测
             wired_table_res_list = []
             for table_res_dict in table_res_list_all_page:
-                if table_res_dict["table_res"]["cls_label"] == AtomicModel.WiredTable:
+                # logger.debug(f"Table classification result: {table_res_dict["table_res"]["cls_label"]} with confidence {table_res_dict["table_res"]["cls_score"]}")
+                if (
+                    (table_res_dict["table_res"]["cls_label"] == AtomicModel.WirelessTable and table_res_dict["table_res"]["cls_score"] < 0.9)
+                    or table_res_dict["table_res"]["cls_label"] == AtomicModel.WiredTable
+                ):
                     wired_table_res_list.append(table_res_dict)
+                del table_res_dict["table_res"]["cls_label"]
+                del table_res_dict["table_res"]["cls_score"]
             if wired_table_res_list:
                 for table_res_dict in tqdm(
                         wired_table_res_list, desc="Table-wired Predict"
@@ -207,7 +214,7 @@ class BatchAnalyze:
                         lang=table_res_dict["lang"],
                     )
                     table_res_dict["table_res"]["html"] = wired_table_model.predict(
-                        table_res_dict["table_img"],
+                        table_res_dict["wired_table_img"],
                         table_res_dict["ocr_result"],
                         table_res_dict["table_res"].get("html", None)
                     )

+ 14 - 0
mineru/model/ori_cls/paddle_ori_cls.py

@@ -174,6 +174,12 @@ class PaddleOrientationClsModel:
     def batch_predict(
         self, imgs: List[Dict], det_batch_size: int, batch_size: int = 16
     ) -> None:
+
+        import torch
+        from packaging import version
+        if version.parse(torch.__version__) >= version.parse("2.8.0"):
+            return None
+
         """
         批量预测传入的包含图片信息列表的旋转信息,并且将旋转过的图片正确地旋转回来
         """
@@ -254,11 +260,19 @@ class PaddleOrientationClsModel:
                                 np.asarray(img_info["table_img"]),
                                 cv2.ROTATE_90_CLOCKWISE,
                             )
+                            img_info["wired_table_img"] = cv2.rotate(
+                                np.asarray(img_info["wired_table_img"]),
+                                cv2.ROTATE_90_CLOCKWISE,
+                            )
                         elif label == "90":
                             img_info["table_img"] = cv2.rotate(
                                 np.asarray(img_info["table_img"]),
                                 cv2.ROTATE_90_COUNTERCLOCKWISE,
                             )
+                            img_info["wired_table_img"] = cv2.rotate(
+                                np.asarray(img_info["wired_table_img"]),
+                                cv2.ROTATE_90_COUNTERCLOCKWISE,
+                            )
                         else:
                             # 180度和0度不做处理
                             pass

+ 1 - 7
mineru/model/table/cls/paddle_table_cls.py

@@ -72,9 +72,6 @@ class PaddleTableClsModel:
         result = self.sess.run(None, {"x": x})
         idx = np.argmax(result)
         conf = float(np.max(result))
-        # logger.debug(f"Table classification result: {self.labels[idx]} with confidence {conf:.4f}")
-        if idx == 0 and conf < 0.8:
-            idx = 1
         return self.labels[idx], conf
 
     def list_2_batch(self, img_list, batch_size=16):
@@ -134,7 +131,7 @@ class PaddleTableClsModel:
         x = np.stack(res_imgs, axis=0).astype(dtype=np.float32, copy=False)
         return x
     def batch_predict(self, img_info_list, batch_size=16):
-        imgs = [item["table_img"] for item in img_info_list]
+        imgs = [item["wired_table_img"] for item in img_info_list]
         imgs = self.list_2_batch(imgs, batch_size=batch_size)
         label_res = []
         with tqdm(total=len(img_info_list), desc="Table-wired/wireless cls predict", disable=True) as pbar:
@@ -144,9 +141,6 @@ class PaddleTableClsModel:
                 for img_res in result[0]:
                     idx = np.argmax(img_res)
                     conf = float(np.max(img_res))
-                    # logger.debug(f"Table classification result: {self.labels[idx]} with confidence {conf:.4f}")
-                    # if idx == 0 and conf < 0.8:
-                    #     idx = 1
                     label_res.append((self.labels[idx],conf))
                 pbar.update(len(img_batch))
             for img_info, (label, conf) in zip(img_info_list, label_res):

+ 17 - 3
mineru/model/table/rec/unet_table/main.py

@@ -184,7 +184,7 @@ class WiredTableRecognition:
             # 从img中截取对应的区域
             x1, y1, x2, y2 = int(box[0][0])+1, int(box[0][1])+1, int(box[2][0])-1, int(box[2][1])-1
             if x1 >= x2 or y1 >= y2:
-                logger.warning(f"Invalid box coordinates: {box}")
+                # logger.warning(f"Invalid box coordinates: {x1, y1, x2, y2}")
                 continue
             # 判断长宽比
             if (x2 - x1) / (y2 - y1) > 20 or (y2 - y1) / (x2 - x1) > 20:
@@ -308,10 +308,24 @@ class UnetTableModel:
             wired_blank_count = sum(1 for cell in wired_soup.find_all(['td', 'th']) if not cell.text.strip())
             # logger.debug(f"wireless table blank cell count: {wireless_blank_count}, wired table blank cell count: {wired_blank_count}")
 
+            # 计算非空单元格数量
+            wireless_non_blank_count = wireless_len - wireless_blank_count
+            wired_non_blank_count = wired_len - wired_blank_count
+            # 无线表非空格数量大于有线表非空格数量时,才考虑切换
+            switch_flag = False
+            if wireless_non_blank_count > wired_non_blank_count:
+                # 假设非空表格是接近正方表,使用非空单元格数量开平方作为表格规模的估计
+                wired_table_scale = round(wired_non_blank_count ** 0.5)
+                # logger.debug(f"wireless non-blank cell count: {wireless_non_blank_count}, wired non-blank cell count: {wired_non_blank_count}, wired table scale: {wired_table_scale}")
+                # 如果无线表非空格的数量比有线表多一列或以上,需要切换到无线表
+                wired_scale_plus_2_cols = wired_non_blank_count + (wired_table_scale * 2)
+                wired_scale_squared_plus_2_rows = wired_table_scale * (wired_table_scale + 2)
+                if (wireless_non_blank_count + 3) >= max(wired_scale_plus_2_cols, wired_scale_squared_plus_2_rows):
+                    switch_flag = True
+
             # 判断是否使用无线表格模型的结果
             if (
-                (int(wireless_len * 0.04) <= wired_len <= int(wireless_len * 0.62)+1 and wireless_blank_count <= wired_blank_count+50)
-                or int(wireless_len * 0.04) <= wired_len <= int(wireless_len * 0.55)+1 # 有线模型检测到的单元格数太少(低于无线模型的55%)
+                switch_flag
                 or (0 <= gap_of_len <= 5 and wired_len <= round(wireless_len * 0.75))  # 两者相差不大但有线模型结果较少
                 or (gap_of_len == 0 and wired_len <= 4)  # 单元格数量完全相等且总量小于等于4
                 or (wired_text_count <= wireless_text_count * 0.6 and  wireless_text_count >=10) # 有线模型填入的文字明显少于无线模型