Explorar o código

refactor(model): remove unused OCR and table models

- Remove OCR utils, modified PaddleOCR, and StructEqTable model
- Delete related import statements and model definitions
- Update dependencies in setup.py to remove paddlepaddle and related OCR packages
myhloli hai 7 meses
pai
achega
d8ebd92f26

+ 0 - 0
magic_pdf/model/sub_modules/ocr/paddleocr/__init__.py


+ 0 - 364
magic_pdf/model/sub_modules/ocr/paddleocr/ocr_utils.py

@@ -1,364 +0,0 @@
-import cv2
-import numpy as np
-from loguru import logger
-from io import BytesIO
-from PIL import Image
-import base64
-from magic_pdf.libs.boxbase import __is_overlaps_y_exceeds_threshold
-from magic_pdf.pre_proc.ocr_dict_merge import merge_spans_to_line
-
-from importlib.resources import files
-from paddleocr import PaddleOCR
-from ppocr.utils.utility import check_and_read
-
-
-def img_decode(content: bytes):
-    np_arr = np.frombuffer(content, dtype=np.uint8)
-    return cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED)
-
-
-def check_img(img):
-    if isinstance(img, bytes):
-        img = img_decode(img)
-    if isinstance(img, str):
-        image_file = img
-        img, flag_gif, flag_pdf = check_and_read(image_file)
-        if not flag_gif and not flag_pdf:
-            with open(image_file, 'rb') as f:
-                img_str = f.read()
-                img = img_decode(img_str)
-            if img is None:
-                try:
-                    buf = BytesIO()
-                    image = BytesIO(img_str)
-                    im = Image.open(image)
-                    rgb = im.convert('RGB')
-                    rgb.save(buf, 'jpeg')
-                    buf.seek(0)
-                    image_bytes = buf.read()
-                    data_base64 = str(base64.b64encode(image_bytes),
-                                      encoding="utf-8")
-                    image_decode = base64.b64decode(data_base64)
-                    img_array = np.frombuffer(image_decode, np.uint8)
-                    img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
-                except:
-                    logger.error("error in loading image:{}".format(image_file))
-                    return None
-        if img is None:
-            logger.error("error in loading image:{}".format(image_file))
-            return None
-    if isinstance(img, np.ndarray) and len(img.shape) == 2:
-        img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
-
-    return img
-
-
-def bbox_to_points(bbox):
-    """ 将bbox格式转换为四个顶点的数组 """
-    x0, y0, x1, y1 = bbox
-    return np.array([[x0, y0], [x1, y0], [x1, y1], [x0, y1]]).astype('float32')
-
-
-def points_to_bbox(points):
-    """ 将四个顶点的数组转换为bbox格式 """
-    x0, y0 = points[0]
-    x1, _ = points[1]
-    _, y1 = points[2]
-    return [x0, y0, x1, y1]
-
-
-def merge_intervals(intervals):
-    # Sort the intervals based on the start value
-    intervals.sort(key=lambda x: x[0])
-
-    merged = []
-    for interval in intervals:
-        # If the list of merged intervals is empty or if the current
-        # interval does not overlap with the previous, simply append it.
-        if not merged or merged[-1][1] < interval[0]:
-            merged.append(interval)
-        else:
-            # Otherwise, there is overlap, so we merge the current and previous intervals.
-            merged[-1][1] = max(merged[-1][1], interval[1])
-
-    return merged
-
-
-def remove_intervals(original, masks):
-    # Merge all mask intervals
-    merged_masks = merge_intervals(masks)
-
-    result = []
-    original_start, original_end = original
-
-    for mask in merged_masks:
-        mask_start, mask_end = mask
-
-        # If the mask starts after the original range, ignore it
-        if mask_start > original_end:
-            continue
-
-        # If the mask ends before the original range starts, ignore it
-        if mask_end < original_start:
-            continue
-
-        # Remove the masked part from the original range
-        if original_start < mask_start:
-            result.append([original_start, mask_start - 1])
-
-        original_start = max(mask_end + 1, original_start)
-
-    # Add the remaining part of the original range, if any
-    if original_start <= original_end:
-        result.append([original_start, original_end])
-
-    return result
-
-
-def update_det_boxes(dt_boxes, mfd_res):
-    new_dt_boxes = []
-    angle_boxes_list = []
-    for text_box in dt_boxes:
-
-        if calculate_is_angle(text_box):
-            angle_boxes_list.append(text_box)
-            continue
-
-        text_bbox = points_to_bbox(text_box)
-        masks_list = []
-        for mf_box in mfd_res:
-            mf_bbox = mf_box['bbox']
-            if __is_overlaps_y_exceeds_threshold(text_bbox, mf_bbox):
-                masks_list.append([mf_bbox[0], mf_bbox[2]])
-        text_x_range = [text_bbox[0], text_bbox[2]]
-        text_remove_mask_range = remove_intervals(text_x_range, masks_list)
-        temp_dt_box = []
-        for text_remove_mask in text_remove_mask_range:
-            temp_dt_box.append(bbox_to_points([text_remove_mask[0], text_bbox[1], text_remove_mask[1], text_bbox[3]]))
-        if len(temp_dt_box) > 0:
-            new_dt_boxes.extend(temp_dt_box)
-
-    new_dt_boxes.extend(angle_boxes_list)
-
-    return new_dt_boxes
-
-
-def merge_overlapping_spans(spans):
-    """
-    Merges overlapping spans on the same line.
-
-    :param spans: A list of span coordinates [(x1, y1, x2, y2), ...]
-    :return: A list of merged spans
-    """
-    # Return an empty list if the input spans list is empty
-    if not spans:
-        return []
-
-    # Sort spans by their starting x-coordinate
-    spans.sort(key=lambda x: x[0])
-
-    # Initialize the list of merged spans
-    merged = []
-    for span in spans:
-        # Unpack span coordinates
-        x1, y1, x2, y2 = span
-        # If the merged list is empty or there's no horizontal overlap, add the span directly
-        if not merged or merged[-1][2] < x1:
-            merged.append(span)
-        else:
-            # If there is horizontal overlap, merge the current span with the previous one
-            last_span = merged.pop()
-            # Update the merged span's top-left corner to the smaller (x1, y1) and bottom-right to the larger (x2, y2)
-            x1 = min(last_span[0], x1)
-            y1 = min(last_span[1], y1)
-            x2 = max(last_span[2], x2)
-            y2 = max(last_span[3], y2)
-            # Add the merged span back to the list
-            merged.append((x1, y1, x2, y2))
-
-    # Return the list of merged spans
-    return merged
-
-
-def merge_det_boxes(dt_boxes):
-    """
-    Merge detection boxes.
-
-    This function takes a list of detected bounding boxes, each represented by four corner points.
-    The goal is to merge these bounding boxes into larger text regions.
-
-    Parameters:
-    dt_boxes (list): A list containing multiple text detection boxes, where each box is defined by four corner points.
-
-    Returns:
-    list: A list containing the merged text regions, where each region is represented by four corner points.
-    """
-    # Convert the detection boxes into a dictionary format with bounding boxes and type
-    dt_boxes_dict_list = []
-    angle_boxes_list = []
-    for text_box in dt_boxes:
-        text_bbox = points_to_bbox(text_box)
-
-        if calculate_is_angle(text_box):
-            angle_boxes_list.append(text_box)
-            continue
-
-        text_box_dict = {
-            'bbox': text_bbox,
-            'type': 'text',
-        }
-        dt_boxes_dict_list.append(text_box_dict)
-
-    # Merge adjacent text regions into lines
-    lines = merge_spans_to_line(dt_boxes_dict_list)
-
-    # Initialize a new list for storing the merged text regions
-    new_dt_boxes = []
-    for line in lines:
-        line_bbox_list = []
-        for span in line:
-            line_bbox_list.append(span['bbox'])
-
-        # Merge overlapping text regions within the same line
-        merged_spans = merge_overlapping_spans(line_bbox_list)
-
-        # Convert the merged text regions back to point format and add them to the new detection box list
-        for span in merged_spans:
-            new_dt_boxes.append(bbox_to_points(span))
-
-    new_dt_boxes.extend(angle_boxes_list)
-
-    return new_dt_boxes
-
-
-def get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list):
-    paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
-    # Adjust the coordinates of the formula area
-    adjusted_mfdetrec_res = []
-    for mf_res in single_page_mfdetrec_res:
-        mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"]
-        # Adjust the coordinates of the formula area to the coordinates relative to the cropping area
-        x0 = mf_xmin - xmin + paste_x
-        y0 = mf_ymin - ymin + paste_y
-        x1 = mf_xmax - xmin + paste_x
-        y1 = mf_ymax - ymin + paste_y
-        # Filter formula blocks outside the graph
-        if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]):
-            continue
-        else:
-            adjusted_mfdetrec_res.append({
-                "bbox": [x0, y0, x1, y1],
-            })
-    return adjusted_mfdetrec_res
-
-
-def get_ocr_result_list(ocr_res, useful_list):
-    paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
-    ocr_result_list = []
-    for box_ocr_res in ocr_res:
-
-        if len(box_ocr_res) == 2:
-            p1, p2, p3, p4 = box_ocr_res[0]
-            text, score = box_ocr_res[1]
-            # logger.info(f"text: {text}, score: {score}")
-            if score < 0.6:  # 过滤低置信度的结果
-                continue
-        else:
-            p1, p2, p3, p4 = box_ocr_res
-            text, score = "", 1
-        # average_angle_degrees = calculate_angle_degrees(box_ocr_res[0])
-        # if average_angle_degrees > 0.5:
-        poly = [p1, p2, p3, p4]
-        if calculate_is_angle(poly):
-            # logger.info(f"average_angle_degrees: {average_angle_degrees}, text: {text}")
-            # 与x轴的夹角超过0.5度,对边界做一下矫正
-            # 计算几何中心
-            x_center = sum(point[0] for point in poly) / 4
-            y_center = sum(point[1] for point in poly) / 4
-            new_height = ((p4[1] - p1[1]) + (p3[1] - p2[1])) / 2
-            new_width = p3[0] - p1[0]
-            p1 = [x_center - new_width / 2, y_center - new_height / 2]
-            p2 = [x_center + new_width / 2, y_center - new_height / 2]
-            p3 = [x_center + new_width / 2, y_center + new_height / 2]
-            p4 = [x_center - new_width / 2, y_center + new_height / 2]
-
-        # Convert the coordinates back to the original coordinate system
-        p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin]
-        p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin]
-        p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
-        p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin]
-
-        ocr_result_list.append({
-            'category_id': 15,
-            'poly': p1 + p2 + p3 + p4,
-            'score': float(round(score, 2)),
-            'text': text,
-        })
-
-    return ocr_result_list
-
-
-def calculate_is_angle(poly):
-    p1, p2, p3, p4 = poly
-    height = ((p4[1] - p1[1]) + (p3[1] - p2[1])) / 2
-    if 0.8 * height <= (p3[1] - p1[1]) <= 1.2 * height:
-        return False
-    else:
-        # logger.info((p3[1] - p1[1])/height)
-        return True
-
-
-class ONNXModelSingleton:
-    _instance = None
-    _models = {}
-
-    def __new__(cls, *args, **kwargs):
-        if cls._instance is None:
-            cls._instance = super().__new__(cls)
-        return cls._instance
-
-    def get_onnx_model(self, **kwargs):
-
-        lang = kwargs.get('lang', None)
-        det_db_box_thresh = kwargs.get('det_db_box_thresh', 0.3)
-        use_dilation = kwargs.get('use_dilation', True)
-        det_db_unclip_ratio = kwargs.get('det_db_unclip_ratio', 1.8)
-        key = (lang, det_db_box_thresh, use_dilation, det_db_unclip_ratio)
-        if key not in self._models:
-            self._models[key] = onnx_model_init(key)
-        return self._models[key]
-
-
-def onnx_model_init(key):
-    if len(key) < 4:
-        logger.error('Invalid key length, expected at least 4 elements')
-        exit(1)
-
-    try:
-        resource_path = files("rapidocr_onnxruntime") / "models"
-        additional_ocr_params = {
-            "use_onnx": True,
-            "det_model_dir": f'{resource_path}/ch_PP-OCRv4_det_infer.onnx',
-            "rec_model_dir": f'{resource_path}/ch_PP-OCRv4_rec_infer.onnx',
-            "cls_model_dir": f'{resource_path}/ch_ppocr_mobile_v2.0_cls_infer.onnx',
-            "det_db_box_thresh": key[1],
-            "use_dilation": key[2],
-            "det_db_unclip_ratio": key[3],
-        }
-
-        if key[0] is not None:
-            additional_ocr_params["lang"] = key[0]
-
-        # logger.info(f"additional_ocr_params: {additional_ocr_params}")
-
-        onnx_model = PaddleOCR(**additional_ocr_params)
-
-        if onnx_model is None:
-            logger.error('model init failed')
-            exit(1)
-        else:
-            return onnx_model
-
-    except Exception as e:
-        logger.exception(f'Error initializing model: {e}')
-        exit(1)

