瀏覽代碼

Merge remote-tracking branch 'origin/dev' into dev

myhloli 7 月之前
父節點
當前提交
15467730cf

+ 59 - 48
magic_pdf/data/batch_build_dataset.py

@@ -103,54 +103,65 @@ def batch_build_dataset(pdf_paths, k, lang=None):
     all_images : list
         List of all processed images
     """
-    # Get page counts for each PDF
-    pdf_info = []
-    total_pages = 0
 
+    results = []
     for pdf_path in pdf_paths:
-        try:
-            doc = fitz.open(pdf_path)
-            num_pages = len(doc)
-            pdf_info.append((pdf_path, num_pages))
-            total_pages += num_pages
-            doc.close()
-        except Exception as e:
-            print(f'Error opening {pdf_path}: {e}')
-
-    # Partition the jobs based on page countEach job has 1 page
-    partitions = partition_array_greedy(pdf_info, k)
-
-    # Process each partition in parallel
-    all_images_h = {}
-
-    with concurrent.futures.ProcessPoolExecutor(max_workers=k) as executor:
-        # Submit one task per partition
-        futures = []
-        for sn, partition in enumerate(partitions):
-            # Get the jobs for this partition
-            partition_jobs = [pdf_info[idx] for idx in partition]
-
-            # Submit the task
-            future = executor.submit(
-                process_pdf_batch,
-                partition_jobs,
-                sn
-            )
-            futures.append(future)
-        # Process results as they complete
-        for i, future in enumerate(concurrent.futures.as_completed(futures)):
-            try:
-                idx, images = future.result()
-                all_images_h[idx] = images
-            except Exception as e:
-                print(f'Error processing partition: {e}')
-    results = [None] * len(pdf_paths)
-    for i in range(len(partitions)):
-        partition = partitions[i]
-        for j in range(len(partition)):
-            with open(pdf_info[partition[j]][0], 'rb') as f:
-                pdf_bytes = f.read()
-            dataset = PymuDocDataset(pdf_bytes, lang=lang)
-            dataset.set_images(all_images_h[i][j])
-            results[partition[j]] = dataset
+        with open(pdf_path, 'rb') as f:
+            pdf_bytes = f.read()
+        dataset = PymuDocDataset(pdf_bytes, lang=lang)
+        results.append(dataset)
     return results
+
+
+    #
+    # # Get page counts for each PDF
+    # pdf_info = []
+    # total_pages = 0
+    #
+    # for pdf_path in pdf_paths:
+    #     try:
+    #         doc = fitz.open(pdf_path)
+    #         num_pages = len(doc)
+    #         pdf_info.append((pdf_path, num_pages))
+    #         total_pages += num_pages
+    #         doc.close()
+    #     except Exception as e:
+    #         print(f'Error opening {pdf_path}: {e}')
+    #
+    # # Partition the jobs based on page countEach job has 1 page
+    # partitions = partition_array_greedy(pdf_info, k)
+    #
+    # # Process each partition in parallel
+    # all_images_h = {}
+    #
+    # with concurrent.futures.ProcessPoolExecutor(max_workers=k) as executor:
+    #     # Submit one task per partition
+    #     futures = []
+    #     for sn, partition in enumerate(partitions):
+    #         # Get the jobs for this partition
+    #         partition_jobs = [pdf_info[idx] for idx in partition]
+    #
+    #         # Submit the task
+    #         future = executor.submit(
+    #             process_pdf_batch,
+    #             partition_jobs,
+    #             sn
+    #         )
+    #         futures.append(future)
+    #     # Process results as they complete
+    #     for i, future in enumerate(concurrent.futures.as_completed(futures)):
+    #         try:
+    #             idx, images = future.result()
+    #             all_images_h[idx] = images
+    #         except Exception as e:
+    #             print(f'Error processing partition: {e}')
+    # results = [None] * len(pdf_paths)
+    # for i in range(len(partitions)):
+    #     partition = partitions[i]
+    #     for j in range(len(partition)):
+    #         with open(pdf_info[partition[j]][0], 'rb') as f:
+    #             pdf_bytes = f.read()
+    #         dataset = PymuDocDataset(pdf_bytes, lang=lang)
+    #         dataset.set_images(all_images_h[i][j])
+    #         results[partition[j]] = dataset
+    # return results

+ 2 - 2
magic_pdf/data/dataset.py

@@ -150,7 +150,7 @@ class PymuDocDataset(Dataset):
         elif lang == 'auto':
             from magic_pdf.model.sub_modules.language_detection.utils import \
                 auto_detect_lang
-            self._lang = auto_detect_lang(bits)
+            self._lang = auto_detect_lang(self._data_bits)
             logger.info(f'lang: {lang}, detect_lang: {self._lang}')
         else:
             self._lang = lang
@@ -405,4 +405,4 @@ class Doc(PageableData):
             fontsize (int): font size of the text
             color (list[float] | None):  three element tuple which describe the RGB of the board line, None will use the default font color!
         """
