Эх сурвалжийг харах

feat(ocr): pass language parameter for custom model init

Pass the `lang` parameter to `custom_model_init` in `doc_analyze` to support language-specific OCR configurations. This enhancement allows the use of language information to improve OCR accuracy when processing PDFs.
myhloli 1 жил өмнө
parent
commit
4b372f3f7e

+ 10 - 8
magic_pdf/model/doc_analyze_by_custom_model.py

@@ -57,14 +57,14 @@ class ModelSingleton:
             cls._instance = super().__new__(cls)
         return cls._instance
 
-    def get_model(self, ocr: bool, show_log: bool):
-        key = (ocr, show_log)
+    def get_model(self, ocr: bool, show_log: bool, lang):
+        key = (ocr, show_log, lang)
         if key not in self._models:
-            self._models[key] = custom_model_init(ocr=ocr, show_log=show_log)
+            self._models[key] = custom_model_init(ocr=ocr, show_log=show_log, lang=lang)
         return self._models[key]
 
 
-def custom_model_init(ocr: bool = False, show_log: bool = False):
+def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None):
     model = None
 
     if model_config.__model_mode__ == "lite":
@@ -78,7 +78,7 @@ def custom_model_init(ocr: bool = False, show_log: bool = False):
         model_init_start = time.time()
         if model == MODEL.Paddle:
             from magic_pdf.model.pp_structure_v2 import CustomPaddleModel
-            custom_model = CustomPaddleModel(ocr=ocr, show_log=show_log)
+            custom_model = CustomPaddleModel(ocr=ocr, show_log=show_log, lang=lang)
         elif model == MODEL.PEK:
             from magic_pdf.model.pdf_extract_kit import CustomPEKModel
             # 从配置文件读取model-dir和device
@@ -89,7 +89,9 @@ def custom_model_init(ocr: bool = False, show_log: bool = False):
                            "show_log": show_log,
                            "models_dir": local_models_dir,
                            "device": device,
-                           "table_config": table_config}
+                           "table_config": table_config,
+                           "lang": lang,
+                           }
             custom_model = CustomPEKModel(**model_input)
         else:
             logger.error("Not allow model_name!")
@@ -104,10 +106,10 @@ def custom_model_init(ocr: bool = False, show_log: bool = False):
 
 
 def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False,
-                start_page_id=0, end_page_id=None):
+                start_page_id=0, end_page_id=None, lang=None):
 
     model_manager = ModelSingleton()
-    custom_model = model_manager.get_model(ocr, show_log)
+    custom_model = model_manager.get_model(ocr, show_log, lang)
 
     images = load_images_from_pdf(pdf_bytes)
 

+ 12 - 6
magic_pdf/model/pdf_extract_kit.py

@@ -74,8 +74,11 @@ def layout_model_init(weight, config_file, device):
     return model
 
 
-def ocr_model_init(show_log: bool = False, det_db_box_thresh=0.3):
-    model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh)
+def ocr_model_init(show_log: bool = False, det_db_box_thresh=0.3, lang=None):
+    if lang is not None:
+        model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, lang=lang)
+    else:
+        model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh)
     return model
 
 
@@ -134,7 +137,8 @@ def atom_model_init(model_name: str, **kwargs):
     elif model_name == AtomicModel.OCR:
         atom_model = ocr_model_init(
             kwargs.get("ocr_show_log"),
-            kwargs.get("det_db_box_thresh")
+            kwargs.get("det_db_box_thresh"),
+            kwargs.get("lang")
         )
     elif model_name == AtomicModel.Table:
         atom_model = table_model_init(
@@ -177,9 +181,10 @@ class CustomPEKModel:
         self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE)
         self.table_model_type = self.table_config.get("model", TABLE_MASTER)
         self.apply_ocr = ocr