+ 0 - 205
magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py

@@ -1,205 +0,0 @@
-import copy
-import platform
-import time
-import cv2
-import numpy as np
-import torch
-
-
-from paddleocr import PaddleOCR
-from ppocr.utils.logging import get_logger
-from ppocr.utils.utility import alpha_to_color, binarize_img
-from tools.infer.predict_system import sorted_boxes
-from tools.infer.utility import get_rotate_crop_image, get_minarea_rect_crop
-
-from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import update_det_boxes, merge_det_boxes, check_img, \
-    ONNXModelSingleton
-
-logger = get_logger()
-
-
-class ModifiedPaddleOCR(PaddleOCR):
-    def __init__(self, *args, **kwargs):
-
-        super().__init__(*args, **kwargs)
-        self.lang = kwargs.get('lang', 'ch')
-        # 在cpu架构为arm且不支持cuda时调用onnx、
-        if not torch.cuda.is_available() and platform.machine() in ['arm64', 'aarch64']:
-            self.use_onnx = True
-            onnx_model_manager = ONNXModelSingleton()
-            self.additional_ocr = onnx_model_manager.get_onnx_model(**kwargs)
-        else:
-            self.use_onnx = False
-
-    def ocr(self,
-            img,
-            det=True,
-            rec=True,
-            cls=True,
-            bin=False,
-            inv=False,
-            alpha_color=(255, 255, 255),
-            mfd_res=None,
-            ):
-        """
-        OCR with PaddleOCR
-        args:
-            img: img for OCR, support ndarray, img_path and list or ndarray
-            det: use text detection or not. If False, only rec will be exec. Default is True
-            rec: use text recognition or not. If False, only det will be exec. Default is True
-            cls: use angle classifier or not. Default is True. If True, the text with rotation of 180 degrees can be recognized. If no text is rotated by 180 degrees, use cls=False to get better performance. Text with rotation of 90 or 270 degrees can be recognized even if cls=False.
-            bin: binarize image to black and white. Default is False.
-            inv: invert image colors. Default is False.
-            alpha_color: set RGB color Tuple for transparent parts replacement. Default is pure white.
-        """
-        assert isinstance(img, (np.ndarray, list, str, bytes))
-        if isinstance(img, list) and det == True:
-            logger.error('When input a list of images, det must be false')
-            exit(0)
-        if cls == True and self.use_angle_cls == False:
-            pass
-            # logger.warning(
-            #     'Since the angle classifier is not initialized, it will not be used during the forward process'
-            # )
-
-        img = check_img(img)
-        # for infer pdf file
-        if isinstance(img, list):
-            if self.page_num > len(img) or self.page_num == 0:
-                self.page_num = len(img)
-            imgs = img[:self.page_num]
-        else:
-            imgs = [img]
-
-        def preprocess_image(_image):
-            _image = alpha_to_color(_image, alpha_color)
-            if inv:
-                _image = cv2.bitwise_not(_image)
-            if bin:
-                _image = binarize_img(_image)
-            return _image
-
-        if det and rec:
-            ocr_res = []
-            for img in imgs:
-                img = preprocess_image(img)
-                dt_boxes, rec_res, _ = self.__call__(img, cls, mfd_res=mfd_res)
-                if not dt_boxes and not rec_res:
-                    ocr_res.append(None)
-                    continue
-                tmp_res = [[box.tolist(), res]
-                           for box, res in zip(dt_boxes, rec_res)]
-                ocr_res.append(tmp_res)
-            return ocr_res
-        elif det and not rec:
-            ocr_res = []
-            for img in imgs:
-                img = preprocess_image(img)
-                if self.lang in ['ch'] and self.use_onnx:
-                    dt_boxes, elapse = self.additional_ocr.text_detector(img)
-                else:
-                    dt_boxes, elapse = self.text_detector(img)
-                if dt_boxes is None:
-                    ocr_res.append(None)
-                    continue
-                dt_boxes = sorted_boxes(dt_boxes)
-                # merge_det_boxes 和 update_det_boxes 都会把poly转成bbox再转回poly,因此需要过滤所有倾斜程度较大的文本框
-                dt_boxes = merge_det_boxes(dt_boxes)
-                if mfd_res:
-                    bef = time.time()
-                    dt_boxes = update_det_boxes(dt_boxes, mfd_res)
-                    aft = time.time()
-                    logger.debug("split text box by formula, new dt_boxes num : {}, elapsed : {}".format(
-                        len(dt_boxes), aft - bef))
-                tmp_res = [box.tolist() for box in dt_boxes]
-                ocr_res.append(tmp_res)
-            return ocr_res
-        else:
-            ocr_res = []
-            cls_res = []
-            for img in imgs:
-                if not isinstance(img, list):
-                    img = preprocess_image(img)
-                    img = [img]
-                if self.use_angle_cls and cls:
-                    img, cls_res_tmp, elapse = self.text_classifier(img)
-                    if not rec:
-                        cls_res.append(cls_res_tmp)
-                if self.lang in ['ch'] and self.use_onnx:
-                    rec_res, elapse = self.additional_ocr.text_recognizer(img)
-                else:
-                    rec_res, elapse = self.text_recognizer(img)
-                ocr_res.append(rec_res)
-            if not rec:
-                return cls_res
-            return ocr_res
-
-    def __call__(self, img, cls=True, mfd_res=None):
-        time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}
-
-        if img is None:
-            logger.debug("no valid image provided")
-            return None, None, time_dict
-
-        start = time.time()
-        ori_im = img.copy()
-        if self.lang in ['ch'] and self.use_onnx:
-            dt_boxes, elapse = self.additional_ocr.text_detector(img)
-        else:
-            dt_boxes, elapse = self.text_detector(img)
-        time_dict['det'] = elapse
-
-        if dt_boxes is None:
-            logger.debug("no dt_boxes found, elapsed : {}".format(elapse))
-            end = time.time()
-            time_dict['all'] = end - start
-            return None, None, time_dict
-        else:
-            logger.debug("dt_boxes num : {}, elapsed : {}".format(
-                len(dt_boxes), elapse))
-        img_crop_list = []
-
-        dt_boxes = sorted_boxes(dt_boxes)
-
-        # merge_det_boxes 和 update_det_boxes 都会把poly转成bbox再转回poly,因此需要过滤所有倾斜程度较大的文本框
-        dt_boxes = merge_det_boxes(dt_boxes)
-
-        if mfd_res:
-            bef = time.time()
-            dt_boxes = update_det_boxes(dt_boxes, mfd_res)
-            aft = time.time()
-            logger.debug("split text box by formula, new dt_boxes num : {}, elapsed : {}".format(
-                len(dt_boxes), aft - bef))
-
-        for bno in range(len(dt_boxes)):
-            tmp_box = copy.deepcopy(dt_boxes[bno])
-            if self.args.det_box_type == "quad":
-                img_crop = get_rotate_crop_image(ori_im, tmp_box)
-            else:
-                img_crop = get_minarea_rect_crop(ori_im, tmp_box)
-            img_crop_list.append(img_crop)
-        if self.use_angle_cls and cls:
-            img_crop_list, angle_list, elapse = self.text_classifier(
-                img_crop_list)
-            time_dict['cls'] = elapse
-            logger.debug("cls num  : {}, elapsed : {}".format(
-                len(img_crop_list), elapse))
-        if self.lang in ['ch'] and self.use_onnx:
-            rec_res, elapse = self.additional_ocr.text_recognizer(img_crop_list)
-        else:
-            rec_res, elapse = self.text_recognizer(img_crop_list)
-        time_dict['rec'] = elapse
-        logger.debug("rec_res num  : {}, elapsed : {}".format(
-            len(rec_res), elapse))
-        if self.args.save_crop_res:
-            self.draw_crop_rec_res(self.args.crop_res_save_dir, img_crop_list,
-                                   rec_res)
-        filter_boxes, filter_rec_res = [], []
-        for box, rec_result in zip(dt_boxes, rec_res):
-            text, score = rec_result
-            if score >= self.drop_score:
-                filter_boxes.append(box)
-                filter_rec_res.append(rec_result)
-        end = time.time()
-        time_dict['all'] = end - start
-        return filter_boxes, filter_rec_res, time_dict

