Prechádzať zdrojové kódy

feat(draw_bbox): add layout sorting visualization

Implement a new function `draw_layout_sort_bbox` in `draw_bbox.py` to visualize the
layout sorting results using the `LayoutLMv3ForTokenClassification` model. This function
predicts the order of layout elements and draws them in the sorted sequence on the PDF pages.
myhloli 1 rok pred
rodič
commit
3cbcf2ded0

+ 59 - 5
magic_pdf/libs/draw_bbox.py

@@ -1,3 +1,5 @@
+import time
+
 from magic_pdf.libs.commons import fitz  # PyMuPDF
 from magic_pdf.libs.Constants import CROSS_PAGE
 from magic_pdf.libs.ocr_content_type import BlockType, CategoryId, ContentType
@@ -211,9 +213,9 @@ def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename):
         # 构造其余useful_list
         for block in page['para_blocks']:
             if block['type'] in [
-                    BlockType.Text,
-                    BlockType.Title,
-                    BlockType.InterlineEquation,
+                BlockType.Text,
+                BlockType.Title,
+                BlockType.InterlineEquation,
             ]:
                 for line in block['lines']:
                     for span in line['spans']:
@@ -244,7 +246,7 @@ def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename):
     pdf_docs.save(f'{out_path}/{filename}_spans.pdf')
 
 
-def drow_model_bbox(model_list: list, pdf_bytes, out_path, filename):
+def draw_model_bbox(model_list: list, pdf_bytes, out_path, filename):
     dropped_bbox_list = []
     tables_body_list, tables_caption_list, tables_footnote_list = [], [], []
     imgs_body_list, imgs_caption_list, imgs_footnote_list = [], [], []
@@ -279,7 +281,7 @@ def drow_model_bbox(model_list: list, pdf_bytes, out_path, filename):
             elif layout_det['category_id'] == CategoryId.ImageCaption:
                 imgs_caption.append(bbox)
             elif layout_det[
-                    'category_id'] == CategoryId.InterlineEquation_YOLO:
+                'category_id'] == CategoryId.InterlineEquation_YOLO:
                 interequations.append(bbox)
             elif layout_det['category_id'] == CategoryId.Abandon:
                 page_dropped_list.append(bbox)
@@ -316,3 +318,55 @@ def drow_model_bbox(model_list: list, pdf_bytes, out_path, filename):
 
     # Save the PDF
     pdf_docs.save(f'{out_path}/{filename}_model.pdf')
+
+
+from typing import List
+
+
+def do_predict(boxes: List[List[int]]) -> List[int]:
+    from transformers import LayoutLMv3ForTokenClassification
+    from magic_pdf.v3.helpers import prepare_inputs, boxes2inputs, parse_logits
+    model = LayoutLMv3ForTokenClassification.from_pretrained("hantian/layoutreader")
+    inputs = boxes2inputs(boxes)
+    inputs = prepare_inputs(inputs, model)
+    logits = model(**inputs).logits.cpu().squeeze(0)
+    return parse_logits(logits, len(boxes))
+
+
+def draw_layout_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
+    layout_bbox_list = []
+
+    from loguru import logger
+    for page in pdf_info:
+        page_layout_list = []
+        for block in page['para_blocks']:
+            bbox = block['bbox']
+            page_layout_list.append(bbox)
+
+        # 使用layoutreader排序
+        page_size = page['page_size']
+        x_scale = 1000.0 / page_size[0]
+        y_scale = 1000.0 / page_size[1]
+        boxes = []
+        logger.info(f"Scale: {x_scale}, {y_scale}, Boxes len: {len(page_layout_list)}")
+        for left, top, right, bottom in page_layout_list:
+            left = round(left * x_scale)
+            top = round(top * y_scale)
+            right = round(right * x_scale)
+            bottom = round(bottom * y_scale)
+            assert (
+                    1000 >= right >= left >= 0 and 1000 >= bottom >= top >= 0
+            ), f"Invalid box. right: {right}, left: {left}, bottom: {bottom}, top: {top}"
+            boxes.append([left, top, right, bottom])
+        logger.info("layoutreader start")
+        start = time.time()
+        orders = do_predict(boxes)
+        print(orders)
+        logger.info(f"layoutreader end, cos time{time.time() - start}")
+        sorted_bboxes = [page_layout_list[i] for i in orders]
+        layout_bbox_list.append(sorted_bboxes)
+    pdf_docs = fitz.open('pdf', pdf_bytes)
+    for i, page in enumerate(pdf_docs):
+        draw_bbox_with_number(i, layout_bbox_list, page, [102, 102, 255], False)
+
+    pdf_docs.save(f'{out_path}/{filename}_layout_sort.pdf')

