Browse Source

更新libs/config_reader,删除spark/s3.py
pipeline_cor.py pipeline_txt.py, pipeline.py 移动到code_clean并修复一些依赖关系

赵小蒙 1 năm trước cách đây
mục cha
commit
c81f699e68

+ 2 - 2
demo/demo_commons.py

@@ -1,6 +1,6 @@
 import json
 
-from magic_pdf.spark.s3 import get_s3_config
+from magic_pdf.libs.config_reader import get_s3_config_dict
 from magic_pdf.libs.commons import join_path, read_file, json_dump_path
 
 
@@ -16,7 +16,7 @@ def get_json_from_local_or_s3(book_name=None):
         # error_log_path & json_dump_path
         # 可配置从上述两个地址获取源json
         json_path = join_path(json_dump_path, book_name + ".json")
-        s3_config = get_s3_config(json_path)
+        s3_config = get_s3_config_dict(json_path)
         file_content = read_file(json_path, s3_config)
         json_str = file_content.decode("utf-8")
         # logger.info(json_str)

+ 33 - 4
demo/ocr_demo.py

@@ -1,17 +1,19 @@
 import json
 import os
+import sys
+import time
 
 from loguru import logger
 from pathlib import Path
 
-from magic_pdf.pipeline_ocr import ocr_parse_pdf_core
-from magic_pdf.spark.s3 import get_s3_config
+from magic_pdf.libs.config_reader import get_s3_config_dict
+from magic_pdf.pdf_parse_by_ocr import parse_pdf_by_ocr
 from demo.demo_commons import get_json_from_local_or_s3
 from magic_pdf.dict2md.ocr_mkcontent import (
     ocr_mk_mm_markdown_with_para,
     make_standard_format_with_para
 )
-from magic_pdf.libs.commons import join_path, read_file
+from magic_pdf.libs.commons import join_path, read_file, formatted_time
 
 
 def save_markdown(markdown_text, input_filepath):
@@ -50,7 +52,7 @@ def ocr_online_parse(book_name, start_page_id=0, debug_mode=True):
         json_object = get_json_from_local_or_s3(book_name)
         # logger.info(json_object)
         s3_pdf_path = json_object["file_location"]
-        s3_config = get_s3_config(s3_pdf_path)
+        s3_config = get_s3_config_dict(s3_pdf_path)
         pdf_bytes = read_file(s3_pdf_path, s3_config)
         ocr_pdf_model_info = json_object.get("doc_layout_result")
         ocr_parse_core(book_name, pdf_bytes, ocr_pdf_model_info)
@@ -83,6 +85,33 @@ def ocr_parse_core(book_name, pdf_bytes, ocr_pdf_model_info, start_page_id=0):
         f.write(json.dumps(standard_format, ensure_ascii=False))
 
 
+def ocr_parse_pdf_core(pdf_bytes, model_output_json_list, book_name, start_page_id=0, debug_mode=False):
+    start_time = time.time()  # 记录开始时间
+    # 先打印一下book_name和解析开始的时间
+    logger.info(
+        f"book_name is:{book_name},start_time is:{formatted_time(start_time)}",
+        file=sys.stderr,
+    )
+    pdf_info_dict = parse_pdf_by_ocr(
+        pdf_bytes,
+        model_output_json_list,
+        "",
+        book_name,
+        pdf_model_profile=None,
+        start_page_id=start_page_id,
+        debug_mode=debug_mode,
+    )
+    end_time = time.time()  # 记录完成时间
+    parse_time = int(end_time - start_time)  # 计算执行时间
+    # 解析完成后打印一下book_name和耗时
+    logger.info(
+        f"book_name is:{book_name},end_time is:{formatted_time(end_time)},cost_time is:{parse_time}",
+        file=sys.stderr,
+    )
+
+    return pdf_info_dict, parse_time
+
+
 if __name__ == '__main__':
     pdf_path = r"/home/cxu/workspace/Magic-PDF/ocr_demo/j.1540-627x.2006.00176.x.pdf"
     json_file_path = r"/home/cxu/workspace/Magic-PDF/ocr_demo/j.1540-627x.2006.00176.x.json"