+        self.lang = kwargs.get("lang", None)
         logger.info(
-            "DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}, apply_table: {}".format(
-                self.apply_layout, self.apply_formula, self.apply_ocr, self.apply_table
+            "DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}, apply_table: {}, lang: {}".format(
+                self.apply_layout, self.apply_formula, self.apply_ocr, self.apply_table, self.lang
             )
         )
         assert self.apply_layout, "DocAnalysis must contain layout model."
@@ -229,7 +234,8 @@ class CustomPEKModel:
             self.ocr_model = atom_model_manager.get_atom_model(
                 atom_model_name=AtomicModel.OCR,
                 ocr_show_log=show_log,
-                det_db_box_thresh=0.3
+                det_db_box_thresh=0.3,
+                lang=self.lang
             )
         # init table model
         if self.apply_table:

+ 5 - 2
magic_pdf/model/pp_structure_v2.py

@@ -18,8 +18,11 @@ def region_to_bbox(region):
 
 
 class CustomPaddleModel:
-    def __init__(self, ocr: bool = False, show_log: bool = False):
-        self.model = PPStructure(table=False, ocr=ocr, show_log=show_log)
+    def __init__(self, ocr: bool = False, show_log: bool = False, lang=None):
+        if lang is not None:
+            self.model = PPStructure(table=False, ocr=ocr, show_log=show_log, lang=lang)
+        else:
+            self.model = PPStructure(table=False, ocr=ocr, show_log=show_log)
 
     def __call__(self, img):
         try:

+ 2 - 1
magic_pdf/pipe/AbsPipe.py

@@ -17,7 +17,7 @@ class AbsPipe(ABC):
     PIP_TXT = "txt"
 
     def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
-                 start_page_id=0, end_page_id=None):
+                 start_page_id=0, end_page_id=None, lang=None):
         self.pdf_bytes = pdf_bytes
         self.model_list = model_list
         self.image_writer = image_writer
@@ -25,6 +25,7 @@ class AbsPipe(ABC):
         self.is_debug = is_debug
         self.start_page_id = start_page_id
         self.end_page_id = end_page_id
+        self.lang = lang
     
     def get_compress_pdf_mid_data(self):
         return JsonCompressor.compress_json(self.pdf_mid_data)

+ 4 - 3
magic_pdf/pipe/OCRPipe.py

@@ -10,15 +10,16 @@ from magic_pdf.user_api import parse_ocr_pdf
 class OCRPipe(AbsPipe):
 
     def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
-                 start_page_id=0, end_page_id=None):
-        super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id)
+                 start_page_id=0, end_page_id=None, lang=None):
+        super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang)
 
     def pipe_classify(self):
         pass
 
     def pipe_analyze(self):
         self.model_list = doc_analyze(self.pdf_bytes, ocr=True,
-                                      start_page_id=self.start_page_id, end_page_id=self.end_page_id)
+                                      start_page_id=self.start_page_id, end_page_id=self.end_page_id,
+                                      lang=self.lang)
 
     def pipe_parse(self):
         self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug,

+ 4 - 3
magic_pdf/pipe/TXTPipe.py

@@ -11,15 +11,16 @@ from magic_pdf.user_api import parse_txt_pdf
 class TXTPipe(AbsPipe):
 
     def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
-                 start_page_id=0, end_page_id=None):
-        super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id)
+                 start_page_id=0, end_page_id=None, lang=None):
+        super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang)
 
     def pipe_classify(self):
         pass
 
     def pipe_analyze(self):
         self.model_list = doc_analyze(self.pdf_bytes, ocr=False,
-                                      start_page_id=self.start_page_id, end_page_id=self.end_page_id)
+                                      start_page_id=self.start_page_id, end_page_id=self.end_page_id,
+                                      lang=self.lang)
 
     def pipe_parse(self):
         self.pdf_mid_data = parse_txt_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug,

+ 6 - 4
magic_pdf/pipe/UNIPipe.py

@@ -14,9 +14,9 @@ from magic_pdf.user_api import parse_union_pdf, parse_ocr_pdf
 class UNIPipe(AbsPipe):
 
     def __init__(self, pdf_bytes: bytes, jso_useful_key: dict, image_writer: AbsReaderWriter, is_debug: bool = False,
-                 start_page_id=0, end_page_id=None):
+                 start_page_id=0, end_page_id=None, lang=None):
         self.pdf_type = jso_useful_key["_pdf_type"]
