Bläddra i källkod

feat: add batch predict for slanet_plus

Sidney233 2 månader sedan
förälder
incheckning
512f40fdfb

+ 2 - 2
mineru/backend/pipeline/model_init.py

@@ -33,7 +33,7 @@ def table_cls_model_init():
     return PaddleTableClsModel()
 
 
-def wired_table_model_init(lang=None):
+def wired_table_model_init(lang="ch"):
     atom_model_manager = AtomModelSingleton()
     ocr_engine = atom_model_manager.get_atom_model(
         atom_model_name=AtomicModel.OCR,
@@ -46,7 +46,7 @@ def wired_table_model_init(lang=None):
     return table_model
 
 
-def wireless_table_model_init(lang=None):
+def wireless_table_model_init(lang="ch"):
     atom_model_manager = AtomModelSingleton()
     ocr_engine = atom_model_manager.get_atom_model(
         atom_model_name=AtomicModel.OCR,

+ 88 - 23
mineru/model/table/rec/slanet_plus/main.py

@@ -12,6 +12,8 @@ from typing import Dict, List, Optional, Tuple, Union
 import cv2
 import numpy as np
 from loguru import logger
+from tqdm import tqdm
+
 from .matcher import TableMatch
 from .table_structure import TableStructurer
 from mineru.utils.enum_class import ModelPath
@@ -72,40 +74,25 @@ class RapidTable:
             raise ValueError(f"{self.model_type} is not supported.")
         self.table_matcher = TableMatch()
 
-        try:
-            self.ocr_engine = importlib.import_module("rapidocr").RapidOCR()
-        except ModuleNotFoundError:
-            self.ocr_engine = None
-
-    def __call__(
+    def predict(
         self,
         img: np.ndarray,
         ocr_result: List[Union[List[List[float]], str, str]] = None,
     ) -> RapidTableOutput:
-        if self.ocr_engine is None and ocr_result is None:
-            raise ValueError(
-                "One of two conditions must be met: ocr_result is not empty, or rapidocr is installed."
-            )
+        if ocr_result is None:
+            raise ValueError("OCR result is None")
 
         s = time.perf_counter()
         h, w = img.shape[:2]
 
-        if ocr_result is None:
-            ocr_result = self.ocr_engine(img)
-            ocr_result = list(
-                zip(
-                    ocr_result.boxes,
-                    ocr_result.txts,
-                    ocr_result.scores,
-                )
-            )
         dt_boxes, rec_res = self.get_boxes_recs(ocr_result, h, w)
 
-        pred_structures, cell_bboxes, _ = self.table_structure(copy.deepcopy(img))
+        pred_structures, cell_bboxes, _ = self.table_structure.process(
+            copy.deepcopy(img)
+        )
 
         # 适配slanet-plus模型输出的box缩放还原
-        if self.model_type == ModelType.SLANETPLUS.value:
-            cell_bboxes = self.adapt_slanet_plus(img, cell_bboxes)
+        cell_bboxes = self.adapt_slanet_plus(img, cell_bboxes)
 
         pred_html = self.table_matcher(pred_structures, cell_bboxes, dt_boxes, rec_res)
 
@@ -117,6 +104,50 @@ class RapidTable:
         elapse = time.perf_counter() - s
         return RapidTableOutput(pred_html, cell_bboxes, logic_points, elapse)
 
+    def batch_predict(
+        self,
+        images: List[np.ndarray],
+        ocr_results: List[List[Union[List[List[float]], str, str]]],
+        batch_size: int = 4,
+    ) -> List[RapidTableOutput]:
+        """批量处理图像"""
+        s = time.perf_counter()
+
+        batch_dt_boxes = []
+        batch_rec_res = []
+
+        for i, img in enumerate(images):
+            h, w = img.shape[:2]
+            dt_boxes, rec_res = self.get_boxes_recs(ocr_results[i], h, w)
+            batch_dt_boxes.append(dt_boxes)
+            batch_rec_res.append(rec_res)
+
+        # 批量表格结构识别
+        batch_results = self.table_structure.batch_process(images)
+
+        output_results = []
+        for i, (img, ocr_result, (pred_structures, cell_bboxes, _)) in enumerate(
+            zip(images, ocr_results, batch_results)
+        ):
+            # 适配slanet-plus模型输出的box缩放还原
+            cell_bboxes = self.adapt_slanet_plus(img, cell_bboxes)
+            pred_html = self.table_matcher(
+                pred_structures, cell_bboxes, batch_dt_boxes[i], batch_rec_res[i]
+            )
+            # 过滤掉占位的bbox
+            mask = ~np.all(cell_bboxes == 0, axis=1)
+            cell_bboxes = cell_bboxes[mask]
+
+            logic_points = self.table_matcher.decode_logic_points(pred_structures)
+            result = RapidTableOutput(pred_html, cell_bboxes, logic_points, 0)
+            output_results.append(result)
+
+        total_elapse = time.perf_counter() - s
+        for result in output_results:
+            result.elapse = total_elapse / len(output_results)
+
+        return output_results
+
     def get_boxes_recs(
         self, ocr_result: List[Union[List[List[float]], str, str]], h: int, w: int
     ) -> Tuple[np.ndarray, Tuple[str, str]]:
@@ -201,7 +232,7 @@ class RapidTableModel(object):
 
         if ocr_result:
             try:
-                table_results = self.table_model(np.asarray(image), ocr_result)
+                table_results = self.table_model.predict(np.asarray(image), ocr_result)
                 html_code = table_results.pred_html
                 table_cell_bboxes = table_results.cell_bboxes
                 logic_points = table_results.logic_points
@@ -211,3 +242,37 @@ class RapidTableModel(object):
                 logger.exception(e)
 
         return None, None, None, None
+
+    def batch_predict(self, table_res_list: List[Dict], batch_size: int = 4) -> None:
+        """对传入的字典列表进行批量预测,无返回值"""
+        for index in tqdm(
+            range(0, len(table_res_list), batch_size),
+            desc=f"Table Batch Predict, total={len(table_res_list)}, batch_size={batch_size}",
+        ):
+            batch_imgs = [
+                cv2.cvtColor(np.asarray(table_res_list[i]["table_img"]), cv2.COLOR_RGB2BGR)
+                for i in range(index, min(index + batch_size, len(table_res_list)))
+            ]
+            batch_ocrs = [
+                table_res_list[i]["ocr_result"]
+                for i in range(index, min(index + batch_size, len(table_res_list)))
+            ]
+            results = self.table_model.batch_predict(
+                batch_imgs, batch_ocrs, batch_size=batch_size
+            )
+            for i, result in enumerate(results):
+                if result.pred_html:
+                    # 检查html_code是否包含'<table>'和'</table>'
+                    if '<table>' in result.pred_html and '</table>' in result.pred_html:
+                        # 选用<table>到</table>的内容,放入table_res_dict['table_res']['html']
+                        start_index = result.pred_html.find('<table>')
+                        end_index = result.pred_html.rfind('</table>') + len('</table>')
+                        table_res_list[index + i]['table_res']['html'] = result.pred_html[start_index:end_index]
+                    else:
+                        logger.warning(
+                            'table recognition processing fails, not found expected HTML table end'
+                        )
+                else:
+                    logger.warning(
+                        "table recognition processing fails, not get html return"
+                    )

+ 54 - 3
mineru/model/table/rec/slanet_plus/table_structure.py

@@ -12,23 +12,29 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import time
-from typing import Any, Dict
+from typing import Any, Dict, List, Tuple
 
 import numpy as np
 
-from .table_stucture_utils import OrtInferSession, TableLabelDecode, TablePreprocess
+from .table_stucture_utils import (
+    OrtInferSession,
+    TableLabelDecode,
+    TablePreprocess,
+    BatchTablePreprocess,
+)
 
 
 class TableStructurer:
     def __init__(self, config: Dict[str, Any]):
         self.preprocess_op = TablePreprocess()
+        self.batch_preprocess_op = BatchTablePreprocess()
 
         self.session = OrtInferSession(config)
 
         self.character = self.session.get_metadata()
         self.postprocess_op = TableLabelDecode(self.character)
 
-    def __call__(self, img):
+    def process(self, img):
         starttime = time.time()
         data = {"image": img}
         data = self.preprocess_op(data)
@@ -56,3 +62,48 @@ class TableStructurer:
         )
         elapse = time.time() - starttime
         return structure_str_list, bbox_list, elapse
+
+    def batch_process(
+        self, img_list: List[np.ndarray]
+    ) -> List[Tuple[List[str], np.ndarray, float]]:
+        """批量处理图像列表
+        Args:
+            img_list: 图像列表
+        Returns:
+            结果列表,每个元素包含 (table_struct_str, cell_bboxes, elapse)
+        """
+        starttime = time.perf_counter()
+
+        batch_data = self.batch_preprocess_op(img_list)
+        preprocessed_images = batch_data[0]
+        shape_lists = batch_data[1]
+
+        preprocessed_images = np.array(preprocessed_images)
+        bbox_preds, struct_probs = self.session([preprocessed_images])
+
+        batch_size = preprocessed_images.shape[0]
+        results = []
+        for bbox_pred, struct_prob, shape_list in zip(
+            bbox_preds, struct_probs, shape_lists
+        ):
+            preds = {
+                "loc_preds": np.expand_dims(bbox_pred, axis=0),
+                "structure_probs": np.expand_dims(struct_prob, axis=0),
+            }
+            shape_list = np.expand_dims(shape_list, axis=0)
+            post_result = self.postprocess_op(preds, [shape_list])
+            bbox_list = post_result["bbox_batch_list"][0]
+            structure_str_list = post_result["structure_batch_list"][0]
+            structure_str_list = structure_str_list[0]
+            structure_str_list = (
+                ["<html>", "<body>", "<table>"]
+                + structure_str_list
+                + ["</table>", "</body>", "</html>"]
+            )
+            results.append((structure_str_list, bbox_list, 0))
+
+        total_elapse = time.perf_counter() - starttime
+        for i in range(len(results)):
+            results[i] = (results[i][0], results[i][1], total_elapse / batch_size)
+
+        return results

+ 29 - 0
mineru/model/table/rec/slanet_plus/table_stucture_utils.py

@@ -443,6 +443,35 @@ class TablePreprocess:
         ]
 
 
+class BatchTablePreprocess:
+
+    def __init__(self):
+        self.preprocess = TablePreprocess()
+
+    def __call__(
+        self, img_list: List[np.ndarray]
+    ) -> Tuple[List[np.ndarray], List[List[float]]]:
+        """批量处理图像
+
+        Args:
+            img_list: 图像列表
+
+        Returns:
+            预处理后的图像列表和形状信息列表
+        """
+        processed_imgs = []
+        shape_lists = []
+
+        for img in img_list:
+            if img is None:
+                continue
+            data = {"image": img}
+            img_processed, shape_list = self.preprocess(data)
+            processed_imgs.append(img_processed)
+            shape_lists.append(shape_list)
+        return processed_imgs, shape_lists
+
+
 class ResizeTableImage:
     def __init__(self, max_len, resize_bboxes=False, infer_mode=False):
         super(ResizeTableImage, self).__init__()