+ 63 - 55
demo/text_demo.py

@@ -6,84 +6,93 @@ from pathlib import Path
 import click
 
 from demo.demo_commons import get_json_from_local_or_s3, write_json_to_local, local_jsonl_path, local_json_path
-from magic_pdf.dict2md.mkcontent import mk_mm_markdown
-from magic_pdf.pipeline import (
-    meta_scan,
-    classify_by_type,
-    parse_pdf,
-    pdf_intermediate_dict_to_markdown,
-    save_tables_to_s3,
-)
-from magic_pdf.libs.commons import join_path
-from loguru import logger
+from magic_pdf.dict2md.mkcontent import mk_mm_markdown, mk_universal_format
+from magic_pdf.filter.pdf_classify_by_type import classify
 
+from magic_pdf.filter.pdf_meta_scan import pdf_meta_scan
+from magic_pdf.libs.commons import join_path, read_file
+from loguru import logger
 
+from magic_pdf.libs.config_reader import get_s3_config_dict
+from magic_pdf.pdf_parse_by_txt import parse_pdf_by_txt
+from magic_pdf.spark.base import get_data_source
 
 
 def demo_parse_pdf(book_name=None, start_page_id=0, debug_mode=True):
     json_object = get_json_from_local_or_s3(book_name)
 
-    jso = parse_pdf(json_object, start_page_id=start_page_id, debug_mode=debug_mode)
-    logger.info(f"pdf_parse_time: {jso['parse_time']}")
-
-    write_json_to_local(jso, book_name)
-
-    jso_md = pdf_intermediate_dict_to_markdown(jso, debug_mode=debug_mode)
-    content = jso_md.get("content_list")
-    markdown_content = mk_mm_markdown(content)
+    s3_pdf_path = json_object.get("file_location")
+    s3_config = get_s3_config_dict(s3_pdf_path)
+    pdf_bytes = read_file(s3_pdf_path, s3_config)
+    model_output_json_list = json_object.get("doc_layout_result")
+    data_source = get_data_source(json_object)
+    file_id = json_object.get("file_id")
+    junk_img_bojids = json_object["pdf_meta"]["junk_img_bojids"]
+    save_path = ""
+    pdf_info_dict = parse_pdf_by_txt(
+        pdf_bytes,
+        model_output_json_list,
+        save_path,
+        f"{data_source}/{file_id}",
+        pdf_model_profile=None,
+        start_page_id=start_page_id,
+        junk_img_bojids=junk_img_bojids,
+        debug_mode=debug_mode,
+    )
+
+    write_json_to_local(pdf_info_dict, book_name)
+    content_list = mk_universal_format(pdf_info_dict)
+    markdown_content = mk_mm_markdown(content_list)
     if book_name is not None:
         save_tmp_path = os.path.join(os.path.dirname(__file__), "../..", "tmp", "unittest", "md", book_name)
-        uni_format_save_path = join_path(save_tmp_path,  "book" + ".json")
-        markdown_save_path = join_path(save_tmp_path,  "book" + ".md")
+        uni_format_save_path = join_path(save_tmp_path, "book" + ".json")
+        markdown_save_path = join_path(save_tmp_path, "book" + ".md")
         with open(uni_format_save_path, "w", encoding="utf-8") as f:
-            f.write(json.dumps(content, ensure_ascii=False, indent=4))
+            f.write(json.dumps(content_list, ensure_ascii=False, indent=4))
         with open(markdown_save_path, "w", encoding="utf-8") as f:
             f.write(markdown_content)
-            
-    else:
-        logger.info(json.dumps(content, ensure_ascii=False))
-
-
-def demo_save_tables(book_name=None, start_page_id=0, debug_mode=True):
-    json_object = get_json_from_local_or_s3(book_name)
 