+ 0 - 213
magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_291_mod.py

@@ -1,213 +0,0 @@
-import copy
-import time
-
-
-import cv2
-import numpy as np
-from paddleocr import PaddleOCR
-from paddleocr.paddleocr import check_img, logger
-from paddleocr.ppocr.utils.utility import alpha_to_color, binarize_img
-from paddleocr.tools.infer.predict_system import sorted_boxes
-from paddleocr.tools.infer.utility import slice_generator, merge_fragmented, get_rotate_crop_image, \
-    get_minarea_rect_crop
-
-from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import update_det_boxes
-
-
-class ModifiedPaddleOCR(PaddleOCR):
-
-    def ocr(
-        self,
-        img,
-        det=True,
-        rec=True,
-        cls=True,
-        bin=False,
-        inv=False,
-        alpha_color=(255, 255, 255),
-        slice={},
-        mfd_res=None,
-    ):
-        """
-        OCR with PaddleOCR
-
-        Args:
-            img: Image for OCR. It can be an ndarray, img_path, or a list of ndarrays.
-            det: Use text detection or not. If False, only text recognition will be executed. Default is True.
-            rec: Use text recognition or not. If False, only text detection will be executed. Default is True.
-            cls: Use angle classifier or not. Default is True. If True, the text with a rotation of 180 degrees can be recognized. If no text is rotated by 180 degrees, use cls=False to get better performance.
-            bin: Binarize image to black and white. Default is False.
-            inv: Invert image colors. Default is False.
-            alpha_color: Set RGB color Tuple for transparent parts replacement. Default is pure white.
-            slice: Use sliding window inference for large images. Both det and rec must be True. Requires int values for slice["horizontal_stride"], slice["vertical_stride"], slice["merge_x_thres"], slice["merge_y_thres"] (See doc/doc_en/slice_en.md). Default is {}.
-
-        Returns:
-            If both det and rec are True, returns a list of OCR results for each image. Each OCR result is a list of bounding boxes and recognized text for each detected text region.
-            If det is True and rec is False, returns a list of detected bounding boxes for each image.
-            If det is False and rec is True, returns a list of recognized text for each image.
-            If both det and rec are False, returns a list of angle classification results for each image.
-
-        Raises:
-            AssertionError: If the input image is not of type ndarray, list, str, or bytes.
-            SystemExit: If det is True and the input is a list of images.
-
-        Note:
-            - If the angle classifier is not initialized (use_angle_cls=False), it will not be used during the forward process.
-            - For PDF files, if the input is a list of images and the page_num is specified, only the first page_num images will be processed.
-            - The preprocess_image function is used to preprocess the input image by applying alpha color replacement, inversion, and binarization if specified.
-        """
-        assert isinstance(img, (np.ndarray, list, str, bytes))
-        if isinstance(img, list) and det == True:
-            logger.error("When input a list of images, det must be false")
-            exit(0)
-        if cls == True and self.use_angle_cls == False:
-            logger.warning(
-                "Since the angle classifier is not initialized, it will not be used during the forward process"
-            )
-
-        img, flag_gif, flag_pdf = check_img(img, alpha_color)
-        # for infer pdf file
-        if isinstance(img, list) and flag_pdf:
-            if self.page_num > len(img) or self.page_num == 0:
-                imgs = img
-            else:
-                imgs = img[: self.page_num]
-        else:
-            imgs = [img]
-
-        def preprocess_image(_image):
-            _image = alpha_to_color(_image, alpha_color)
-            if inv:
-                _image = cv2.bitwise_not(_image)
-            if bin:
-                _image = binarize_img(_image)
-            return _image
-
-        if det and rec:
-            ocr_res = []
-            for img in imgs:
-                img = preprocess_image(img)
-                dt_boxes, rec_res, _ = self.__call__(img, cls, slice, mfd_res=mfd_res)
-                if not dt_boxes and not rec_res:
-                    ocr_res.append(None)
-                    continue
-                tmp_res = [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
-                ocr_res.append(tmp_res)
-            return ocr_res
-        elif det and not rec:
-            ocr_res = []
-            for img in imgs:
-                img = preprocess_image(img)
-                dt_boxes, elapse = self.text_detector(img)
-                if dt_boxes.size == 0:
-                    ocr_res.append(None)
-                    continue
-                tmp_res = [box.tolist() for box in dt_boxes]
-                ocr_res.append(tmp_res)
-            return ocr_res
-        else:
-            ocr_res = []
-            cls_res = []
-            for img in imgs:
-                if not isinstance(img, list):
-                    img = preprocess_image(img)
-                    img = [img]
-                if self.use_angle_cls and cls:
-                    img, cls_res_tmp, elapse = self.text_classifier(img)
-                    if not rec:
-                        cls_res.append(cls_res_tmp)
-                rec_res, elapse = self.text_recognizer(img)
-                ocr_res.append(rec_res)
-            if not rec:
-                return cls_res
-            return ocr_res
-
-    def __call__(self, img, cls=True, slice={}, mfd_res=None):
-        time_dict = {"det": 0, "rec": 0, "cls": 0, "all": 0}
-
-        if img is None:
-            logger.debug("no valid image provided")
-            return None, None, time_dict
-
-        start = time.time()
-        ori_im = img.copy()
-        if slice:
-            slice_gen = slice_generator(
-                img,
-                horizontal_stride=slice["horizontal_stride"],
-                vertical_stride=slice["vertical_stride"],
-            )
-            elapsed = []
-            dt_slice_boxes = []
-            for slice_crop, v_start, h_start in slice_gen:
-                dt_boxes, elapse = self.text_detector(slice_crop, use_slice=True)
-                if dt_boxes.size:
-                    dt_boxes[:, :, 0] += h_start
-                    dt_boxes[:, :, 1] += v_start
-                    dt_slice_boxes.append(dt_boxes)
-                    elapsed.append(elapse)
-            dt_boxes = np.concatenate(dt_slice_boxes)
-
-            dt_boxes = merge_fragmented(
-                boxes=dt_boxes,
-                x_threshold=slice["merge_x_thres"],
-                y_threshold=slice["merge_y_thres"],
-            )
-            elapse = sum(elapsed)
-        else:
-            dt_boxes, elapse = self.text_detector(img)
-
-        time_dict["det"] = elapse
-
-        if dt_boxes is None:
-            logger.debug("no dt_boxes found, elapsed : {}".format(elapse))
-            end = time.time()
-            time_dict["all"] = end - start
-            return None, None, time_dict
-        else:
-            logger.debug(
-                "dt_boxes num : {}, elapsed : {}".format(len(dt_boxes), elapse)
-            )
-        img_crop_list = []
-
-        dt_boxes = sorted_boxes(dt_boxes)
-
-        if mfd_res:
-            bef = time.time()
-            dt_boxes = update_det_boxes(dt_boxes, mfd_res)
-            aft = time.time()
-            logger.debug("split text box by formula, new dt_boxes num : {}, elapsed : {}".format(
-                len(dt_boxes), aft - bef))
-
-        for bno in range(len(dt_boxes)):
-            tmp_box = copy.deepcopy(dt_boxes[bno])
-            if self.args.det_box_type == "quad":
-                img_crop = get_rotate_crop_image(ori_im, tmp_box)
-            else:
-                img_crop = get_minarea_rect_crop(ori_im, tmp_box)
-            img_crop_list.append(img_crop)
-        if self.use_angle_cls and cls:
-            img_crop_list, angle_list, elapse = self.text_classifier(img_crop_list)
-            time_dict["cls"] = elapse
-            logger.debug(
-                "cls num  : {}, elapsed : {}".format(len(img_crop_list), elapse)
-            )
-        if len(img_crop_list) > 1000:
-            logger.debug(
-                f"rec crops num: {len(img_crop_list)}, time and memory cost may be large."
-            )
-
-        rec_res, elapse = self.text_recognizer(img_crop_list)
-        time_dict["rec"] = elapse
-        logger.debug("rec_res num  : {}, elapsed : {}".format(len(rec_res), elapse))
-        if self.args.save_crop_res:
-            self.draw_crop_rec_res(self.args.crop_res_save_dir, img_crop_list, rec_res)
-        filter_boxes, filter_rec_res = [], []
-        for box, rec_result in zip(dt_boxes, rec_res):
-            text, score = rec_result[0], rec_result[1]
-            if score >= self.drop_score:
-                filter_boxes.append(box)
-                filter_rec_res.append(rec_result)
-        end = time.time()
-        time_dict["all"] = end - start
-        return filter_boxes, filter_rec_res, time_dict

+ 0 - 8
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_cls.py

@@ -21,12 +21,8 @@ class TextClassifier(BaseOCRV20):
         }
         self.postprocess_op = build_post_process(postprocess_params)
 
