Kaynağa Gözat

refactor: refactor code

icecraft 11 ay önce
ebeveyn
işleme
b2887ca0aa

+ 0 - 101
magic_pdf/model/__init__.py

@@ -1,101 +0,0 @@
-from typing import Callable
-
-from abc import ABC, abstractmethod
-
-from magic_pdf.data.data_reader_writer import DataWriter
-from magic_pdf.data.dataset import Dataset
-from magic_pdf.pipe.operators import PipeResult
-
-
-__use_inside_model__ = True
-__model_mode__ = "full"
-
-
-class InferenceResultBase(ABC):
-
-    @abstractmethod
-    def __init__(self, inference_results: list, dataset: Dataset):
-        """Initialized method.
-
-        Args:
-            inference_results (list): the inference result generated by model
-            dataset (Dataset): the dataset related with model inference result
-        """
-        self._infer_res = inference_results
-        self._dataset = dataset
-
-    @abstractmethod
-    def draw_model(self, file_path: str) -> None:
-        """Draw model inference result.
-
-        Args:
-            file_path (str): the output file path
-        """
-        pass
-
-    @abstractmethod
-    def dump_model(self, writer: DataWriter, file_path: str):
-        """Dump model inference result to file.
-
-        Args:
-            writer (DataWriter): writer handle
-            file_path (str): the location of target file
-        """
-        pass
-
-    @abstractmethod
-    def get_infer_res(self):
-        """Get the inference result.
-
-        Returns:
-            list: the inference result generated by model
-        """
-        pass
-
-    @abstractmethod
-    def apply(self, proc: Callable, *args, **kwargs):
-        """Apply callable method which.
-
-        Args:
-            proc (Callable): invoke proc as follows:
-                proc(inference_result, *args, **kwargs)
-
-        Returns:
-            Any: return the result generated by proc
-        """
-        pass
-
-    @abstractmethod
-    def pipe_txt_mode(
-        self,
-        imageWriter: DataWriter,
-        start_page_id=0,
-        end_page_id=None,
-        debug_mode=False,
-        lang=None,
-    ) -> PipeResult:
-        """Post-proc the model inference result, Extract the text using the
-        third library, such as `pymupdf`
-
-        Args:
-            imageWriter (DataWriter): the image writer handle
-            start_page_id (int, optional): Defaults to 0. Let user select some pages He/She want to process
-            end_page_id (int, optional):  Defaults to the last page index of dataset. Let user select some pages He/She want to process
-            debug_mode (bool, optional): Defaults to False. will dump more log if enabled
-            lang (str, optional): Defaults to None.
-
-        Returns:
-            PipeResult: the result
-        """
-        pass
-
-    @abstractmethod
-    def pipe_ocr_mode(
-        self,
-        imageWriter: DataWriter,
-        start_page_id=0,
-        end_page_id=None,
-        debug_mode=False,
-        lang=None,
-    ) -> PipeResult:
-        pass

+ 31 - 37
magic_pdf/model/batch_analyze.py

@@ -11,17 +11,12 @@ from magic_pdf.config.exceptions import CUDA_NOT_AVAILABLE
 from magic_pdf.data.dataset import Dataset
 from magic_pdf.libs.clean_memory import clean_memory
 from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
-from magic_pdf.model.operators import InferenceResult
 from magic_pdf.model.pdf_extract_kit import CustomPEKModel
 from magic_pdf.model.sub_modules.model_utils import (
-    clean_vram,
-    crop_img,
-    get_res_list_from_layout_res,
-)
+    clean_vram, crop_img, get_res_list_from_layout_res)
 from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
-    get_adjusted_mfdetrec_res,
-    get_ocr_result_list,
-)
+    get_adjusted_mfdetrec_res, get_ocr_result_list)
+from magic_pdf.operators.models import InferenceResult
 
 YOLO_LAYOUT_BASE_BATCH_SIZE = 4
 MFD_BASE_BATCH_SIZE = 1
@@ -50,7 +45,7 @@ class BatchAnalyze:
                 pil_img = Image.fromarray(image)
                 width, height = pil_img.size
                 if height > width:
-                    input_res = {"poly": [0, 0, width, 0, width, height, 0, height]}
+                    input_res = {'poly': [0, 0, width, 0, width, height, 0, height]}
                     new_image, useful_list = crop_img(
                         input_res, pil_img, crop_paste_x=width // 2, crop_paste_y=0
                     )