-    jso = parse_pdf(json_object, start_page_id=start_page_id, debug_mode=debug_mode)
-    logger.info(f"pdf_parse_time: {jso['parse_time']}")
-
-    write_json_to_local(jso, book_name)
-
-    save_tables_to_s3(jso, debug_mode=debug_mode)
+    else:
+        logger.info(json.dumps(content_list, ensure_ascii=False))
 
 
 def demo_classify_by_type(book_name=None, debug_mode=True):
     json_object = get_json_from_local_or_s3(book_name)
 
-    jso = classify_by_type(json_object, debug_mode=debug_mode)
-
-    logger.info(json.dumps(jso, ensure_ascii=False))
-    logger.info(f"classify_time: {jso['classify_time']}")
-    write_json_to_local(jso, book_name)
+    pdf_meta = json_object.get("pdf_meta")
+    total_page = pdf_meta["total_page"]
+    page_width = pdf_meta["page_width_pts"]
+    page_height = pdf_meta["page_height_pts"]
+    img_sz_list = pdf_meta["image_info_per_page"]
+    img_num_list = pdf_meta["imgs_per_page"]
+    text_len_list = pdf_meta["text_len_per_page"]
+    text_layout_list = pdf_meta["text_layout_per_page"]
+    pdf_path = json_object.get("file_location")
+    is_text_pdf, results = classify(
+        pdf_path,
+        total_page,
+        page_width,
+        page_height,
+        img_sz_list,
+        text_len_list,
+        img_num_list,
+        text_layout_list,
+    )
+    logger.info(f"is_text_pdf: {is_text_pdf}")
+    logger.info(json.dumps(results, ensure_ascii=False))
+    write_json_to_local(results, book_name)
 
 
 def demo_meta_scan(book_name=None, debug_mode=True):
     json_object = get_json_from_local_or_s3(book_name)
 
-    # doc_layout_check=False
-    jso = meta_scan(json_object, doc_layout_check=True)
+    s3_pdf_path = json_object.get("file_location")
+    s3_config = get_s3_config_dict(s3_pdf_path)
+    pdf_bytes = read_file(s3_pdf_path, s3_config)
+    res = pdf_meta_scan(s3_pdf_path, pdf_bytes)
 
-    logger.info(json.dumps(jso, ensure_ascii=False))
-    logger.info(f"meta_scan_time: {jso['meta_scan_time']}")
-    write_json_to_local(jso, book_name)
-
-
-def demo_meta_scan_from_jsonl():
-    with open(local_jsonl_path, "r", encoding="utf-8") as jsonl_file:
-        for line in jsonl_file:
-            jso = json.loads(line)
-            jso = meta_scan(jso)
-            logger.info(f"pdf_path: {jso['content']['pdf_path']}")
-            logger.info(f"read_file_time: {jso['read_file_time']}")
-            logger.info(f"meta_scan_time: {jso['meta_scan_time']}")
+    logger.info(json.dumps(res, ensure_ascii=False))
+    write_json_to_local(res, book_name)
 
 
 def demo_test5():
@@ -94,7 +103,6 @@ def demo_test5():
     logger.info(f"img_list_len: {img_list_len}")
 
 
-
 def read_more_para_test_samples(type="scihub"):
     # 读取多段落测试样本
     curr_dir = Path(__file__).parent

+ 15 - 3
magic_pdf/libs/config_reader.py

@@ -7,6 +7,8 @@ import os
 
 from loguru import logger
 