-        # use_gpu = args.use_gpu
-        # self.use_gpu = torch.cuda.is_available() and use_gpu
-
         self.weights_path = args.cls_model_path
         self.yaml_path = args.cls_yaml_path
-        # network_config = utility.AnalysisConfig(self.weights_path, self.yaml_path)
         network_config = utility.get_arch_config(self.weights_path)
         super(TextClassifier, self).__init__(network_config, **kwargs)
 
@@ -37,8 +33,6 @@ class TextClassifier(BaseOCRV20):
 
         self.load_pytorch_weights(self.weights_path)
         self.net.eval()
-        # if self.use_gpu:
-        #     self.net.cuda()
         self.net.to(self.device)
 
     def resize_norm_img(self, img):
@@ -97,8 +91,6 @@ class TextClassifier(BaseOCRV20):
 
             with torch.no_grad():
                 inp = torch.from_numpy(norm_img_batch)
-                # if self.use_gpu:
-                #     inp = inp.cuda()
                 inp = inp.to(self.device)
                 prob_out = self.net(inp)
             prob_out = prob_out.cpu().numpy()

+ 0 - 8
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_det.py

@@ -109,18 +109,12 @@ class TextDetector(BaseOCRV20):
         self.preprocess_op = create_operators(pre_process_list)
         self.postprocess_op = build_post_process(postprocess_params)
 
