Преглед на файлове

Merge pull request #11 from magicpdf/dev-xm

fix logic
drunkpig преди 1 година
родител
ревизия
b75ee676fe
променени са 6 файла, в които са добавени 135 реда и са изтрити 31 реда
  1. 12 8
      demo/pdf2md.py
  2. 33 13
      magic_pdf/dict2md/ocr_mkcontent.py
  3. 32 5
      magic_pdf/libs/config_reader.py
  4. 1 1
      magic_pdf/pdf_parse_by_txt.py
  5. 6 4
      magic_pdf/pipeline.py
  6. 51 0
      utils/config_init_to_json.py

+ 12 - 8
demo/pdf2md.py

@@ -1,3 +1,4 @@
+import json
 import os
 import sys
 from pathlib import Path
@@ -6,8 +7,8 @@ import click
 from loguru import logger
 
 from magic_pdf.libs.commons import join_path, read_file
-from magic_pdf.dict2md.mkcontent import mk_mm_markdown
-from magic_pdf.pipeline import parse_pdf_by_model
+from magic_pdf.dict2md.mkcontent import mk_mm_markdown, mk_universal_format
+from magic_pdf.pdf_parse_by_txt import parse_pdf_by_txt
 
 
 
@@ -24,7 +25,7 @@ def main(s3_pdf_path: str, s3_pdf_profile: str, pdf_model_path: str, pdf_model_p
     pdf_bytes = read_file(s3_pdf_path, s3_pdf_profile)
 
     try:
-        paras_dict = parse_pdf_by_model(
+        paras_dict = parse_pdf_by_txt(
             pdf_bytes, pdf_model_path, save_path, book_name, pdf_model_profile, start_page_num, debug_mode=debug_mode
         )
         parent_dir = os.path.dirname(text_content_save_path)
@@ -32,7 +33,8 @@ def main(s3_pdf_path: str, s3_pdf_profile: str, pdf_model_path: str, pdf_model_p
             os.makedirs(parent_dir)
                 
         if not paras_dict.get('need_drop'):
-            markdown_content = mk_mm_markdown(paras_dict)
+            content_list = mk_universal_format(paras_dict)
+            markdown_content = mk_mm_markdown(content_list)
         else:
             markdown_content = paras_dict['drop_reason']
             
@@ -70,8 +72,8 @@ def main_shell(pdf_file_path: str, save_path: str):
 
 
 @click.command()
-@click.option("--pdf-dir", help="s3上pdf文件的路径")
-@click.option("--model-dir", help="s3上pdf文件的路径")
+@click.option("--pdf-dir", help="本地pdf文件的路径")
+@click.option("--model-dir", help="本地模型文件的路径")
 @click.option("--start-page-num", default=0, help="从第几页开始解析")
 def main_shell2(pdf_dir: str, model_dir: str,start_page_num: int):
     # 先扫描所有的pdf目录里的文件名字
@@ -86,8 +88,10 @@ def main_shell2(pdf_dir: str, model_dir: str,start_page_num: int):
 
     for pdf_file in pdf_file_names:
         pdf_file_path = os.path.join(pdf_dir, pdf_file)
-        model_file_path = os.path.join(model_dir, pdf_file)
-        main(pdf_file_path, None, model_file_path, None, start_page_num)
+        model_file_path = os.path.join(model_dir, pdf_file).rstrip(".pdf") + ".json"
+        with open(model_file_path, "r") as json_file:
+            model_list = json.load(json_file)
+        main(pdf_file_path, None, model_list, None, start_page_num)
 
 
 

+ 33 - 13
magic_pdf/dict2md/ocr_mkcontent.py

@@ -1,4 +1,5 @@
 from magic_pdf.libs.commons import s3_image_save_path, join_path
+from magic_pdf.libs.language import detect_lang
 from magic_pdf.libs.markdown_utils import ocr_escape_special_markdown_char
 from magic_pdf.libs.ocr_content_type import ContentType
 import wordninja
@@ -72,7 +73,7 @@ def ocr_mk_mm_markdown_with_para(pdf_info_dict: dict):
     markdown = []
     for _, page_info in pdf_info_dict.items():
         paras_of_layout = page_info.get("para_blocks")
-        page_markdown = ocr_mk_mm_markdown_with_para_core(paras_of_layout, "mm")
+        page_markdown = ocr_mk_markdown_with_para_core(paras_of_layout, "mm")
         markdown.extend(page_markdown)
     return '\n\n'.join(markdown)
 
@@ -81,7 +82,7 @@ def ocr_mk_nlp_markdown_with_para(pdf_info_dict: dict):
     markdown = []
     for _, page_info in pdf_info_dict.items():
         paras_of_layout = page_info.get("para_blocks")
-        page_markdown = ocr_mk_mm_markdown_with_para_core(paras_of_layout, "nlp")
+        page_markdown = ocr_mk_markdown_with_para_core(paras_of_layout, "nlp")
         markdown.extend(page_markdown)
     return '\n\n'.join(markdown)
 
@@ -91,7 +92,7 @@ def ocr_mk_mm_markdown_with_para_and_pagination(pdf_info_dict: dict):
         paras_of_layout = page_info.get("para_blocks")
         if not paras_of_layout:
             continue
-        page_markdown = ocr_mk_mm_markdown_with_para_core(paras_of_layout, "mm")
+        page_markdown = ocr_mk_markdown_with_para_core(paras_of_layout, "mm")
         markdown_with_para_and_pagination.append({
             'page_no': page_no,
             'md_content': '\n\n'.join(page_markdown)
@@ -99,7 +100,7 @@ def ocr_mk_mm_markdown_with_para_and_pagination(pdf_info_dict: dict):
     return markdown_with_para_and_pagination
 
 
-def ocr_mk_mm_markdown_with_para_core(paras_of_layout, mode):
+def ocr_mk_markdown_with_para_core(paras_of_layout, mode):
     page_markdown = []
     for paras in paras_of_layout:
         for para in paras:
@@ -108,19 +109,28 @@ def ocr_mk_mm_markdown_with_para_core(paras_of_layout, mode):
                 for span in line['spans']:
                     span_type = span.get('type')
                     content = ''
+                    language = ''
                     if span_type == ContentType.Text:
-                        content = ocr_escape_special_markdown_char(split_long_words(span['content']))
+                        content = span['content']
+                        language = detect_lang(content)
+                        if language == 'en':  # 只对英文长词进行分词处理,中文分词会丢失文本
+                            content = ocr_escape_special_markdown_char(split_long_words(content))
+                        else:
+                            content = ocr_escape_special_markdown_char(content)
                     elif span_type == ContentType.InlineEquation:
-                        content = f"${ocr_escape_special_markdown_char(span['content'])}$"
+                        content = f"${span['content']}$"
                     elif span_type == ContentType.InterlineEquation:
-                        content = f"\n$$\n{ocr_escape_special_markdown_char(span['content'])}\n$$\n"
+                        content = f"\n$$\n{span['content']}\n$$\n"
                     elif span_type in [ContentType.Image, ContentType.Table]:
                         if mode == 'mm':
                             content = f"\n![]({join_path(s3_image_save_path, span['image_path'])})\n"
                         elif mode == 'nlp':
                             pass
                     if content != '':
-                        para_text += content + ' '
+                        if language == 'en':  # 英文语境下 content间需要空格分隔
+                            para_text += content + ' '
+                        else:  # 中文语境下,content间不需要空格分隔
+                            para_text += content
             if para_text.strip() == '':
                 continue
             else:
@@ -137,13 +147,23 @@ def para_to_standard_format(para):
         inline_equation_num = 0
         for line in para:
             for span in line['spans']:
+                language = ''
                 span_type = span.get('type')
                 if span_type == ContentType.Text:
-                    content = ocr_escape_special_markdown_char(split_long_words(span['content']))
+                    content = span['content']
+                    language = detect_lang(content)
+                    if language == 'en':  # 只对英文长词进行分词处理,中文分词会丢失文本
+                        content = ocr_escape_special_markdown_char(split_long_words(content))
+                    else:
+                        content = ocr_escape_special_markdown_char(content)
                 elif span_type == ContentType.InlineEquation:
-                    content = f"${ocr_escape_special_markdown_char(span['content'])}$"
+                    content = f"${span['content']}$"
                     inline_equation_num += 1
-                para_text += content + ' '
+
+                if language == 'en':  # 英文语境下 content间需要空格分隔
+                    para_text += content + ' '
+                else:  # 中文语境下,content间不需要空格分隔
+                    para_text += content
         para_content = {
             'type': 'text',
             'text': para_text,
@@ -186,14 +206,14 @@ def line_to_standard_format(line):
                     return content
         else:
             if span['type'] == ContentType.InterlineEquation:
-                interline_equation = ocr_escape_special_markdown_char(span['content'])  # 转义特殊符号
+                interline_equation = span['content']
                 content = {
                     'type': 'equation',
                     'latex': f"$$\n{interline_equation}\n$$"
                 }
                 return content
             elif span['type'] == ContentType.InlineEquation:
-                inline_equation = ocr_escape_special_markdown_char(span['content'])  # 转义特殊符号
+                inline_equation = span['content']
                 line_text += f"${inline_equation}$"
                 inline_equation_num += 1
             elif span['type'] == ContentType.Text:

+ 32 - 5
magic_pdf/libs/config_reader.py

@@ -1,14 +1,41 @@
-
-
 """
 根据bucket的名字返回对应的s3 AK, SK,endpoint三元组
 
 """
+import json
+import os
+
+from loguru import logger
+
 
 def get_s3_config(bucket_name: str):
     """
     ~/magic-pdf.json 读出来
     """
-    ak , sk, endpoint = "", "", ""
-    # TODO 请实现这个函数
-    return ak, sk, endpoint
+
+    home_dir = os.path.expanduser("~")
+
+    config_file = os.path.join(home_dir, "magic-pdf.json")
+
+    if not os.path.exists(config_file):
+        raise Exception("magic-pdf.json not found")
+
+    with open(config_file, "r") as f:
+        config = json.load(f)
+
+    bucket_info = config.get("bucket_info")
+    if bucket_name not in bucket_info:
+        raise Exception("bucket_name not found in magic-pdf.json")
+
+    access_key, secret_key, storage_endpoint = bucket_info[bucket_name]
+
+    if access_key is None or secret_key is None or storage_endpoint is None:
+        raise Exception("ak, sk or endpoint not found in magic-pdf.json")
+
+    # logger.info(f"get_s3_config: ak={access_key}, sk={secret_key}, endpoint={storage_endpoint}")
+
+    return access_key, secret_key, storage_endpoint
+
+
+if __name__ == '__main__':
+    ak, sk, endpoint = get_s3_config("llm-raw")

+ 1 - 1
magic_pdf/pdf_parse_by_model.py → magic_pdf/pdf_parse_by_txt.py

@@ -70,7 +70,7 @@ paraMergeException_msg = ParaMergeException().message
 
 
 
-def parse_pdf_by_model(
+def parse_pdf_by_txt(
     pdf_bytes,
     pdf_model_output,
     save_path,

+ 6 - 4
magic_pdf/pipeline.py

@@ -13,7 +13,7 @@ from magic_pdf.libs.commons import (
 from magic_pdf.libs.drop_reason import DropReason
 from magic_pdf.libs.json_compressor import JsonCompressor
 from magic_pdf.dict2md.mkcontent import mk_universal_format
-from magic_pdf.pdf_parse_by_model import parse_pdf_by_model
+from magic_pdf.pdf_parse_by_txt import parse_pdf_by_txt
 from magic_pdf.filter.pdf_classify_by_type import classify
 from magic_pdf.filter.pdf_meta_scan import pdf_meta_scan
 from loguru import logger
@@ -130,6 +130,7 @@ def classify_by_type(jso: dict, debug_mode=False) -> dict:
             classify_time = int(time.time() - start_time)  # 计算执行时间
             if is_text_pdf:
                 pdf_meta["is_text_pdf"] = is_text_pdf
+                jso["_pdf_type"] = "TXT"
                 jso["pdf_meta"] = pdf_meta
                 jso["classify_time"] = classify_time
                 # print(json.dumps(pdf_meta, ensure_ascii=False))
@@ -144,10 +145,11 @@ def classify_by_type(jso: dict, debug_mode=False) -> dict:
             else:
                 # 先不drop
                 pdf_meta["is_text_pdf"] = is_text_pdf
+                jso["_pdf_type"] = "OCR"
                 jso["pdf_meta"] = pdf_meta
                 jso["classify_time"] = classify_time
-                jso["need_drop"] = True
-                jso["drop_reason"] = DropReason.NOT_IS_TEXT_PDF
+                # jso["need_drop"] = True
+                # jso["drop_reason"] = DropReason.NOT_IS_TEXT_PDF
                 extra_info = {"classify_rules": []}
                 for condition, result in results.items():
                     if not result:
@@ -310,7 +312,7 @@ def parse_pdf(jso: dict, start_page_id=0, debug_mode=False) -> dict:
                 f"book_name is:{book_name},start_time is:{formatted_time(start_time)}",
                 file=sys.stderr,
             )
-            pdf_info_dict = parse_pdf_by_model(
+            pdf_info_dict = parse_pdf_by_txt(
                 pdf_bytes,
                 model_output_json_list,
                 save_path,

+ 51 - 0
utils/config_init_to_json.py

@@ -0,0 +1,51 @@
+from loguru import logger
+import json
+import os
+from magic_pdf.config import s3_buckets, s3_clusters, s3_users
+
+
+def get_bucket_configs_dict(buckets, clusters, users):
+    bucket_configs = {}
+    for s3_bucket in buckets.items():
+        bucket_name = s3_bucket[0]
+        bucket_config = s3_bucket[1]
+        cluster, user = bucket_config
+        cluster_config = clusters[cluster]
+        endpoint_key = "outside"
+        endpoints = cluster_config[endpoint_key]
+        endpoint = endpoints[0]
+        user_config = users[user]
+        # logger.info(bucket_name)
+        # logger.info(endpoint)
+        # logger.info(user_config)
+        bucket_config = [user_config["ak"], user_config["sk"], endpoint]
+        bucket_configs[bucket_name] = bucket_config
+
+    return bucket_configs
+
+
+def write_json_to_home(my_dict):
+    # Convert dictionary to JSON
+    json_data = json.dumps(my_dict, indent=4, ensure_ascii=False)
+
+    home_dir = os.path.expanduser("~")
+
+    # Define the output file path
+    output_file = os.path.join(home_dir, "magic-pdf.json")
+
+    # Write JSON data to the output file
+    with open(output_file, "w") as f:
+        f.write(json_data)
+
+    # Print a success message
+    print(f"Dictionary converted to JSON and saved to {output_file}")
+
+
+if __name__ == '__main__':
+    bucket_configs_dict = get_bucket_configs_dict(s3_buckets, s3_clusters, s3_users)
+    logger.info(bucket_configs_dict)
+    config_dict = {
+        "bucket_info": bucket_configs_dict,
+        "temp-output-dir": "/tmp"
+    }
+    write_json_to_home(config_dict)