소스 검색

fix: enhance wired table prediction logic and improve classification criteria

myhloli 2 달 전
부모
커밋
06a158e56b
3개의 변경된 파일25개의 추가작업 그리고 12개의 파일을 삭제
  1. 7 1
      mineru/backend/pipeline/batch_analyze.py
  2. 1 7
      mineru/model/table/cls/paddle_table_cls.py
  3. 17 4
      mineru/model/table/rec/unet_table/main.py

+ 7 - 1
mineru/backend/pipeline/batch_analyze.py

@@ -194,8 +194,14 @@ class BatchAnalyze:
             # 单独拿出有线表格进行预测
             wired_table_res_list = []
             for table_res_dict in table_res_list_all_page:
-                if table_res_dict["table_res"]["cls_label"] == AtomicModel.WiredTable:
+                # logger.debug(f"Table classification result: {table_res_dict["table_res"]["cls_label"]} with confidence {table_res_dict["table_res"]["cls_score"]}")
+                if (
+                    (table_res_dict["table_res"]["cls_label"] == AtomicModel.WirelessTable and table_res_dict["table_res"]["cls_score"] < 0.9)
+                    or table_res_dict["table_res"]["cls_label"] == AtomicModel.WiredTable
+                ):
                     wired_table_res_list.append(table_res_dict)
+                del table_res_dict["table_res"]["cls_label"]
+                del table_res_dict["table_res"]["cls_score"]
             if wired_table_res_list:
                 for table_res_dict in tqdm(
                         wired_table_res_list, desc="Table-wired Predict"

+ 1 - 7
mineru/model/table/cls/paddle_table_cls.py

@@ -72,9 +72,6 @@ class PaddleTableClsModel:
         result = self.sess.run(None, {"x": x})
         idx = np.argmax(result)
         conf = float(np.max(result))
-        # logger.debug(f"Table classification result: {self.labels[idx]} with confidence {conf:.4f}")
-        if idx == 0 and conf < 0.8:
-            idx = 1
         return self.labels[idx], conf
 
     def list_2_batch(self, img_list, batch_size=16):
@@ -134,7 +131,7 @@ class PaddleTableClsModel:
         x = np.stack(res_imgs, axis=0).astype(dtype=np.float32, copy=False)
         return x
     def batch_predict(self, img_info_list, batch_size=16):
-        imgs = [item["table_img"] for item in img_info_list]
+        imgs = [item["wired_table_img"] for item in img_info_list]
         imgs = self.list_2_batch(imgs, batch_size=batch_size)
         label_res = []
         with tqdm(total=len(img_info_list), desc="Table-wired/wireless cls predict", disable=True) as pbar:
@@ -144,9 +141,6 @@ class PaddleTableClsModel:
                 for img_res in result[0]:
                     idx = np.argmax(img_res)
                     conf = float(np.max(img_res))
-                    # logger.debug(f"Table classification result: {self.labels[idx]} with confidence {conf:.4f}")
-                    # if idx == 0 and conf < 0.8:
-                    #     idx = 1
                     label_res.append((self.labels[idx],conf))
                 pbar.update(len(img_batch))
             for img_info, (label, conf) in zip(img_info_list, label_res):

+ 17 - 4
mineru/model/table/rec/unet_table/main.py

@@ -184,7 +184,7 @@ class WiredTableRecognition:
             # 从img中截取对应的区域
             x1, y1, x2, y2 = int(box[0][0])+1, int(box[0][1])+1, int(box[2][0])-1, int(box[2][1])-1
             if x1 >= x2 or y1 >= y2:
-                logger.warning(f"Invalid box coordinates: {box}")
+                # logger.warning(f"Invalid box coordinates: {x1, y1, x2, y2}")
                 continue
             # 判断长宽比
             if (x2 - x1) / (y2 - y1) > 20 or (y2 - y1) / (x2 - x1) > 20:
@@ -308,11 +308,24 @@ class UnetTableModel:
             wired_blank_count = sum(1 for cell in wired_soup.find_all(['td', 'th']) if not cell.text.strip())
             # logger.debug(f"wireless table blank cell count: {wireless_blank_count}, wired table blank cell count: {wired_blank_count}")
 
+            # 计算非空单元格数量
+            wireless_non_blank_count = wireless_len - wireless_blank_count
+            wired_non_blank_count = wired_len - wired_blank_count
+            # 无线表非空格数量大于有线表非空格数量时,才考虑切换
+            switch_flag = False
+            if wireless_non_blank_count > wired_non_blank_count:
+                # 假设非空表格是接近正方表,使用非空单元格数量开平方作为表格规模的估计
+                wired_table_scale = round(wired_non_blank_count ** 0.5)
+                # logger.debug(f"wireless non-blank cell count: {wireless_non_blank_count}, wired non-blank cell count: {wired_non_blank_count}, wired table scale: {wired_table_scale}")
+                # 如果无线表非空格的数量比有线表多一列或以上,需要切换到无线表
+                wired_scale_plus_2_cols = wired_non_blank_count + (wired_table_scale * 2)
+                wired_scale_squared_plus_2_rows = wired_table_scale * (wired_table_scale + 2)
+                if (wireless_non_blank_count + 3) >= max(wired_scale_plus_2_cols, wired_scale_squared_plus_2_rows):
+                    switch_flag = True
+
             # 判断是否使用无线表格模型的结果
             if (
-                # (int(wireless_len * 0.04) <= wired_len <= int(wireless_len * 0.62)+1 and wireless_blank_count <= wired_blank_count+50)
-                # or int(wireless_len * 0.04) <= wired_len <= int(wireless_len * 0.55)+1 # 有线模型检测到的单元格数太少(低于无线模型的55%)
-                (int(wireless_len * 0.04) <= (wired_len-wired_blank_count) <= int((wireless_len-wireless_blank_count) * 0.76) and wired_len <= int(wireless_len * 0.5)) # 非空表数量有线表明显少于无线表模型60%
+                switch_flag
                 or (0 <= gap_of_len <= 5 and wired_len <= round(wireless_len * 0.75))  # 两者相差不大但有线模型结果较少
                 or (gap_of_len == 0 and wired_len <= 4)  # 单元格数量完全相等且总量小于等于4
                 or (wired_text_count <= wireless_text_count * 0.6 and  wireless_text_count >=10) # 有线模型填入的文字明显少于无线模型