@@ -65,17 +60,17 @@ class BatchAnalyze:
 
             for image_index, useful_list in modified_images:
                 for res in images_layout_res[image_index]:
-                    for i in range(len(res["poly"])):
+                    for i in range(len(res['poly'])):
                         if i % 2 == 0:
-                            res["poly"][i] = (
-                                res["poly"][i] - useful_list[0] + useful_list[2]
+                            res['poly'][i] = (
+                                res['poly'][i] - useful_list[0] + useful_list[2]
                             )
                         else:
-                            res["poly"][i] = (
-                                res["poly"][i] - useful_list[1] + useful_list[3]
+                            res['poly'][i] = (
+                                res['poly'][i] - useful_list[1] + useful_list[3]
                             )
         logger.info(
-            f"layout time: {round(time.time() - layout_start_time, 2)}, image num: {len(images)}"
+            f'layout time: {round(time.time() - layout_start_time, 2)}, image num: {len(images)}'
         )
 
         if self.model.apply_formula:
@@ -85,7 +80,7 @@ class BatchAnalyze:
                 images, self.batch_ratio * MFD_BASE_BATCH_SIZE
             )
             logger.info(
-                f"mfd time: {round(time.time() - mfd_start_time, 2)}, image num: {len(images)}"
+                f'mfd time: {round(time.time() - mfd_start_time, 2)}, image num: {len(images)}'
             )
 
             # 公式识别
@@ -98,7 +93,7 @@ class BatchAnalyze:
             for image_index in range(len(images)):
                 images_layout_res[image_index] += images_formula_list[image_index]
             logger.info(
-                f"mfr time: {round(time.time() - mfr_start_time, 2)}, image num: {len(images)}"
+                f'mfr time: {round(time.time() - mfr_start_time, 2)}, image num: {len(images)}'
             )
 
         # 清理显存
@@ -156,7 +151,7 @@ class BatchAnalyze:
                     if self.model.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
                         with torch.no_grad():
                             table_result = self.model.table_model.predict(
-                                new_image, "html"
+                                new_image, 'html'
                             )
                             if len(table_result) > 0:
                                 html_code = table_result[0]
@@ -169,32 +164,32 @@ class BatchAnalyze:
                     run_time = time.time() - single_table_start_time
                     if run_time > self.model.table_max_time:
                         logger.warning(
-                            f"table recognition processing exceeds max time {self.model.table_max_time}s"
+                            f'table recognition processing exceeds max time {self.model.table_max_time}s'
                         )
                     # 判断是否返回正常
                     if html_code:
                         expected_ending = html_code.strip().endswith(
-                            "</html>"
-                        ) or html_code.strip().endswith("</table>")
+                            '</html>'
+                        ) or html_code.strip().endswith('</table>')
                         if expected_ending:
-                            res["html"] = html_code
+                            res['html'] = html_code
                         else:
                             logger.warning(
-                                "table recognition processing fails, not found expected HTML table end"
+                                'table recognition processing fails, not found expected HTML table end'
                             )
                     else:
                         logger.warning(
-                            "table recognition processing fails, not get html return"
+                            'table recognition processing fails, not get html return'
                         )
                 table_time += time.time() - table_start
                 table_count += len(table_res_list)
 
         if self.model.apply_ocr:
-            logger.info(f"ocr time: {round(ocr_time, 2)}, image num: {ocr_count}")
+            logger.info(f'ocr time: {round(ocr_time, 2)}, image num: {ocr_count}')
         else:
-            logger.info(f"det time: {round(ocr_time, 2)}, image num: {ocr_count}")
+            logger.info(f'det time: {round(ocr_time, 2)}, image num: {ocr_count}')
         if self.model.apply_table:
-            logger.info(f"table time: {round(table_time, 2)}, image num: {table_count}")
+            logger.info(f'table time: {round(table_time, 2)}, image num: {table_count}')
 
         return images_layout_res
 
@@ -211,8 +206,7 @@ def doc_batch_analyze(
     table_enable=None,
     batch_ratio: int | None = None,
 ) -> InferenceResult:
-    """
-    Perform batch analysis on a document dataset.
+    """Perform batch analysis on a document dataset.
 
     Args:
         dataset (Dataset): The dataset containing document pages to be analyzed.
