Przeglądaj źródła

Merge branch 'opendatalab:dev' into dev

Sidney233 2 miesięcy temu
rodzic
commit
41017331c6

+ 2 - 2
README.md

@@ -37,7 +37,7 @@
 <!-- join us -->
 
 <p align="center">
-    👋 join us on <a href="https://discord.gg/Tdedn9GTXq" target="_blank">Discord</a> and <a href="http://mineru.space/s/V85Yl" target="_blank">WeChat</a>
+    👋 join us on <a href="https://discord.gg/Tdedn9GTXq" target="_blank">Discord</a> and <a href="https://mineru.net/community-portal/?aliasId=3c430f94" target="_blank">WeChat</a>
 </p>
 
 </div>
@@ -576,7 +576,7 @@ You can use MinerU for PDF parsing through various methods such as command line,
 
 - If you encounter any issues during usage, you can first check the [FAQ](https://opendatalab.github.io/MinerU/faq/) for solutions.  
 - If your issue remains unresolved, you may also use [DeepWiki](https://deepwiki.com/opendatalab/MinerU) to interact with an AI assistant, which can address most common problems.  
-- If you still cannot resolve the issue, you are welcome to join our community via [Discord](https://discord.gg/Tdedn9GTXq) or [WeChat](http://mineru.space/s/V85Yl) to discuss with other users and developers.
+- If you still cannot resolve the issue, you are welcome to join our community via [Discord](https://discord.gg/Tdedn9GTXq) or [WeChat](https://mineru.net/community-portal/?aliasId=3c430f94) to discuss with other users and developers.
 
 # All Thanks To Our Contributors
 

+ 2 - 2
README_zh-CN.md

@@ -37,7 +37,7 @@
 <!-- join us -->
 
 <p align="center">
-    👋 join us on <a href="https://discord.gg/Tdedn9GTXq" target="_blank">Discord</a> and <a href="http://mineru.space/s/V85Yl" target="_blank">WeChat</a>
+    👋 join us on <a href="https://discord.gg/Tdedn9GTXq" target="_blank">Discord</a> and <a href="https://mineru.net/community-portal/?aliasId=3c430f94" target="_blank">WeChat</a>
 </p>
 
 </div>
@@ -564,7 +564,7 @@ mineru -p <input_path> -o <output_path>
  
 - 如果您在使用过程中遇到问题,可以先查看[常见问题](https://opendatalab.github.io/MinerU/zh/faq/)是否有解答。  
 - 如果未能解决您的问题,您也可以使用[DeepWiki](https://deepwiki.com/opendatalab/MinerU)与AI助手交流,这可以解决大部分常见问题。  
-- 如果您仍然无法解决问题,您可通过[Discord](https://discord.gg/Tdedn9GTXq)或[WeChat](http://mineru.space/s/V85Yl)加入社区,与其他用户和开发者交流。
+- 如果您仍然无法解决问题,您可通过[Discord](https://discord.gg/Tdedn9GTXq)或[WeChat](https://mineru.net/community-portal/?aliasId=3c430f94)加入社区,与其他用户和开发者交流。
 
 # All Thanks To Our Contributors
 

+ 1 - 1
docs/en/faq/index.md

@@ -2,7 +2,7 @@
 
 If your question is not listed, try using [DeepWiki](https://deepwiki.com/opendatalab/MinerU)'s AI assistant for common issues.
 
-For unresolved problems, join our [Discord](https://discord.gg/Tdedn9GTXq) or [WeChat](http://mineru.space/s/V85Yl) community for support.
+For unresolved problems, join our [Discord](https://discord.gg/Tdedn9GTXq) or [WeChat](https://mineru.net/community-portal/?aliasId=3c430f94) community for support.
 
 ??? question "Encountered the error `ImportError: libGL.so.1: cannot open shared object file: No such file or directory` in Ubuntu 22.04 on WSL2"
 

+ 1 - 1
docs/en/index.md

@@ -34,7 +34,7 @@
 <!-- join us -->
 
 <p align="center">
-    👋 join us on <a href="https://discord.gg/Tdedn9GTXq" target="_blank">Discord</a> and <a href="http://mineru.space/s/V85Yl" target="_blank">WeChat</a>
+    👋 join us on <a href="https://discord.gg/Tdedn9GTXq" target="_blank">Discord</a> and <a href="https://mineru.net/community-portal/?aliasId=3c430f94" target="_blank">WeChat</a>
 </p>
 </div>
 

+ 1 - 1
docs/zh/faq/index.md

@@ -2,7 +2,7 @@
 
 如果未能列出您的问题,您也可以使用[DeepWiki](https://deepwiki.com/opendatalab/MinerU)与AI助手交流,这可以解决大部分常见问题。
 
-如果您仍然无法解决问题,您可通过[Discord](https://discord.gg/Tdedn9GTXq)或[WeChat](http://mineru.space/s/V85Yl)加入社区,与其他用户和开发者交流。
+如果您仍然无法解决问题,您可通过[Discord](https://discord.gg/Tdedn9GTXq)或[WeChat](https://mineru.net/community-portal/?aliasId=3c430f94)加入社区,与其他用户和开发者交流。
 
 ??? question "在WSL2的Ubuntu22.04中遇到报错`ImportError: libGL.so.1: cannot open shared object file: No such file or directory`"
 

+ 1 - 1
docs/zh/index.md

@@ -33,7 +33,7 @@
 <!-- join us -->
 
 <p align="center">
-    👋 join us on <a href="https://discord.gg/Tdedn9GTXq" target="_blank">Discord</a> and <a href="http://mineru.space/s/V85Yl" target="_blank">WeChat</a>
+    👋 join us on <a href="https://discord.gg/Tdedn9GTXq" target="_blank">Discord</a> and <a href="https://mineru.net/community-portal/?aliasId=3c430f94" target="_blank">WeChat</a>
 </p>
 </div>
 

+ 34 - 24
mineru/backend/pipeline/batch_analyze.py

@@ -9,7 +9,7 @@ import numpy as np
 from .model_init import AtomModelSingleton
 from .model_list import AtomicModel
 from ...utils.config_reader import get_formula_enable, get_table_enable
-from ...utils.model_utils import crop_img, get_res_list_from_layout_res
+from ...utils.model_utils import crop_img, get_res_list_from_layout_res, clean_vram
 from ...utils.ocr_utils import merge_det_boxes, update_det_boxes, sorted_boxes
 from ...utils.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list, OcrConfidence, get_rotate_crop_image
 from ...utils.pdf_image_tools import get_crop_np_img
@@ -71,7 +71,7 @@ class BatchAnalyze:
                 mfr_count += len(images_formula_list[image_index])
 
         # 清理显存
-        # clean_vram(self.model.device, vram_threshold=8)
+        clean_vram(self.model.device, vram_threshold=8)
 
         ocr_res_list_all_page = []
         table_res_list_all_page = []
@@ -93,18 +93,19 @@ class BatchAnalyze:
                                           })
 
             for table_res in table_res_list:
-                # table_img, _ = crop_img(table_res, pil_img)
-                # bbox = (241, 208, 1475, 2019)
-                scale = 10/3
-                # scale = 1
-                crop_xmin, crop_ymin = int(table_res['poly'][0]), int(table_res['poly'][1])
-                crop_xmax, crop_ymax = int(table_res['poly'][4]), int(table_res['poly'][5])
-                bbox = (int(crop_xmin/scale), int(crop_ymin/scale), int(crop_xmax/scale), int(crop_ymax/scale))
-                table_img = get_crop_np_img(bbox, np_img, scale=scale)
+                def get_crop_table_img(scale):
+                    crop_xmin, crop_ymin = int(table_res['poly'][0]), int(table_res['poly'][1])
+                    crop_xmax, crop_ymax = int(table_res['poly'][4]), int(table_res['poly'][5])
+                    bbox = (int(crop_xmin / scale), int(crop_ymin / scale), int(crop_xmax / scale), int(crop_ymax / scale))
+                    return get_crop_np_img(bbox, np_img, scale=scale)
+
+                wireless_table_img = get_crop_table_img(scale = 1)
+                wired_table_img = get_crop_table_img(scale = 10/3)
 
                 table_res_list_all_page.append({'table_res':table_res,
                                                 'lang':_lang,
-                                                'table_img':table_img,
+                                                'table_img':wireless_table_img,
+                                                'wired_table_img':wired_table_img,
                                               })
 
         # 表格识别 table recognition
@@ -137,18 +138,17 @@ class BatchAnalyze:
 
             # OCR det 过程,顺序执行
             rec_img_lang_group = defaultdict(list)
+            det_ocr_engine = atom_model_manager.get_atom_model(
+                atom_model_name=AtomicModel.OCR,
+                det_db_box_thresh=0.5,
+                det_db_unclip_ratio=1.6,
+                enable_merge_det_boxes=False,
+            )
             for index, table_res_dict in enumerate(
                     tqdm(table_res_list_all_page, desc="Table-ocr det")
             ):
-                ocr_engine = atom_model_manager.get_atom_model(
-                    atom_model_name=AtomicModel.OCR,
-                    det_db_box_thresh=0.5,
-                    det_db_unclip_ratio=1.6,
-                    # lang= table_res_dict["lang"],
-                    enable_merge_det_boxes=False,
-                )
                 bgr_image = cv2.cvtColor(table_res_dict["table_img"], cv2.COLOR_RGB2BGR)
-                ocr_result = ocr_engine.ocr(bgr_image, rec=False)[0]
+                ocr_result = det_ocr_engine.ocr(bgr_image, rec=False)[0]
                 # 构造需要 OCR 识别的图片字典,包括cropped_img, dt_box, table_id,并按照语言进行分组
                 for dt_box in ocr_result:
                     rec_img_lang_group[_lang].append(
@@ -171,8 +171,7 @@ class BatchAnalyze:
                     enable_merge_det_boxes=False,
                 )
                 cropped_img_list = [item["cropped_img"] for item in rec_img_list]
-                ocr_res_list = \
-                ocr_engine.ocr(cropped_img_list, det=False, tqdm_enable=True, tqdm_desc="Table-ocr rec")[0]
+                ocr_res_list = ocr_engine.ocr(cropped_img_list, det=False, tqdm_enable=True, tqdm_desc=f"Table-ocr rec {_lang}")[0]
                 # 按照 table_id 将识别结果进行回填
                 for img_dict, ocr_res in zip(rec_img_list, ocr_res_list):
                     if table_res_list_all_page[img_dict["table_id"]].get("ocr_result"):
@@ -184,6 +183,8 @@ class BatchAnalyze:
                             [img_dict["dt_box"], html.escape(ocr_res[0]), ocr_res[1]]
                         ]
 
+            clean_vram(self.model.device, vram_threshold=8)
+
             # 先对所有表格使用无线表格模型,然后对分类为有线的表格使用有线表格模型
             wireless_table_model = atom_model_manager.get_atom_model(
                 atom_model_name=AtomicModel.WirelessTable,
@@ -193,18 +194,27 @@ class BatchAnalyze:
             # 单独拿出有线表格进行预测
             wired_table_res_list = []
             for table_res_dict in table_res_list_all_page:
-                if table_res_dict["table_res"]["cls_label"] == AtomicModel.WiredTable:
+                # logger.debug(f"Table classification result: {table_res_dict["table_res"]["cls_label"]} with confidence {table_res_dict["table_res"]["cls_score"]}")
+                if (
+                    (table_res_dict["table_res"]["cls_label"] == AtomicModel.WirelessTable and table_res_dict["table_res"]["cls_score"] < 0.9)
+                    or table_res_dict["table_res"]["cls_label"] == AtomicModel.WiredTable
+                ):
                     wired_table_res_list.append(table_res_dict)
+                del table_res_dict["table_res"]["cls_label"]
+                del table_res_dict["table_res"]["cls_score"]
             if wired_table_res_list:
                 for table_res_dict in tqdm(
                         wired_table_res_list, desc="Table-wired Predict"
                 ):
+                    if not table_res_dict.get("ocr_result", None):
+                        continue
+
                     wired_table_model = atom_model_manager.get_atom_model(
                         atom_model_name=AtomicModel.WiredTable,
                         lang=table_res_dict["lang"],
                     )
                     table_res_dict["table_res"]["html"] = wired_table_model.predict(
-                        table_res_dict["table_img"],
+                        table_res_dict["wired_table_img"],
                         table_res_dict["ocr_result"],
                         table_res_dict["table_res"].get("html", None)
                     )
@@ -428,7 +438,7 @@ class BatchAnalyze:
                                                layout_res_item['poly'][4], layout_res_item['poly'][5]]
                             layout_res_width = layout_res_bbox[2] - layout_res_bbox[0]
                             layout_res_height = layout_res_bbox[3] - layout_res_bbox[1]
-                            if ocr_text in ['(204号', '(20', '(2', '(2号', '(20号'] and ocr_score < 0.8 and layout_res_width < layout_res_height:
+                            if ocr_text in ['(204号', '(20', '(2', '(2号', '(20号', '号', '(204'] and ocr_score < 0.8 and layout_res_width < layout_res_height:
                                 layout_res_item['category_id'] = 16
 
                     total_processed += len(img_crop_list)

+ 1 - 0
mineru/backend/pipeline/model_init.py

@@ -10,6 +10,7 @@ from ...model.mfr.unimernet.Unimernet import UnimernetModel
 from ...model.ocr.paddleocr2pytorch.pytorch_paddle import PytorchPaddleOCR
 from ...model.ori_cls.paddle_ori_cls import PaddleOrientationClsModel
 from ...model.table.cls.paddle_table_cls import PaddleTableClsModel
+# from ...model.table.rec.RapidTable import RapidTableModel
 from ...model.table.rec.slanet_plus.main import RapidTableModel
 from ...model.table.rec.unet_table.main import UnetTableModel
 from ...utils.enum_class import ModelPath

+ 4 - 2
mineru/backend/vlm/hf_predictor.py

@@ -137,12 +137,14 @@ class HuggingfacePredictor(BasePredictor):
         image_tensor = image_tensor.to(device=self.model.device, dtype=self.model.dtype)
         image_sizes = [[*image_obj.size]]
 
-        input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
-        input_ids = input_ids.to(device=self.model.device)
+        encoded_inputs = self.tokenizer(prompt, return_tensors="pt")
+        input_ids = encoded_inputs.input_ids.to(device=self.model.device)
+        attention_mask = encoded_inputs.attention_mask.to(device=self.model.device)
 
         with torch.inference_mode():
             output_ids = self.model.generate(
                 input_ids,
+                attention_mask=attention_mask,
                 images=image_tensor,
                 image_sizes=image_sizes,
                 use_cache=True,

+ 16 - 2
mineru/model/ori_cls/paddle_ori_cls.py

@@ -174,6 +174,12 @@ class PaddleOrientationClsModel:
     def batch_predict(
         self, imgs: List[Dict], det_batch_size: int, batch_size: int = 16
     ) -> None:
+
+        import torch
+        from packaging import version
+        if version.parse(torch.__version__) >= version.parse("2.8.0"):
+            return None
+
         """
         批量预测传入的包含图片信息列表的旋转信息,并且将旋转过的图片正确地旋转回来
         """
@@ -195,7 +201,7 @@ class PaddleOrientationClsModel:
 
         # 对每个分辨率组进行批处理
         rotated_imgs = []
-        for group_key, group_imgs in tqdm(resolution_groups.items(), desc="Table-ori cls stage1 predict"):
+        for group_key, group_imgs in tqdm(resolution_groups.items(), desc="Table-ori cls stage1 predict", disable=True):
             # 计算目标尺寸(组内最大尺寸,向上取整到RESOLUTION_GROUP_STRIDE的倍数)
             max_h = max(img["table_img_bgr"].shape[0] for img in group_imgs)
             max_w = max(img["table_img_bgr"].shape[1] for img in group_imgs)
@@ -243,7 +249,7 @@ class PaddleOrientationClsModel:
         # 对旋转的图片进行旋转角度预测
         if len(rotated_imgs) > 0:
             imgs = self.list_2_batch(rotated_imgs, batch_size=batch_size)
-            with tqdm(total=len(rotated_imgs), desc="Table-ori cls stage2 predict") as pbar:
+            with tqdm(total=len(rotated_imgs), desc="Table-ori cls stage2 predict", disable=True) as pbar:
                 for img_batch in imgs:
                     x = self.batch_preprocess(img_batch)
                     results = self.sess.run(None, {"x": x})
@@ -254,11 +260,19 @@ class PaddleOrientationClsModel:
                                 np.asarray(img_info["table_img"]),
                                 cv2.ROTATE_90_CLOCKWISE,
                             )
+                            img_info["wired_table_img"] = cv2.rotate(
+                                np.asarray(img_info["wired_table_img"]),
+                                cv2.ROTATE_90_CLOCKWISE,
+                            )
                         elif label == "90":
                             img_info["table_img"] = cv2.rotate(
                                 np.asarray(img_info["table_img"]),
                                 cv2.ROTATE_90_COUNTERCLOCKWISE,
                             )
+                            img_info["wired_table_img"] = cv2.rotate(
+                                np.asarray(img_info["wired_table_img"]),
+                                cv2.ROTATE_90_COUNTERCLOCKWISE,
+                            )
                         else:
                             # 180度和0度不做处理
                             pass

+ 2 - 9
mineru/model/table/cls/paddle_table_cls.py

@@ -1,5 +1,4 @@
 import os
-from pathlib import Path
 
 from PIL import Image
 import cv2
@@ -73,9 +72,6 @@ class PaddleTableClsModel:
         result = self.sess.run(None, {"x": x})
         idx = np.argmax(result)
         conf = float(np.max(result))
-        # logger.debug(f"Table classification result: {self.labels[idx]} with confidence {conf:.4f}")
-        if idx == 0 and conf < 0.8:
-            idx = 1
         return self.labels[idx], conf
 
     def list_2_batch(self, img_list, batch_size=16):
@@ -135,19 +131,16 @@ class PaddleTableClsModel:
         x = np.stack(res_imgs, axis=0).astype(dtype=np.float32, copy=False)
         return x
     def batch_predict(self, img_info_list, batch_size=16):
-        imgs = [item["table_img"] for item in img_info_list]
+        imgs = [item["wired_table_img"] for item in img_info_list]
         imgs = self.list_2_batch(imgs, batch_size=batch_size)
         label_res = []
-        with tqdm(total=len(img_info_list), desc="Table-wired/wireless cls predict") as pbar:
+        with tqdm(total=len(img_info_list), desc="Table-wired/wireless cls predict", disable=True) as pbar:
             for img_batch in imgs:
                 x = self.batch_preprocess(img_batch)
                 result = self.sess.run(None, {"x": x})
                 for img_res in result[0]:
                     idx = np.argmax(img_res)
                     conf = float(np.max(img_res))
-                    # logger.debug(f"Table classification result: {self.labels[idx]} with confidence {conf:.4f}")
-                    if idx == 0 and conf < 0.8:
-                        idx = 1
                     label_res.append((self.labels[idx],conf))
                 pbar.update(len(img_batch))
             for img_info, (label, conf) in zip(img_info_list, label_res):

+ 154 - 0
mineru/model/table/rec/RapidTable.py

@@ -0,0 +1,154 @@
+import html
+import os
+import time
+from pathlib import Path
+from typing import List
+
+import cv2
+import numpy as np
+from loguru import logger
+from rapid_table import ModelType, RapidTable, RapidTableInput
+from rapid_table.utils import RapidTableOutput
+from tqdm import tqdm
+
+from mineru.model.ocr.paddleocr2pytorch.pytorch_paddle import PytorchPaddleOCR
+from mineru.utils.enum_class import ModelPath
+from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
+
+
+def escape_html(input_string):
+    """Escape HTML Entities."""
+    return html.escape(input_string)
+
+
+class CustomRapidTable(RapidTable):
+    def __init__(self, cfg: RapidTableInput):
+        import logging
+        # 通过环境变量控制日志级别
+        logging.disable(logging.INFO)
+        super().__init__(cfg)
+    def __call__(self, img_contents, ocr_results=None, batch_size=1):
+        if not isinstance(img_contents, list):
+            img_contents = [img_contents]
+
+        s = time.perf_counter()
+
+        results = RapidTableOutput()
+
+        total_nums = len(img_contents)
+
+        with tqdm(total=total_nums, desc="Table-wireless Predict") as pbar:
+            for start_i in range(0, total_nums, batch_size):
+                end_i = min(total_nums, start_i + batch_size)
+
+                imgs = self._load_imgs(img_contents[start_i:end_i])
+
+                pred_structures, cell_bboxes = self.table_structure(imgs)
+                logic_points = self.table_matcher.decode_logic_points(pred_structures)
+
+                dt_boxes, rec_res = self.get_ocr_results(imgs, start_i, end_i, ocr_results)
+                pred_htmls = self.table_matcher(
+                    pred_structures, cell_bboxes, dt_boxes, rec_res
+                )
+
+                results.pred_htmls.extend(pred_htmls)
+                # 更新进度条
+                pbar.update(end_i - start_i)
+
+        elapse = time.perf_counter() - s
+        results.elapse = elapse / total_nums
+        return results
+
+
+class RapidTableModel():
+    def __init__(self, ocr_engine):
+        slanet_plus_model_path = os.path.join(
+            auto_download_and_get_model_root_path(ModelPath.slanet_plus),
+            ModelPath.slanet_plus,
+        )
+        input_args = RapidTableInput(
+            model_type=ModelType.SLANETPLUS,
+            model_dir_or_path=slanet_plus_model_path,
+            use_ocr=False
+        )
+        self.table_model = CustomRapidTable(input_args)
+        self.ocr_engine = ocr_engine
+
+    def predict(self, image, ocr_result=None):
+        bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
+        # Continue with OCR on potentially rotated image
+
+        if not ocr_result:
+            raw_ocr_result = self.ocr_engine.ocr(bgr_image)[0]
+            # 分离边界框、文本和置信度
+            boxes = []
+            texts = []
+            scores = []
+            for item in raw_ocr_result:
+                if len(item) == 3:
+                    boxes.append(item[0])
+                    texts.append(escape_html(item[1]))
+                    scores.append(item[2])
+                elif len(item) == 2 and isinstance(item[1], tuple):
+                    boxes.append(item[0])
+                    texts.append(escape_html(item[1][0]))
+                    scores.append(item[1][1])
+            # 按照 rapid_table 期望的格式构建 ocr_results
+            ocr_result = [(boxes, texts, scores)]
+
+        if ocr_result:
+            try:
+                table_results = self.table_model(img_contents=np.asarray(image), ocr_results=ocr_result)
+                html_code = table_results.pred_htmls
+                table_cell_bboxes = table_results.cell_bboxes
+                logic_points = table_results.logic_points
+                elapse = table_results.elapse
+                return html_code, table_cell_bboxes, logic_points, elapse
+            except Exception as e:
+                logger.exception(e)
+
+        return None, None, None, None
+
+    def batch_predict(self, table_res_list: List[dict], batch_size: int = 4):
+        not_none_table_res_list = []
+        for table_res in table_res_list:
+            if table_res.get("ocr_result", None):
+                not_none_table_res_list.append(table_res)
+
+        if not_none_table_res_list:
+            img_contents = [table_res["table_img"] for table_res in not_none_table_res_list]
+            ocr_results = []
+            # ocr_results需要按照rapid_table期望的格式构建
+            for table_res in not_none_table_res_list:
+                raw_ocr_result = table_res["ocr_result"]
+                boxes = []
+                texts = []
+                scores = []
+                for item in raw_ocr_result:
+                    if len(item) == 3:
+                        boxes.append(item[0])
+                        texts.append(escape_html(item[1]))
+                        scores.append(item[2])
+                    elif len(item) == 2 and isinstance(item[1], tuple):
+                        boxes.append(item[0])
+                        texts.append(escape_html(item[1][0]))
+                        scores.append(item[1][1])
+                ocr_results.append((boxes, texts, scores))
+            table_results = self.table_model(img_contents=img_contents, ocr_results=ocr_results, batch_size=batch_size)
+
+            for i, result in enumerate(table_results.pred_htmls):
+                if result:
+                    not_none_table_res_list[i]['table_res']['html'] = result
+
+if __name__ == '__main__':
+    ocr_engine= PytorchPaddleOCR(
+            det_db_box_thresh=0.5,
+            det_db_unclip_ratio=1.6,
+            enable_merge_det_boxes=False,
+    )
+    table_model = RapidTableModel(ocr_engine)
+    img_path = Path(r"D:\project\20240729ocrtest\pythonProject\images\601c939cc6dabaf07af763e2f935f54896d0251f37cc47beb7fc6b069353455d.jpg")
+    image = cv2.imread(str(img_path))
+    html_code, table_cell_bboxes, logic_points, elapse = table_model.predict(image)
+    print(html_code)
+

+ 20 - 74
mineru/model/table/rec/slanet_plus/main.py

@@ -1,11 +1,8 @@
 import os
-import argparse
 import copy
-import importlib
 import time
 import html
 from dataclasses import asdict, dataclass
-from enum import Enum
 from pathlib import Path
 from typing import Dict, List, Optional, Tuple, Union
 
@@ -19,32 +16,10 @@ from .table_structure import TableStructurer
 from mineru.utils.enum_class import ModelPath
 from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
 
-root_dir = Path(__file__).resolve().parent
-
-
-class ModelType(Enum):
-    PPSTRUCTURE_EN = "ppstructure_en"
-    PPSTRUCTURE_ZH = "ppstructure_zh"
-    SLANETPLUS = "slanet_plus"
-    UNITABLE = "unitable"
-
-
-ROOT_URL = "https://www.modelscope.cn/models/RapidAI/RapidTable/resolve/master/"
-KEY_TO_MODEL_URL = {
-    ModelType.PPSTRUCTURE_EN.value: f"{ROOT_URL}/en_ppstructure_mobile_v2_SLANet.onnx",
-    ModelType.PPSTRUCTURE_ZH.value: f"{ROOT_URL}/ch_ppstructure_mobile_v2_SLANet.onnx",
-    ModelType.SLANETPLUS.value: f"{ROOT_URL}/slanet-plus.onnx",
-    ModelType.UNITABLE.value: {
-        "encoder": f"{ROOT_URL}/unitable/encoder.pth",
-        "decoder": f"{ROOT_URL}/unitable/decoder.pth",
-        "vocab": f"{ROOT_URL}/unitable/vocab.json",
-    },
-}
-
 
 @dataclass
 class RapidTableInput:
-    model_type: Optional[str] = ModelType.SLANETPLUS.value
+    model_type: Optional[str] = "slanet_plus"
     model_path: Union[str, Path, None, Dict[str, str]] = None
     use_cuda: bool = False
     device: str = "cpu"
@@ -60,18 +35,7 @@ class RapidTableOutput:
 
 class RapidTable:
     def __init__(self, config: RapidTableInput):
-        self.model_type = config.model_type
-        if self.model_type not in KEY_TO_MODEL_URL:
-            model_list = ",".join(KEY_TO_MODEL_URL)
-            raise ValueError(
-                f"{self.model_type} is not supported. The currently supported models are {model_list}."
-            )
-
-        config.model_path = config.model_path
-        if self.model_type == ModelType.SLANETPLUS.value:
-            self.table_structure = TableStructurer(asdict(config))
-        else:
-            raise ValueError(f"{self.model_type} is not supported.")
+        self.table_structure = TableStructurer(asdict(config))
         self.table_matcher = TableMatch()
 
     def predict(
@@ -177,29 +141,6 @@ class RapidTable:
         return cell_bboxes
 
 
-def parse_args(arg_list: Optional[List[str]] = None):
-    parser = argparse.ArgumentParser()
-    parser.add_argument(
-        "-v",
-        "--vis",
-        action="store_true",
-        default=False,
-        help="Wheter to visualize the layout results.",
-    )
-    parser.add_argument(
-        "-img", "--img_path", type=str, required=True, help="Path to image for layout."
-    )
-    parser.add_argument(
-        "-m",
-        "--model_type",
-        type=str,
-        default=ModelType.SLANETPLUS.value,
-        choices=list(KEY_TO_MODEL_URL),
-    )
-    args = parser.parse_args(arg_list)
-    return args
-
-
 def escape_html(input_string):
     """Escape HTML Entities."""
     return html.escape(input_string)
@@ -217,18 +158,17 @@ class RapidTableModel(object):
         self.table_model = RapidTable(input_args)
         self.ocr_engine = ocr_engine
 
-    def predict(self, image, table_cls_score):
+    def predict(self, image, ocr_result=None):
         bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
         # Continue with OCR on potentially rotated image
-        ocr_result = self.ocr_engine.ocr(bgr_image)[0]
-        if ocr_result:
+
+        if not ocr_result:
+            ocr_result = self.ocr_engine.ocr(bgr_image)[0]
             ocr_result = [
                 [item[0], escape_html(item[1][0]), item[1][1]]
                 for item in ocr_result
                 if len(item) == 2 and isinstance(item[1], tuple)
             ]
-        else:
-            ocr_result = None
 
         if ocr_result:
             try:
@@ -245,22 +185,28 @@ class RapidTableModel(object):
 
     def batch_predict(self, table_res_list: List[Dict], batch_size: int = 4) -> None:
         """对传入的字典列表进行批量预测,无返回值"""
-        with tqdm(total=len(table_res_list), desc="Table-wireless Predict") as pbar:
-            for index in range(0, len(table_res_list), batch_size):
+
+        not_none_table_res_list = []
+        for table_res in table_res_list:
+            if table_res.get("ocr_result", None):
+                not_none_table_res_list.append(table_res)
+
+        with tqdm(total=len(not_none_table_res_list), desc="Table-wireless Predict") as pbar:
+            for index in range(0, len(not_none_table_res_list), batch_size):
                 batch_imgs = [
-                    cv2.cvtColor(np.asarray(table_res_list[i]["table_img"]), cv2.COLOR_RGB2BGR)
-                    for i in range(index, min(index + batch_size, len(table_res_list)))
+                    cv2.cvtColor(np.asarray(not_none_table_res_list[i]["table_img"]), cv2.COLOR_RGB2BGR)
+                    for i in range(index, min(index + batch_size, len(not_none_table_res_list)))
                 ]
                 batch_ocrs = [
-                    table_res_list[i]["ocr_result"]
-                    for i in range(index, min(index + batch_size, len(table_res_list)))
+                    not_none_table_res_list[i]["ocr_result"]
+                    for i in range(index, min(index + batch_size, len(not_none_table_res_list)))
                 ]
                 results = self.table_model.batch_predict(
                     batch_imgs, batch_ocrs, batch_size=batch_size
                 )
                 for i, result in enumerate(results):
                     if result.pred_html:
-                        table_res_list[index + i]['table_res']['html'] = result.pred_html
+                        not_none_table_res_list[index + i]['table_res']['html'] = result.pred_html
 
                 # 更新进度条
-                pbar.update(len(results))
+                pbar.update(len(results))

+ 51 - 9
mineru/model/table/rec/unet_table/main.py

@@ -10,6 +10,7 @@ import numpy as np
 import cv2
 from PIL import Image
 from loguru import logger
+from bs4 import BeautifulSoup
 
 from .table_structure_unet import TSRUnet
 
@@ -64,7 +65,7 @@ class WiredTableRecognition:
         img = self.load_img(img)
         polygons, rotated_polygons = self.table_structure(img, **kwargs)
         if polygons is None:
-            logging.warning("polygons is None.")
+            # logging.warning("polygons is None.")
             return WiredTableOutput("", None, None, 0.0)
 
         try:
@@ -181,9 +182,9 @@ class WiredTableRecognition:
                 logger.warning(f"No OCR engine provided for box {i}: {box}")
                 continue
             # 从img中截取对应的区域
-            x1, y1, x2, y2 = int(box[0][0]), int(box[0][1]), int(box[2][0]), int(box[2][1])
+            x1, y1, x2, y2 = int(box[0][0])+1, int(box[0][1])+1, int(box[2][0])-1, int(box[2][1])-1
             if x1 >= x2 or y1 >= y2:
-                logger.warning(f"Invalid box coordinates: {box}")
+                # logger.warning(f"Invalid box coordinates: {x1, y1, x2, y2}")
                 continue
             # 判断长宽比
             if (x2 - x1) / (y2 - y1) > 20 or (y2 - y1) / (x2 - x1) > 20:
@@ -196,6 +197,14 @@ class WiredTableRecognition:
         if len(img_crop_list) > 0:
             # 进行ocr识别
             ocr_result = self.ocr_engine.ocr(img_crop_list, det=False)
+            # ocr_result = [[]]
+            # for crop_img in img_crop_list:
+            #     tmp_ocr_result = self.ocr_engine.ocr(crop_img)
+            #     if tmp_ocr_result[0] and len(tmp_ocr_result[0]) > 0 and isinstance(tmp_ocr_result[0], list) and len(tmp_ocr_result[0][0]) == 2:
+            #         ocr_result[0].append(tmp_ocr_result[0][0][1])
+            #     else:
+            #         ocr_result[0].append(("", 0.0))
+
             if not ocr_result or not isinstance(ocr_result, list) or len(ocr_result) == 0:
                 logger.warning("OCR engine returned no results or invalid result for image crops.")
                 return cell_box_map
@@ -210,10 +219,10 @@ class WiredTableRecognition:
                 # 处理ocr结果
                 ocr_text, ocr_score = ocr_res
                 # logger.debug(f"OCR result for box {i}: {ocr_text} with score {ocr_score}")
-                if ocr_score < 0.9 or ocr_text in ['1']:
+                if ocr_score < 0.6 or ocr_text in ['1','口','■','(204号', '(20', '(2', '(2号', '(20号', '号', '(204']:
                     # logger.warning(f"Low confidence OCR result for box {i}: {ocr_text} with score {ocr_score}")
                     box = sorted_polygons[i]
-                    cell_box_map[i] = [[box, "", 0.5]]
+                    cell_box_map[i] = [[box, "", 0.1]]
                     continue
                 cell_box_map[i] = [[box, ocr_text, ocr_score]]
 
@@ -275,18 +284,51 @@ class UnetTableModel:
             # )
 
             wired_html_code = wired_table_results.pred_html
-
             wired_len = count_table_cells_physical(wired_html_code)
             wireless_len = count_table_cells_physical(wireless_html_code)
-
-            # logger.debug(f"wired table cell bboxes: {wired_len}, wireless table cell bboxes: {wireless_len}")
             # 计算两种模型检测的单元格数量差异
             gap_of_len = wireless_len - wired_len
+            # logger.debug(f"wired table cell bboxes: {wired_len}, wireless table cell bboxes: {wireless_len}")
+
+            # 使用OCR结果计算两种模型填入的文字数量
+            wireless_text_count = 0
+            wired_text_count = 0
+            for ocr_res in ocr_result:
+                if ocr_res[1] in wireless_html_code:
+                    wireless_text_count += 1
+                if ocr_res[1] in wired_html_code:
+                    wired_text_count += 1
+            # logger.debug(f"wireless table ocr text count: {wireless_text_count}, wired table ocr text count: {wired_text_count}")
+
+            # 使用HTML解析器计算空单元格数量
+            wireless_soup = BeautifulSoup(wireless_html_code, 'html.parser') if wireless_html_code else BeautifulSoup("", 'html.parser')
+            wired_soup = BeautifulSoup(wired_html_code, 'html.parser') if wired_html_code else BeautifulSoup("", 'html.parser')
+            # 计算空单元格数量(没有文本内容或只有空白字符)
+            wireless_blank_count = sum(1 for cell in wireless_soup.find_all(['td', 'th']) if not cell.text.strip())
+            wired_blank_count = sum(1 for cell in wired_soup.find_all(['td', 'th']) if not cell.text.strip())
+            # logger.debug(f"wireless table blank cell count: {wireless_blank_count}, wired table blank cell count: {wired_blank_count}")
+
+            # 计算非空单元格数量
+            wireless_non_blank_count = wireless_len - wireless_blank_count
+            wired_non_blank_count = wired_len - wired_blank_count
+            # 无线表非空格数量大于有线表非空格数量时,才考虑切换
+            switch_flag = False
+            if wireless_non_blank_count > wired_non_blank_count:
+                # 假设非空表格是接近正方表,使用非空单元格数量开平方作为表格规模的估计
+                wired_table_scale = round(wired_non_blank_count ** 0.5)
+                # logger.debug(f"wireless non-blank cell count: {wireless_non_blank_count}, wired non-blank cell count: {wired_non_blank_count}, wired table scale: {wired_table_scale}")
+                # 如果无线表非空格的数量比有线表多一列或以上,需要切换到无线表
+                wired_scale_plus_2_cols = wired_non_blank_count + (wired_table_scale * 2)
+                wired_scale_squared_plus_2_rows = wired_table_scale * (wired_table_scale + 2)
+                if (wireless_non_blank_count + 3) >= max(wired_scale_plus_2_cols, wired_scale_squared_plus_2_rows):
+                    switch_flag = True
+
             # 判断是否使用无线表格模型的结果
             if (
-                wired_len <= int(wireless_len * 0.55)+1  # 有线模型检测到的单元格数太少(低于无线模型的50%)
+                switch_flag
                 or (0 <= gap_of_len <= 5 and wired_len <= round(wireless_len * 0.75))  # 两者相差不大但有线模型结果较少
                 or (gap_of_len == 0 and wired_len <= 4)  # 单元格数量完全相等且总量小于等于4
+                or (wired_text_count <= wireless_text_count * 0.6 and  wireless_text_count >=10) # 有线模型填入的文字明显少于无线模型
             ):
                 # logger.debug("fall back to wireless table model")
                 html_code = wireless_html_code

+ 1 - 1
mineru/model/table/rec/unet_table/utils_table_line_rec.py

@@ -152,7 +152,7 @@ def calculate_center_rotate_angle(box):
     ) / 2
     # x = cx-w/2
     # y = cy-h/2
-    sinA = (h * (x1 - cx) - w * (y1 - cy)) * 1.0 / (h * h + w * w) * 2
+    sinA = (h * (x1 - cx) - w * (y1 - cy)) * 1.0 / (h * h + w * w + 1e-10) * 2
     angle = np.arcsin(sinA)
     return angle, w, h, cx, cy
 

+ 4 - 1
mineru/utils/model_utils.py

@@ -1,3 +1,4 @@
+import os
 import time
 import gc
 from PIL import Image
@@ -427,11 +428,13 @@ def clean_memory(device='cuda'):
 
 def clean_vram(device, vram_threshold=8):
     total_memory = get_vram(device)
+    if total_memory is not None:
+        total_memory = int(os.getenv('MINERU_VIRTUAL_VRAM_SIZE', round(total_memory)))
     if total_memory and total_memory <= vram_threshold:
         gc_start = time.time()
         clean_memory(device)
         gc_time = round(time.time() - gc_start, 2)
-        logger.info(f"gc time: {gc_time}")
+        # logger.info(f"gc time: {gc_time}")
 
 
 def get_vram(device):

+ 1 - 1
mineru/version.py

@@ -1 +1 @@
-__version__ = "2.1.10"
+__version__ = "2.1.11"

+ 1 - 1
mkdocs.yml

@@ -64,7 +64,7 @@ extra:
       link: https://discord.gg/Tdedn9GTXq
       name: Discord
     - icon: fontawesome/brands/weixin
-      link: http://mineru.space/s/V85Yl
+      link: https://mineru.net/community-portal/?aliasId=3c430f94
       name: WeChat
     - icon: material/email
       link: mailto:OpenDataLab@pjlab.org.cn

+ 1 - 0
pyproject.toml

@@ -37,6 +37,7 @@ dependencies = [
     "fast-langdetect>=0.2.3,<0.3.0",
     "scikit-image>=0.25.0,<1.0.0",
     "openai>=1.70.0,<2",
+    "beautifulsoup4>=4.13.5,<5",
 ]
 
 [project.optional-dependencies]