+ 4 - 2
magic_pdf/tools/common.py

@@ -7,7 +7,7 @@ from loguru import logger
 
 import magic_pdf.model as model_config
 from magic_pdf.libs.draw_bbox import (draw_layout_bbox, draw_span_bbox,
-                                      drow_model_bbox)
+                                      draw_model_bbox, draw_layout_sort_bbox)
 from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode
 from magic_pdf.pipe.OCRPipe import OCRPipe
 from magic_pdf.pipe.TXTPipe import TXTPipe
@@ -90,7 +90,9 @@ def do_parse(
     if f_draw_span_bbox:
         draw_span_bbox(pdf_info, pdf_bytes, local_md_dir, pdf_file_name)
     if f_draw_model_bbox:
-        drow_model_bbox(copy.deepcopy(orig_model_list), pdf_bytes, local_md_dir, pdf_file_name)
+        draw_model_bbox(copy.deepcopy(orig_model_list), pdf_bytes, local_md_dir, pdf_file_name)
+
+    draw_layout_sort_bbox(pdf_info, pdf_bytes, local_md_dir, pdf_file_name)
 
     md_content = pipe.pipe_mk_markdown(image_dir,
                                        drop_mode=DropMode.NONE,

+ 46 - 0
magic_pdf/v3/ds_config.json

@@ -0,0 +1,46 @@
+{
+  "fp16": {
+    "enabled": "auto",
+    "loss_scale": 0,
+    "loss_scale_window": 1000,
+    "initial_scale_power": 16,
+    "hysteresis": 2,
+    "min_loss_scale": 1
+  },
+  "bf16": {
+    "enabled": "auto"
+  },
+  "optimizer": {
+    "type": "AdamW",
+    "params": {
+      "lr": "auto",
+      "betas": "auto",
+      "eps": "auto",
+      "weight_decay": "auto"
+    }
+  },
+  "scheduler": {
+    "type": "WarmupDecayLR",
+    "params": {
+      "warmup_min_lr": "auto",
+      "warmup_max_lr": "auto",
+      "warmup_num_steps": "auto",
+      "total_num_steps": "auto"
+    }
+  },
+  "zero_optimization": {
+    "stage": 2,
+    "allgather_partitions": true,
+    "allgather_bucket_size": 2e8,
+    "overlap_comm": true,
+    "reduce_scatter": true,
+    "reduce_bucket_size": 2e8,
+    "contiguous_gradients": true
+  },
+  "gradient_accumulation_steps": "auto",
+  "gradient_clipping": "auto",
+  "steps_per_print": 2000,
+  "train_batch_size": "auto",
+  "train_micro_batch_size_per_gpu": "auto",
+  "wall_clock_breakdown": false
+}

+ 86 - 0
magic_pdf/v3/eval.py

@@ -0,0 +1,86 @@
+import gzip
+import json
+
+import torch
+import typer
+from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
+from tqdm import tqdm
+from transformers import LayoutLMv3ForTokenClassification
+
+from helpers import (
+    DataCollator,
+    check_duplicate,
+    MAX_LEN,
+    parse_logits,
+    prepare_inputs,
+)
+
+app = typer.Typer()
+
+chen_cherry = SmoothingFunction()
+
+
+@app.command()
+def main(
+    input_file: str = typer.Argument(..., help="input file"),
+    model_path: str = typer.Argument(..., help="model path"),
+    batch_size: int = typer.Option(16, help="batch size"),
+):
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    model = (
+        LayoutLMv3ForTokenClassification.from_pretrained(model_path, num_labels=MAX_LEN)
+        .bfloat16()
+        .to(device)
+        .eval()
+    )
+    data_collator = DataCollator()
+    if torch.cuda.is_available():
+        torch.cuda.empty_cache()
+
+    datasets = []
+    with gzip.open(input_file, "rt") as f:
+        for line in tqdm(f):
+            datasets.append(json.loads(line))
+    # make batch faster
+    datasets.sort(key=lambda x: len(x["source_boxes"]), reverse=True)
+
+    total = 0
+    total_out_idx = 0.0
+    total_out_token = 0.0
+    for i in tqdm(range(0, len(datasets), batch_size)):
+        batch = datasets[i : i + batch_size]
+        model_inputs = data_collator(batch)
+        model_inputs = prepare_inputs(model_inputs, model)
+        # forward
+        with torch.no_grad():
+            model_outputs = model(**model_inputs)
+        logits = model_outputs.logits.cpu()
+        for data, logit in zip(batch, logits):
+            target_index = data["target_index"][:MAX_LEN]
+            pred_index = parse_logits(logit, len(target_index))
+            assert len(pred_index) == len(target_index)
+            assert not check_duplicate(pred_index)
+            target_texts = data["target_texts"][:MAX_LEN]
+            source_texts = data["source_texts"][:MAX_LEN]
+            pred_texts = []
+            for idx in pred_index:
+                pred_texts.append(source_texts[idx])
+            total += 1
+            total_out_idx += sentence_bleu(
+                [target_index],
+                [i + 1 for i in pred_index],
+                smoothing_function=chen_cherry.method2,
+            )
+            total_out_token += sentence_bleu(
+                [" ".join(target_texts).split()],
+                " ".join(pred_texts).split(),
+                smoothing_function=chen_cherry.method2,
+            )
+
+    print("total: ", total)
+    print("out_idx: ", round(100 * total_out_idx / total, 1))
+    print("out_token: ", round(100 * total_out_token / total, 1))
+
+
+if __name__ == "__main__":
+    app()

+ 125 - 0
magic_pdf/v3/helpers.py

@@ -0,0 +1,125 @@
+from collections import defaultdict
+from typing import List, Dict
+
+import torch
+from transformers import LayoutLMv3ForTokenClassification
+
+MAX_LEN = 510
+CLS_TOKEN_ID = 0
+UNK_TOKEN_ID = 3
+EOS_TOKEN_ID = 2
+
+
+class DataCollator:
+    def __call__(self, features: List[dict]) -> Dict[str, torch.Tensor]:
+        bbox = []
+        labels = []
+        input_ids = []
+        attention_mask = []
+
+        # clip bbox and labels to max length, build input_ids and attention_mask
+        for feature in features:
+            _bbox = feature["source_boxes"]
+            if len(_bbox) > MAX_LEN:
+                _bbox = _bbox[:MAX_LEN]
+            _labels = feature["target_index"]
+            if len(_labels) > MAX_LEN:
+                _labels = _labels[:MAX_LEN]
+            _input_ids = [UNK_TOKEN_ID] * len(_bbox)
+            _attention_mask = [1] * len(_bbox)
+            assert len(_bbox) == len(_labels) == len(_input_ids) == len(_attention_mask)
+            bbox.append(_bbox)
+            labels.append(_labels)
+            input_ids.append(_input_ids)
+            attention_mask.append(_attention_mask)
+
+        # add CLS and EOS tokens
+        for i in range(len(bbox)):
+            bbox[i] = [[0, 0, 0, 0]] + bbox[i] + [[0, 0, 0, 0]]
+            labels[i] = [-100] + labels[i] + [-100]
+            input_ids[i] = [CLS_TOKEN_ID] + input_ids[i] + [EOS_TOKEN_ID]
+            attention_mask[i] = [1] + attention_mask[i] + [1]
+
+        # padding to max length
+        max_len = max(len(x) for x in bbox)
+        for i in range(len(bbox)):
+            bbox[i] = bbox[i] + [[0, 0, 0, 0]] * (max_len - len(bbox[i]))
+            labels[i] = labels[i] + [-100] * (max_len - len(labels[i]))
+            input_ids[i] = input_ids[i] + [EOS_TOKEN_ID] * (max_len - len(input_ids[i]))
+            attention_mask[i] = attention_mask[i] + [0] * (
+                max_len - len(attention_mask[i])
+            )
+
+        ret = {
+            "bbox": torch.tensor(bbox),
+            "attention_mask": torch.tensor(attention_mask),
+            "labels": torch.tensor(labels),
+            "input_ids": torch.tensor(input_ids),
+        }
+        # set label > MAX_LEN to -100, because original labels may be > MAX_LEN
+        ret["labels"][ret["labels"] > MAX_LEN] = -100
+        # set label > 0 to label-1, because original labels are 1-indexed
+        ret["labels"][ret["labels"] > 0] -= 1
+        return ret
+
+
+def boxes2inputs(boxes: List[List[int]]) -> Dict[str, torch.Tensor]:
+    bbox = [[0, 0, 0, 0]] + boxes + [[0, 0, 0, 0]]
+    input_ids = [CLS_TOKEN_ID] + [UNK_TOKEN_ID] * len(boxes) + [EOS_TOKEN_ID]
+    attention_mask = [1] + [1] * len(boxes) + [1]
+    return {
+        "bbox": torch.tensor([bbox]),
+        "attention_mask": torch.tensor([attention_mask]),
+        "input_ids": torch.tensor([input_ids]),
+    }
+
+
+def prepare_inputs(
+    inputs: Dict[str, torch.Tensor], model: LayoutLMv3ForTokenClassification
+) -> Dict[str, torch.Tensor]:
+    ret = {}
+    for k, v in inputs.items():
+        v = v.to(model.device)
+        if torch.is_floating_point(v):
+            v = v.to(model.dtype)
+        ret[k] = v
+    return ret
+
+
+def parse_logits(logits: torch.Tensor, length: int) -> List[int]:
+    """
+    parse logits to orders
+
+    :param logits: logits from model
+    :param length: input length
+    :return: orders
+    """
+    logits = logits[1 : length + 1, :length]
+    orders = logits.argsort(descending=False).tolist()
+    ret = [o.pop() for o in orders]
+    while True:
+        order_to_idxes = defaultdict(list)
+        for idx, order in enumerate(ret):
+            order_to_idxes[order].append(idx)
+        # filter idxes len > 1
+        order_to_idxes = {k: v for k, v in order_to_idxes.items() if len(v) > 1}
+        if not order_to_idxes:
+            break
+        # filter
+        for order, idxes in order_to_idxes.items():
+            # find original logits of idxes
+            idxes_to_logit = {}
+            for idx in idxes:
+                idxes_to_logit[idx] = logits[idx, order]
+            idxes_to_logit = sorted(
+                idxes_to_logit.items(), key=lambda x: x[1], reverse=True
+            )
+            # keep the highest logit as order, set others to next candidate
+            for idx, _ in idxes_to_logit[1:]:
+                ret[idx] = orders[idx].pop()
+
+    return ret
+
+
+def check_duplicate(a: List[int]) -> bool:
+    return len(a) != len(set(a))

+ 67 - 0
magic_pdf/v3/train.py

@@ -0,0 +1,67 @@
+import os
+from dataclasses import dataclass, field
+
+from datasets import load_dataset, Dataset
+from loguru import logger
+from transformers import (
+    TrainingArguments,
+    HfArgumentParser,
+    LayoutLMv3ForTokenClassification,
+    set_seed,
+)
+from transformers.trainer import Trainer
+
+from helpers import DataCollator, MAX_LEN
+
+
+@dataclass
+class Arguments(TrainingArguments):
+    model_dir: str = field(
+        default=None,
+        metadata={"help": "Path to model, based on `microsoft/layoutlmv3-base`"},
+    )
+    dataset_dir: str = field(
+        default=None,
+        metadata={"help": "Path to dataset"},
+    )
+
+
+def load_train_and_dev_dataset(path: str) -> (Dataset, Dataset):
+    datasets = load_dataset(
+        "json",
+        data_files={
+            "train": os.path.join(path, "train.jsonl.gz"),
+            "dev": os.path.join(path, "dev.jsonl.gz"),
+        },
+    )
+    return datasets["train"], datasets["dev"]
+
+
+def main():
+    parser = HfArgumentParser((Arguments,))
+    args: Arguments = parser.parse_args_into_dataclasses()[0]
+    set_seed(args.seed)
+
+    train_dataset, dev_dataset = load_train_and_dev_dataset(args.dataset_dir)
+    logger.info(
+        "Train dataset size: {}, Dev dataset size: {}".format(
+            len(train_dataset), len(dev_dataset)
+        )
+    )
+
+    model = LayoutLMv3ForTokenClassification.from_pretrained(
+        args.model_dir, num_labels=MAX_LEN, visual_embed=False
+    )
+    data_collator = DataCollator()
+    trainer = Trainer(
+        model=model,
+        args=args,
+        train_dataset=train_dataset,
+        eval_dataset=dev_dataset,
+        data_collator=data_collator,
+    )
+    trainer.train()
+
+
+if __name__ == "__main__":
+    main()

+ 32 - 0
magic_pdf/v3/train.sh

@@ -0,0 +1,32 @@
+#!/usr/bin/env bash
+
+set -x
+set -e
+
+DIR="$( cd "$( dirname "$0" )" && cd .. && pwd )"
+OUTPUT_DIR="${DIR}/checkpoint/v3/$(date +%F-%H)"
+DATA_DIR="${DIR}/ReadingBank/"
+
+mkdir -p "${OUTPUT_DIR}"
+
+deepspeed train.py \
+  --model_dir 'microsoft/layoutlmv3-large' \
+  --dataset_dir "${DATA_DIR}" \
+  --dataloader_num_workers 1 \
+  --deepspeed ds_config.json \
+  --per_device_train_batch_size 32 \
+  --per_device_eval_batch_size 64 \
+  --do_train \
+  --do_eval \
+  --logging_steps 100 \
+  --bf16 \
+  --seed 42 \
+  --num_train_epochs 10 \
+  --learning_rate 5e-5 \
+  --warmup_steps 1000 \
+  --save_strategy epoch \
+  --evaluation_strategy epoch \
+  --remove_unused_columns False \
+  --output_dir "${OUTPUT_DIR}" \
+  --overwrite_output_dir \
+  "$@"