@@ -234,9 +228,9 @@ def doc_batch_analyze(
     """
 
     if not torch.cuda.is_available():
-        raise CUDA_NOT_AVAILABLE("batch analyze not support in CPU mode")
+        raise CUDA_NOT_AVAILABLE('batch analyze not support in CPU mode')
 
-    lang = None if lang == "" else lang
+    lang = None if lang == '' else lang
     # TODO: auto detect batch size
     batch_ratio = 1 if batch_ratio is None else batch_ratio
     end_page_id = end_page_id if end_page_id else len(dataset)
@@ -255,26 +249,26 @@ def doc_batch_analyze(
         if start_page_id <= index <= end_page_id:
             page_data = dataset.get_page(index)
             img_dict = page_data.get_image()
-            images.append(img_dict["img"])
+            images.append(img_dict['img'])
     analyze_result = batch_model(images)
 
     for index in range(len(dataset)):
         page_data = dataset.get_page(index)
         img_dict = page_data.get_image()
-        page_width = img_dict["width"]
-        page_height = img_dict["height"]
+        page_width = img_dict['width']
+        page_height = img_dict['height']
         if start_page_id <= index <= end_page_id:
             result = analyze_result.pop(0)
         else:
             result = []
 
-        page_info = {"page_no": index, "height": page_height, "width": page_width}
-        page_dict = {"layout_dets": result, "page_info": page_info}
+        page_info = {'page_no': index, 'height': page_height, 'width': page_width}
+        page_dict = {'layout_dets': result, 'page_info': page_info}
         model_json.append(page_dict)
 
     # TODO: clean memory when gpu memory is not enough
     clean_memory_start_time = time.time()
     clean_memory()
-    logger.info(f"clean memory time: {round(time.time() - clean_memory_start_time, 2)}")
+    logger.info(f'clean memory time: {round(time.time() - clean_memory_start_time, 2)}')
 
     return InferenceResult(model_json, dataset)

+ 3 - 3
magic_pdf/model/doc_analyze_by_custom_model.py

@@ -1,10 +1,10 @@
 import os
 import time
 
-from loguru import logger
-
 # 关闭paddle的信号处理
 import paddle
+from loguru import logger
+
 paddle.disable_signal_handler()
 
 os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1'  # 禁止albumentations检查更新
@@ -25,7 +25,7 @@ from magic_pdf.libs.config_reader import (get_device, get_formula_config,
                                           get_local_models_dir,
                                           get_table_recog_config)
 from magic_pdf.model.model_list import MODEL
-from magic_pdf.model.operators import InferenceResult
+from magic_pdf.operators.models import InferenceResult
 
 
 def dict_compare(d1, d2):

+ 0 - 0
magic_pdf/operators/__init__.py


+ 2 - 4
magic_pdf/model/operators.py → magic_pdf/operators/models.py

@@ -7,15 +7,13 @@ from magic_pdf.config.constants import PARSE_TYPE_OCR, PARSE_TYPE_TXT
 from magic_pdf.config.enums import SupportedPdfParseMethod
 from magic_pdf.data.data_reader_writer import DataWriter
 from magic_pdf.data.dataset import Dataset
-from magic_pdf.filter import classify
 from magic_pdf.libs.draw_bbox import draw_model_bbox
 from magic_pdf.libs.version import __version__
-from magic_pdf.model import InferenceResultBase
+from magic_pdf.operators.pipes import PipeResult
 from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
-from magic_pdf.pipe.operators import PipeResult
 
 
-class InferenceResult(InferenceResultBase):
+class InferenceResult:
     def __init__(self, inference_results: list, dataset: Dataset):
         """Initialized method.
 

+ 0 - 0
magic_pdf/pipe/operators.py → magic_pdf/operators/pipes.py


+ 3 - 3
magic_pdf/tools/common.py

@@ -10,7 +10,7 @@ from magic_pdf.config.make_content_config import DropMode, MakeMode
 from magic_pdf.data.data_reader_writer import FileBasedDataWriter
 from magic_pdf.data.dataset import PymuDocDataset
 from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
-from magic_pdf.model.operators import InferenceResult
+from magic_pdf.operators.models import InferenceResult
 
 # from io import BytesIO
 # from pypdf import PdfReader, PdfWriter
@@ -167,7 +167,7 @@ def do_parse(
             logger.error('need model list input')
             exit(2)
     else:
-        
+
         infer_result = InferenceResult(model_list, ds)
         if parse_method == 'ocr':
             pipe_result = infer_result.pipe_ocr_mode(
@@ -186,7 +186,7 @@ def do_parse(
                 pipe_result = infer_result.pipe_ocr_mode(
                         image_writer, debug_mode=True, lang=ds._lang
                     )
-            
+
 
     if f_draw_model_bbox:
         infer_result.draw_model(

+ 1 - 1
next_docs/en/api/model_operators.rst

@@ -2,7 +2,7 @@
 Model Api
 ==========
 
-.. autoclass:: magic_pdf.model.InferenceResultBase
+.. autoclass:: magic_pdf.operators.models.InferenceResult
    :members:
    :inherited-members:
    :show-inheritance:

+ 2 - 2
next_docs/en/api/pipe_operators.rst

@@ -3,7 +3,7 @@
 Pipeline Api
 =============
 
-.. autoclass:: magic_pdf.pipe.operators.PipeResult
+.. autoclass:: magic_pdf.operators.pipes.PipeResult
    :members:
    :inherited-members:
-   :show-inheritance:
+   :show-inheritance:

+ 7 - 8
next_docs/en/user_guide/inference_result.rst

@@ -1,5 +1,5 @@
 
-Inference Result 
+Inference Result
 ==================
 
 .. admonition:: Tip
@@ -7,7 +7,7 @@ Inference Result
 
     Please first navigate to :doc:`tutorial/pipeline` to get an initial understanding of how the pipeline works; this will help in understanding the content of this section.
 
-The **InferenceResult** class is a container for storing model inference results and implements a series of methods related to these results, such as draw_model, dump_model. 
+The **InferenceResult** class is a container for storing model inference results and implements a series of methods related to these results, such as draw_model, dump_model.
 Checkout :doc:`../api/model_operators` for more details about **InferenceResult**
 
 
@@ -56,7 +56,7 @@ Structure Definition
             page_info: PageInfo = Field(description="Page metadata")
 
 
-Example 
+Example
 ^^^^^^^^^^^
 
 .. code:: json
@@ -116,15 +116,15 @@ and bottom-left points respectively. |Poly Coordinate Diagram|
 
 
 
-Inference Result 
+Inference Result
 -------------------------
 
 
 .. code:: python
 
-    from magic_pdf.model.operators import InferenceResult
-    from magic_pdf.data.dataset import Dataset 
-    
+    from magic_pdf.operators.models import InferenceResult
+    from magic_pdf.data.dataset import Dataset
+
     dataset : Dataset = some_data_set    # not real dataset
 
     # The inference results of all pages, ordered by page number, are stored in a list as the inference results of MinerU
@@ -142,4 +142,3 @@ some_model.pdf
 
 
 .. |Poly Coordinate Diagram| image:: ../_static/image/poly.png
-

+ 7 - 7
next_docs/en/user_guide/pipe_result.rst

@@ -1,6 +1,6 @@
 
 
-Pipe Result 
+Pipe Result
 ==============
 
 .. admonition:: Tip
@@ -9,7 +9,7 @@ Pipe Result
     Please first navigate to :doc:`tutorial/pipeline` to get an initial understanding of how the pipeline works; this will help in understanding the content of this section.
 
 
-The **PipeResult** class is a container for storing pipeline processing results and implements a series of methods related to these results, such as draw_layout, draw_span. 
+The **PipeResult** class is a container for storing pipeline processing results and implements a series of methods related to these results, such as draw_layout, draw_span.
 Checkout :doc:`../api/pipe_operators` for more details about **PipeResult**
 
 
@@ -288,14 +288,14 @@ example
    }
 
 
-Pipeline Result 
+Pipeline Result
 ------------------
 
-.. code:: python 
+.. code:: python
 
     from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
-    from magic_pdf.pipe.operators import PipeResult
-    from magic_pdf.data.dataset import Dataset 
+    from magic_pdf.operators.pipes import PipeResult
+    from magic_pdf.data.dataset import Dataset
 
     res = pdf_parse_union(*args, **kwargs)
     res['_parse_type'] = PARSE_TYPE_OCR
@@ -332,4 +332,4 @@ unrecognized inline formulas.
 .. figure:: ../_static/image/spans_example.png
    :alt: spans example
 
-   spans example
+   spans example