-        self._doc.insert_text(coord, content, fontsize=fontsize, color=color)
+        self._doc.insert_text(coord, content, fontsize=fontsize, color=color)

+ 8 - 2
magic_pdf/model/batch_analyze.py

@@ -30,8 +30,14 @@ class BatchAnalyze:
     
         images_layout_res = []
         layout_start_time = time.time()
-        _, fst_ocr, fst_lang = images_with_extra_info[0]
-        self.model = self.model_manager.get_model(fst_ocr, self.show_log, fst_lang, self.layout_model, self.formula_enable, self.table_enable)
+        self.model = self.model_manager.get_model(
+            ocr=True,
+            show_log=self.show_log,
+            lang = None,
+            layout_model = self.layout_model,
+            formula_enable = self.formula_enable,
+            table_enable = self.table_enable,
+        )
 
         images = [image for image, _, _ in images_with_extra_info]
 

+ 23 - 23
magic_pdf/model/doc_analyze_by_custom_model.py

@@ -146,11 +146,8 @@ def doc_analyze(
             img_dict = page_data.get_image()
             images.append(img_dict['img'])
             page_wh_list.append((img_dict['width'], img_dict['height']))
-    
-    if lang is None or lang == 'auto':
-        images_with_extra_info = [(images[index], ocr, dataset._lang) for index in range(len(images))]
-    else:
-        images_with_extra_info = [(images[index], ocr, lang) for index in range(len(images))]
+
+    images_with_extra_info = [(images[index], ocr, dataset._lang) for index in range(len(dataset))]
 
     if len(images) >= MIN_BATCH_INFERENCE_SIZE:
         batch_size = MIN_BATCH_INFERENCE_SIZE
@@ -159,8 +156,8 @@ def doc_analyze(
         batch_images = [images_with_extra_info]
 
     results = []
-    for sn, batch_image in enumerate(batch_images):
-        _, result = may_batch_image_analyze(batch_image, sn, ocr, show_log,layout_model, formula_enable, table_enable)
+    for batch_image in batch_images:
+        result = may_batch_image_analyze(batch_image, ocr, show_log,layout_model, formula_enable, table_enable)
         results.extend(result)
 
     model_json = []
@@ -182,7 +179,7 @@ def doc_analyze(
 
 def batch_doc_analyze(
     datasets: list[Dataset],
-    parse_method: str,
+    parse_method: str = 'auto',
     show_log: bool = False,
     lang=None,
     layout_model=None,
@@ -191,30 +188,34 @@ def batch_doc_analyze(
 ):
     MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 200))
     batch_size = MIN_BATCH_INFERENCE_SIZE
-    images = []
     page_wh_list = []
 
     images_with_extra_info = []
     for dataset in datasets:
-        for index in range(len(dataset)):
-            if lang is None or lang == 'auto':
-                _lang = dataset._lang
-            else:
-                _lang = lang
 
+        ocr = False
+        if parse_method == 'auto':
+            if dataset.classify() == SupportedPdfParseMethod.TXT:
+                ocr = False
+            elif dataset.classify() == SupportedPdfParseMethod.OCR:
+                ocr = True
+        elif parse_method == 'ocr':
+            ocr = True
+        elif parse_method == 'txt':
+            ocr = False
+
+        _lang = dataset._lang
+
+        for index in range(len(dataset)):
             page_data = dataset.get_page(index)
             img_dict = page_data.get_image()
-            images.append(img_dict['img'])
             page_wh_list.append((img_dict['width'], img_dict['height']))
-            if parse_method == 'auto':
-                images_with_extra_info.append((images[-1], dataset.classify() == SupportedPdfParseMethod.OCR, _lang))
-            else:
-                images_with_extra_info.append((images[-1], parse_method == 'ocr', _lang))
+            images_with_extra_info.append((img_dict['img'], ocr, _lang))
 
     batch_images = [images_with_extra_info[i:i+batch_size] for i in range(0, len(images_with_extra_info), batch_size)]
     results = []
-    for sn, batch_image in enumerate(batch_images):
-        _, result = may_batch_image_analyze(batch_image, sn, True, show_log, layout_model, formula_enable, table_enable)
+    for batch_image in batch_images:
+        result = may_batch_image_analyze(batch_image, True, show_log, layout_model, formula_enable, table_enable)
         results.extend(result)
 
     infer_results = []
@@ -234,7 +235,6 @@ def batch_doc_analyze(
 
 def may_batch_image_analyze(
         images_with_extra_info: list[(np.ndarray, bool, str)],
-        idx: int,
         ocr: bool,
         show_log: bool = False,
         layout_model=None,
@@ -292,4 +292,4 @@ def may_batch_image_analyze(
     #     f'doc analyze time: {round(time.time() - doc_analyze_start, 2)},'
     #     f' speed: {doc_analyze_speed} pages/second'
     # )
-    return idx, results
+    return results

+ 21 - 4
magic_pdf/tools/common.py

@@ -109,9 +109,7 @@ def _do_parse(
     pdf_bytes = ds._raw_data
     local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, parse_method)
 
-    image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(
-        local_md_dir
-    )
+    image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir)
     image_dir = str(os.path.basename(local_image_dir))
 
     if len(model_list) == 0:
@@ -317,7 +315,26 @@ def batch_do_parse(
 
     infer_results = batch_doc_analyze(dss, parse_method, lang=lang, layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
     for idx, infer_result in enumerate(infer_results):
-        _do_parse(output_dir, pdf_file_names[idx], dss[idx], infer_result.get_infer_res(), parse_method, debug_able, f_draw_span_bbox=f_draw_span_bbox, f_draw_layout_bbox=f_draw_layout_bbox, f_dump_md=f_dump_md, f_dump_middle_json=f_dump_middle_json, f_dump_model_json=f_dump_model_json, f_dump_orig_pdf=f_dump_orig_pdf, f_dump_content_list=f_dump_content_list, f_make_md_mode=f_make_md_mode, f_draw_model_bbox=f_draw_model_bbox, f_draw_line_sort_bbox=f_draw_line_sort_bbox, f_draw_char_bbox=f_draw_char_bbox, lang=lang)
+        _do_parse(
+            output_dir = output_dir,
+            pdf_file_name = pdf_file_names[idx],
+            pdf_bytes_or_dataset = dss[idx],
+            model_list = infer_result.get_infer_res(),
+            parse_method = parse_method,
+            debug_able = debug_able,
+            f_draw_span_bbox = f_draw_span_bbox,
+            f_draw_layout_bbox = f_draw_layout_bbox,
+            f_dump_md=f_dump_md,
+            f_dump_middle_json=f_dump_middle_json,
+            f_dump_model_json=f_dump_model_json,
+            f_dump_orig_pdf=f_dump_orig_pdf,
+            f_dump_content_list=f_dump_content_list,
+            f_make_md_mode=MakeMode.MM_MD,
+            f_draw_model_bbox=f_draw_model_bbox,
+            f_draw_line_sort_bbox=f_draw_line_sort_bbox,
+            f_draw_char_bbox=f_draw_char_bbox,
+            lang=lang,
+        )
 
 
 parse_pdf_methods = click.Choice(['ocr', 'txt', 'auto'])

+ 7 - 4
projects/gradio_app/app.py

@@ -159,9 +159,12 @@ devanagari_lang = [
         'sa', 'bgc'
 ]
 other_lang = ['ch', 'en', 'korean', 'japan', 'chinese_cht', 'ta', 'te', 'ka']
+add_lang = ['latin', 'arabic', 'cyrillic', 'devanagari']
 
-all_lang = ['', 'auto']
-all_lang.extend([*other_lang, *latin_lang, *arabic_lang, *cyrillic_lang, *devanagari_lang])
+# all_lang = ['', 'auto']
+all_lang = []
+# all_lang.extend([*other_lang, *latin_lang, *arabic_lang, *cyrillic_lang, *devanagari_lang])
+all_lang.extend([*other_lang, *add_lang])
 
 
 def to_pdf(file_path):
@@ -192,8 +195,8 @@ if __name__ == '__main__':
                 file = gr.File(label='Please upload a PDF or image', file_types=['.pdf', '.png', '.jpeg', '.jpg'])
                 max_pages = gr.Slider(1, 20, 10, step=1, label='Max convert pages')
                 with gr.Row():
-                    layout_mode = gr.Dropdown(['layoutlmv3', 'doclayout_yolo'], label='Layout model', value='doclayout_yolo')
-                    language = gr.Dropdown(all_lang, label='Language', value='auto')
+                    layout_mode = gr.Dropdown(['doclayout_yolo'], label='Layout model', value='doclayout_yolo')
+                    language = gr.Dropdown(all_lang, label='Language', value='ch')
                 with gr.Row():
                     formula_enable = gr.Checkbox(label='Enable formula recognition', value=True)
                     is_ocr = gr.Checkbox(label='Force enable OCR', value=False)

+ 1 - 1
requirements.txt

@@ -7,7 +7,7 @@ numpy>=1.21.6
 pydantic>=2.7.2,<2.11
 PyMuPDF>=1.24.9,<1.25.0
 scikit-learn>=1.0.2
-torch>=2.2.2,!=2.5.0,!=2.5.1,<=2.6.0
+torch>=2.2.2,!=2.5.0,!=2.5.1
 torchvision
 transformers>=4.49.0,!=4.51.0,<5.0.0
 pdfminer.six==20231228