-        # use_gpu = args.use_gpu
-        # self.use_gpu = torch.cuda.is_available() and use_gpu
-
         self.weights_path = args.det_model_path
         self.yaml_path = args.det_yaml_path
-        # network_config = utility.AnalysisConfig(self.weights_path, self.yaml_path)
         network_config = utility.get_arch_config(self.weights_path)
         super(TextDetector, self).__init__(network_config, **kwargs)
         self.load_pytorch_weights(self.weights_path)
         self.net.eval()
-        # if self.use_gpu:
-        #     self.net.cuda()
         self.net.to(self.device)
 
     def order_points_clockwise(self, pts):
@@ -190,8 +184,6 @@ class TextDetector(BaseOCRV20):
 
         with torch.no_grad():
             inp = torch.from_numpy(img)
-            # if self.use_gpu:
-            #     inp = inp.cuda()
             inp = inp.to(self.device)
             outputs = self.net(inp)
 

+ 0 - 11
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_rec.py

@@ -70,17 +70,12 @@ class TextRecognizer(BaseOCRV20):
             }
         self.postprocess_op = build_post_process(postprocess_params)
 
-        # use_gpu = args.use_gpu
-        # self.use_gpu = torch.cuda.is_available() and use_gpu
-
         self.limited_max_width = args.limited_max_width
         self.limited_min_width = args.limited_min_width
 
         self.weights_path = args.rec_model_path
         self.yaml_path = args.rec_yaml_path
 