-        super().__init__(pdf_bytes, jso_useful_key["model_list"], image_writer, is_debug, start_page_id, end_page_id)
+        super().__init__(pdf_bytes, jso_useful_key["model_list"], image_writer, is_debug, start_page_id, end_page_id, lang)
         if len(self.model_list) == 0:
             self.input_model_is_empty = True
         else:
@@ -28,10 +28,12 @@ class UNIPipe(AbsPipe):
     def pipe_analyze(self):
         if self.pdf_type == self.PIP_TXT:
             self.model_list = doc_analyze(self.pdf_bytes, ocr=False,
-                                          start_page_id=self.start_page_id, end_page_id=self.end_page_id)
+                                          start_page_id=self.start_page_id, end_page_id=self.end_page_id,
+                                          lang=self.lang)
         elif self.pdf_type == self.PIP_OCR:
             self.model_list = doc_analyze(self.pdf_bytes, ocr=True,
-                                          start_page_id=self.start_page_id, end_page_id=self.end_page_id)
+                                          start_page_id=self.start_page_id, end_page_id=self.end_page_id,
+                                          lang=self.lang)
 
     def pipe_parse(self):
         if self.pdf_type == self.PIP_TXT:

+ 14 - 1
magic_pdf/tools/cli.py

@@ -45,6 +45,18 @@ without method specified, auto will be used by default.""",
     default='auto',
 )
 @click.option(
+    '-l',
+    '--lang',
+    'lang',
+    type=str,
+    help="""
+    Input the languages in the pdf (if known) to improve OCR accuracy.  Optional.
+    You should input "Abbreviation" with language form url:
+    https://paddlepaddle.github.io/PaddleOCR/en/ppocr/blog/multi_languages.html#5-support-languages-and-abbreviations
+    """,
+    default=None,
+)
+@click.option(
     '-d',
     '--debug',
     'debug_able',
@@ -68,7 +80,7 @@ without method specified, auto will be used by default.""",
     help='The ending page for PDF parsing, beginning from 0.',
     default=None,
 )
-def cli(path, output_dir, method, debug_able, start_page_id, end_page_id):
+def cli(path, output_dir, method, lang, debug_able, start_page_id, end_page_id):
     model_config.__use_inside_model__ = True
     model_config.__model_mode__ = 'full'
     os.makedirs(output_dir, exist_ok=True)
@@ -90,6 +102,7 @@ def cli(path, output_dir, method, debug_able, start_page_id, end_page_id):
                 debug_able,
                 start_page_id=start_page_id,
                 end_page_id=end_page_id,
+                lang=lang
             )
 
         except Exception as e:

+ 4 - 3
magic_pdf/tools/common.py

@@ -44,6 +44,7 @@ def do_parse(
     f_draw_model_bbox=False,
     start_page_id=0,
     end_page_id=None,
+    lang=None,
 ):
     if debug_able:
         logger.warning("debug mode is on")
@@ -61,13 +62,13 @@ def do_parse(
     if parse_method == 'auto':
         jso_useful_key = {'_pdf_type': '', 'model_list': model_list}
         pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer, is_debug=True,
-                       start_page_id=start_page_id, end_page_id=end_page_id)
+                       start_page_id=start_page_id, end_page_id=end_page_id, lang=lang)
     elif parse_method == 'txt':
         pipe = TXTPipe(pdf_bytes, model_list, image_writer, is_debug=True,
-                       start_page_id=start_page_id, end_page_id=end_page_id)
+                       start_page_id=start_page_id, end_page_id=end_page_id, lang=lang)
     elif parse_method == 'ocr':
         pipe = OCRPipe(pdf_bytes, model_list, image_writer, is_debug=True,
-                       start_page_id=start_page_id, end_page_id=end_page_id)
+                       start_page_id=start_page_id, end_page_id=end_page_id, lang=lang)
     else:
         logger.error('unknown parse method')
         exit(1)