+from magic_pdf.libs.commons import parse_bucket_key
+
 
 def get_s3_config(bucket_name: str):
     """
@@ -25,9 +27,9 @@ def get_s3_config(bucket_name: str):
 
     bucket_info = config.get("bucket_info")
     if bucket_name not in bucket_info:
-        raise Exception("bucket_name not found in magic-pdf.json")
-
-    access_key, secret_key, storage_endpoint = bucket_info[bucket_name]
+        access_key, secret_key, storage_endpoint = bucket_info["[default]"]
+    else:
+        access_key, secret_key, storage_endpoint = bucket_info[bucket_name]
 
     if access_key is None or secret_key is None or storage_endpoint is None:
         raise Exception("ak, sk or endpoint not found in magic-pdf.json")
@@ -37,5 +39,15 @@ def get_s3_config(bucket_name: str):
     return access_key, secret_key, storage_endpoint
 
 
+def get_s3_config_dict(path: str):
+    access_key, secret_key, storage_endpoint = get_s3_config(get_bucket_name(path))
+    return {"ak": access_key, "sk": secret_key, "endpoint": storage_endpoint}
+
+
+def get_bucket_name(path):
+    bucket, key = parse_bucket_key(path)
+    return bucket
+
+
 if __name__ == '__main__':
     ak, sk, endpoint = get_s3_config("llm-raw")

+ 1 - 81
magic_pdf/pipeline.py → magic_pdf/pipeline.bak

@@ -21,7 +21,7 @@ from loguru import logger
 from magic_pdf.pdf_parse_for_train import parse_pdf_for_train
 from magic_pdf.spark.base import exception_handler, get_data_source
 from magic_pdf.train_utils.convert_to_train_format import convert_to_train_format
-from magic_pdf.spark.s3 import get_s3_config, get_s3_client
+from magic_pdf.spark.s3 import get_s3_config
 
 
 
@@ -161,86 +161,6 @@ def classify_by_type(jso: dict, debug_mode=False) -> dict:
     return jso
 
 
-def save_tables_to_s3(jso: dict, debug_mode=False) -> dict:
-
-    if debug_mode:
-        pass
-    else:  # 如果debug没开,则检测是否有needdrop字段
-        if jso.get("need_drop", False):
-            logger.info(
-                f"book_name is:{get_data_source(jso)}/{jso['file_id']} need drop",
-                file=sys.stderr,
-            )
-            jso["dropped"] = True
-            return jso
-    try:
-        data_source = get_data_source(jso)
-        file_id = jso.get("file_id")
-        book_name = f"{data_source}/{file_id}"
-        title = jso.get("title")
-        url_encode_title = quote(title, safe="")
-        if data_source != "scihub":
-            return jso
-        pdf_intermediate_dict = jso["pdf_intermediate_dict"]
-        # 将 pdf_intermediate_dict 解压
-        pdf_intermediate_dict = JsonCompressor.decompress_json(pdf_intermediate_dict)
-        i = 0
-        for page in pdf_intermediate_dict.values():
-            if page.get("tables"):
-                if len(page["tables"]) > 0:
-                    j = 0
-                    for table in page["tables"]:
-                        if debug_mode:
-                            image_path = join_path(
-                                "s3://mllm-raw-media/pdf2md_img/",
-                                book_name,
-                                table["image_path"],
-                            )
-                        else:
-                            image_path = join_path(
-                                "s3://mllm-raw-media/pdf2md_img/", table["image_path"]
-                            )
-
-                        if image_path.endswith(".jpg"):
-                            j += 1
-                            s3_client = get_s3_client(image_path)
-                            bucket_name, bucket_key = parse_bucket_key(image_path)
-                            # 通过s3_client获取图片到内存
-                            image_bytes = s3_client.get_object(
-                                Bucket=bucket_name, Key=bucket_key
-                            )["Body"].read()
-                            # 保存图片到新的位置
-                            if debug_mode:
-                                new_image_path = join_path(
-                                    "s3://mllm-raw-media/pdf2md_img/table_new/",
-                                    url_encode_title
-                                    + "_"
-                                    + table["image_path"].lstrip("tables/"),
-                                )
-                            else:
-                                new_image_path = join_path(
-                                    "s3://mllm-raw-media/pdf2md_img/table_new/",
-                                    url_encode_title + f"_page{i}_{j}.jpg",
-                                )
-
-                            logger.info(new_image_path, file=sys.stderr)
-                            bucket_name, bucket_key = parse_bucket_key(new_image_path)
-                            s3_client.put_object(
-                                Bucket=bucket_name, Key=bucket_key, Body=image_bytes
-                            )
-                        else:
-                            continue
-            i += 1
-
-        # 把无用的信息清空
-        jso["doc_layout_result"] = ""
-        jso["pdf_intermediate_dict"] = ""
-        jso["pdf_meta"] = ""
-    except Exception as e:
-        jso = exception_handler(jso, e)
-    return jso
-
-
 def drop_needdrop_pdf(jso: dict) -> dict:
     if jso.get("need_drop", False):
         logger.info(

+ 0 - 0
magic_pdf/pipeline_ocr.py → magic_pdf/pipeline_ocr.bak


+ 0 - 0
magic_pdf/pipeline_txt.py → magic_pdf/pipeline_txt.bak


+ 0 - 11
magic_pdf/spark/base.py

@@ -1,11 +1,7 @@
-
 from loguru import logger
 
-from magic_pdf.libs.commons import read_file
 from magic_pdf.libs.drop_reason import DropReason
 
-from magic_pdf.spark.s3 import get_s3_config
-
 
 def get_data_source(jso: dict):
     data_source = jso.get("data_source")
@@ -41,10 +37,3 @@ def get_bookname(jso: dict):
     file_id = jso.get("file_id")
     book_name = f"{data_source}/{file_id}"
     return book_name
-
-
-def get_pdf_bytes(jso: dict):
-    pdf_s3_path = jso.get("file_location")
-    s3_config = get_s3_config(pdf_s3_path)
-    pdf_bytes = read_file(pdf_s3_path, s3_config)
-    return pdf_bytes

+ 17 - 0
magic_pdf/spark/s3.bak

@@ -0,0 +1,17 @@
+import re
+from magic_pdf.libs.config_reader import get_s3_config_dict
+
+__re_s3_path = re.compile("^s3a?://([^/]+)(?:/(.*))?$")
+
+
+def get_s3_config(path):
+    bucket_name = split_s3_path(path)[0] if path else ""
+    return get_s3_config_dict(bucket_name)
+
+
+def split_s3_path(path: str):
+    "split bucket and key from path"
+    m = __re_s3_path.match(path)
+    if m is None:
+        return "", ""
+    return m.group(1), (m.group(2) or "")

+ 0 - 86
magic_pdf/spark/s3.py

@@ -1,86 +0,0 @@
-# from app.common import s3
-import boto3
-from botocore.client import Config
-
-import re
-import random
-from typing import List, Union
-try:
-    from app.config import s3_buckets, s3_clusters, s3_users # TODO delete 循环依赖
-    from app.common.runtime import get_cluster_name
-except ImportError:
-    from magic_pdf.config import s3_buckets, s3_clusters, get_cluster_name, s3_users
-
-__re_s3_path = re.compile("^s3a?://([^/]+)(?:/(.*))?$")
-def get_s3_config(path: Union[str, List[str]], outside=False):
-    paths = [path] if type(path) == str else path
-    bucket_config = None
-    for p in paths:
-        bc = __get_s3_bucket_config(p)
-        if bucket_config in [bc, None]:
-            bucket_config = bc
-            continue
-        raise Exception(f"{paths} have different s3 config, cannot read together.")
-    if not bucket_config:
-        raise Exception("path is empty.")
-    return __get_s3_config(bucket_config, outside, prefer_ip=True)
-
-def __get_s3_config(
-    bucket_config: tuple,
-    outside: bool,
-    prefer_ip=False,
-    prefer_auto=False,
-):
-    cluster, user = bucket_config
-    cluster_config = s3_clusters[cluster]
-
-    if outside:
-        endpoint_key = "outside"
-    elif prefer_auto and "auto" in cluster_config:
-        endpoint_key = "auto"
-    elif cluster_config.get("cluster") == get_cluster_name():
-        endpoint_key = "inside"
-    else:
-        endpoint_key = "outside"
-
-    if prefer_ip and f"{endpoint_key}_ips" in cluster_config:
-        endpoint_key = f"{endpoint_key}_ips"
-
-    endpoints = cluster_config[endpoint_key]
-    endpoint = random.choice(endpoints)
-    return {"endpoint": endpoint, **s3_users[user]}
-
-def split_s3_path(path: str):
-    "split bucket and key from path"
-    m = __re_s3_path.match(path)
-    if m is None:
-        return "", ""
-    return m.group(1), (m.group(2) or "")
-
-def __get_s3_bucket_config(path: str):
-    bucket = split_s3_path(path)[0] if path else ""
-    bucket_config = s3_buckets.get(bucket)
-    if not bucket_config:
-        bucket_config = s3_buckets.get("[default]")
-        assert bucket_config is not None
-    return bucket_config
-
-def get_s3_client(path: Union[str, List[str]], outside=False):
-    s3_config = get_s3_config(path, outside)
-    try:
-        return boto3.client(
-            "s3",
-            aws_access_key_id=s3_config["ak"],
-            aws_secret_access_key=s3_config["sk"],
-            endpoint_url=s3_config["endpoint"],
-            config=Config(s3={"addressing_style": "path"}, retries={"max_attempts": 8, "mode": "standard"}),
-        )
-    except:
-        # older boto3 do not support retries.mode param.
-        return boto3.client(
-            "s3",
-            aws_access_key_id=s3_config["ak"],
-            aws_secret_access_key=s3_config["sk"],
-            endpoint_url=s3_config["endpoint"],
-            config=Config(s3={"addressing_style": "path"}, retries={"max_attempts": 8}),
-        )

+ 29 - 4
tests/test_commons.py

@@ -1,9 +1,13 @@
 import io
 import json
 import os
+
+import boto3
+from botocore.config import Config
+
 from magic_pdf.libs.commons import fitz
+from magic_pdf.libs.config_reader import get_s3_config_dict
 
-from magic_pdf.spark.s3 import get_s3_config, get_s3_client
 from magic_pdf.libs.commons import join_path, json_dump_path, read_file, parse_bucket_key
 from loguru import logger
 
@@ -12,7 +16,7 @@ test_pdf_dir_path = "s3://llm-pdf-text/unittest/pdf/"
 
 def get_test_pdf_json(book_name):
     json_path = join_path(json_dump_path, book_name + ".json")
-    s3_config = get_s3_config(json_path)
+    s3_config = get_s3_config_dict(json_path)
     file_content = read_file(json_path, s3_config)
     json_str = file_content.decode('utf-8')
     json_object = json.loads(json_str)
@@ -21,7 +25,7 @@ def get_test_pdf_json(book_name):
 
 def read_test_file(book_name):
     test_pdf_path = join_path(test_pdf_dir_path, book_name + ".pdf")
-    s3_config = get_s3_config(test_pdf_path)
+    s3_config = get_s3_config_dict(test_pdf_path)
     try:
         file_content = read_file(test_pdf_path, s3_config)
         return file_content
@@ -31,7 +35,7 @@ def read_test_file(book_name):
             try:
                 json_object = get_test_pdf_json(book_name)
                 orig_s3_pdf_path = json_object.get('file_location')
-                s3_config = get_s3_config(orig_s3_pdf_path)
+                s3_config = get_s3_config_dict(orig_s3_pdf_path)
                 file_content = read_file(orig_s3_pdf_path, s3_config)
                 s3_client = get_s3_client(test_pdf_path)
                 bucket_name, bucket_key = parse_bucket_key(test_pdf_path)
@@ -53,3 +57,24 @@ def get_test_json_data(directory_path, json_file_name):
     with open(os.path.join(directory_path, json_file_name), "r", encoding='utf-8') as f:
         test_data = json.load(f)
     return test_data
+
+
+def get_s3_client(path):
+    s3_config = get_s3_config_dict(path)
+    try:
+        return boto3.client(
+            "s3",
+            aws_access_key_id=s3_config["ak"],
+            aws_secret_access_key=s3_config["sk"],
+            endpoint_url=s3_config["endpoint"],
+            config=Config(s3={"addressing_style": "path"}, retries={"max_attempts": 8, "mode": "standard"}),
+        )
+    except:
+        # older boto3 do not support retries.mode param.
+        return boto3.client(
+            "s3",
+            aws_access_key_id=s3_config["ak"],
+            aws_secret_access_key=s3_config["sk"],
+            endpoint_url=s3_config["endpoint"],
+            config=Config(s3={"addressing_style": "path"}, retries={"max_attempts": 8}),
+        )