-        char_num = len(getattr(self.postprocess_op, 'character'))
-        # network_config = utility.AnalysisConfig(self.weights_path, self.yaml_path, char_num)
         network_config = utility.get_arch_config(self.weights_path)
         weights = self.read_pytorch_weights(self.weights_path)
 
@@ -95,8 +90,6 @@ class TextRecognizer(BaseOCRV20):
 
         self.load_state_dict(weights)
         self.net.eval()
-        # if self.use_gpu:
-        #     self.net.cuda()
         self.net.to(self.device)
 
     def resize_norm_img(self, img, max_wh_ratio):
@@ -417,8 +410,6 @@ class TextRecognizer(BaseOCRV20):
                 inputs = [norm_img_batch, norm_img_mask_batch, word_label_list]
 
                 inp = [torch.from_numpy(e_i) for e_i in inputs]
-                # if self.use_gpu:
-                #     inp = [e_i.cuda() for e_i in inp]
                 inp = [e_i.to(self.device) for e_i in inp]
                 with torch.no_grad():
                     outputs = self.net(inp)
@@ -431,8 +422,6 @@ class TextRecognizer(BaseOCRV20):
 
                 with torch.no_grad():
                     inp = torch.from_numpy(norm_img_batch)
-                    # if self.use_gpu:
-                    #     inp = inp.cuda()
                     inp = inp.to(self.device)
                     prob_out = self.net(inp)
 

+ 0 - 6
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_system.py

@@ -50,12 +50,6 @@ class TextSystem(object):
             dst_img = np.rot90(dst_img)
         return dst_img
 
-    def print_draw_crop_rec_res(self, img_crop_list, rec_res):
-        bbox_num = len(img_crop_list)
-        for bno in range(bbox_num):
-            cv2.imwrite("./output/img_crop_%d.jpg" % bno, img_crop_list[bno])
-            print(bno, rec_res[bno])
-
     def __call__(self, img):
         ori_im = img.copy()
         dt_boxes, elapse = self.text_detector(img)

+ 0 - 0
magic_pdf/model/sub_modules/table/structeqtable/__init__.py


+ 0 - 37
magic_pdf/model/sub_modules/table/structeqtable/struct_eqtable.py

@@ -1,37 +0,0 @@
-import torch
-from struct_eqtable import build_model
-
-from magic_pdf.model.sub_modules.table.table_utils import minify_html
-
-
-class StructTableModel:
-    def __init__(self, model_path, max_new_tokens=1024, max_time=60):
-        # init
-        assert torch.cuda.is_available(), "CUDA must be available for StructEqTable model."
-        self.model = build_model(
-            model_ckpt=model_path,
-            max_new_tokens=max_new_tokens,
-            max_time=max_time,
-            lmdeploy=False,
-            flash_attn=False,
-            batch_size=1,
-        ).cuda()
-        self.default_format = "html"
-
-    def predict(self, images, output_format=None, **kwargs):
-
-        if output_format is None:
-            output_format = self.default_format
-        else:
-            if output_format not in ['latex', 'markdown', 'html']:
-                raise ValueError(f"Output format {output_format} is not supported.")
-
-        results = self.model(
-            images, output_format=output_format
-        )
-
-        if output_format == "html":
-            results = [minify_html(html) for html in results]
-
-        return results
-

+ 0 - 0
magic_pdf/model/sub_modules/table/tablemaster/__init__.py


+ 0 - 72
magic_pdf/model/sub_modules/table/tablemaster/tablemaster_paddle.py

@@ -1,72 +0,0 @@
-import os
-
-import cv2
-import numpy as np
-from paddleocr import PaddleOCR
-from ppstructure.table.predict_table import TableSystem
-from ppstructure.utility import init_args
-from PIL import Image
-
-from magic_pdf.config.constants import *  # noqa: F403
-
-
-class TableMasterPaddleModel(object):
-    """This class is responsible for converting image of table into HTML format
-    using a pre-trained model.
-
-    Attributes:
-    - table_sys: An instance of TableSystem initialized with parsed arguments.
-
-    Methods:
-    - __init__(config): Initializes the model with configuration parameters.
-    - img2html(image): Converts a PIL Image or NumPy array to HTML string.
-    - parse_args(**kwargs): Parses configuration arguments.
-    """
-
-    def __init__(self, config):
-        """
-        Parameters:
-        - config (dict): Configuration dictionary containing model_dir and device.
-        """
-        args = self.parse_args(**config)
-        self.table_sys = TableSystem(args)
-
-    def img2html(self, image):
-        """
-        Parameters:
-        - image (PIL.Image or np.ndarray): The image of the table to be converted.
-
-        Return:
-        - HTML (str): A string representing the HTML structure with content of the table.
-        """
-        if isinstance(image, Image.Image):
-            image = np.asarray(image)
-            image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
-        pred_res, _ = self.table_sys(image)
-        pred_html = pred_res['html']
-        # res = '<td><table  border="1">' + pred_html.replace("<html><body><table>", "").replace(
-        # "</table></body></html>","") + "</table></td>\n"
-        return pred_html
-
-    def parse_args(self, **kwargs):
-        parser = init_args()
-        model_dir = kwargs.get('model_dir')
-        table_model_dir = os.path.join(model_dir, TABLE_MASTER_DIR)  # noqa: F405
-        table_char_dict_path = os.path.join(model_dir, TABLE_MASTER_DICT)  # noqa: F405
-        det_model_dir = os.path.join(model_dir, DETECT_MODEL_DIR)  # noqa: F405
-        rec_model_dir = os.path.join(model_dir, REC_MODEL_DIR)  # noqa: F405
-        rec_char_dict_path = os.path.join(model_dir, REC_CHAR_DICT)  # noqa: F405
-        device = kwargs.get('device', 'cpu')
-        use_gpu = True if device.startswith('cuda') else False
-        config = {
-            'use_gpu': use_gpu,
-            'table_max_len': kwargs.get('table_max_len', TABLE_MAX_LEN),  # noqa: F405
-            'table_algorithm': 'TableMaster',
-            'table_model_dir': table_model_dir,
-            'table_char_dict_path': table_char_dict_path,
-            'det_model_dir': det_model_dir,
-            'rec_model_dir': rec_model_dir,
-            'rec_char_dict_path': rec_char_dict_path,
-        }
-        parser.set_defaults(**config)
-        return parser.parse_args([])

+ 0 - 13
setup.py

@@ -40,12 +40,7 @@ if __name__ == '__main__':
                      "matplotlib<=3.9.0;platform_system=='Windows'",  # 3.9.1及之后不提供windows的预编译包,避免一些没有编译环境的windows设备安装失败
                      "matplotlib;platform_system=='Linux' or platform_system=='Darwin'",  # linux 和 macos 不应限制matplotlib的最高版本,以避免无法更新导致的一些bug
                      "ultralytics>=8.3.48",  # yolov8,公式检测
-                     "paddleocr==2.7.3",  # 2.8.0及2.8.1版本与detectron2有冲突,需锁定2.7.3
-                     "paddlepaddle==3.0.0rc1;platform_system=='Linux' or platform_system=='Darwin'",  # 解决linux的段异常问题
-                     "paddlepaddle==2.6.1;platform_system=='Windows'",  # windows版本3.0.0效率下降,需锁定2.6.1
                      "doclayout_yolo==0.0.2b1",  # doclayout_yolo
-                     "rapidocr-paddle>=1.4.5,<2.0.0",  # rapidocr-paddle
-                     "rapidocr_onnxruntime>=1.4.4,<2.0.0",
                      "rapid_table>=1.0.3,<2.0.0",  # rapid_table
                      "PyYAML",  # yaml
                      "ftfy"
@@ -54,14 +49,6 @@ if __name__ == '__main__':
             "old_linux":[
                 "albumentations<=1.4.20", # 1.4.21引入的simsimd不支持2019年及更早的linux系统
             ],
-            "layoutlmv3":[
-                "detectron2"
-            ],
-            "struct_eqtable":[
-                "struct-eqtable==0.3.2",  # 表格解析
-                "einops",  # struct-eqtable依赖
-                "accelerate",  # struct-eqtable依赖
-            ],
         },
         description="A practical tool for converting PDF to Markdown",  # 简短描述
         long_description=long_description,  # 详细描述