| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586 |
- import gzip
- import json
- import torch
- import typer
- from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
- from tqdm import tqdm
- from transformers import LayoutLMv3ForTokenClassification
- from helpers import (
- DataCollator,
- check_duplicate,
- MAX_LEN,
- parse_logits,
- prepare_inputs,
- )
- app = typer.Typer()
- chen_cherry = SmoothingFunction()
- @app.command()
- def main(
- input_file: str = typer.Argument(..., help="input file"),
- model_path: str = typer.Argument(..., help="model path"),
- batch_size: int = typer.Option(16, help="batch size"),
- ):
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- model = (
- LayoutLMv3ForTokenClassification.from_pretrained(model_path, num_labels=MAX_LEN)
- .bfloat16()
- .to(device)
- .eval()
- )
- data_collator = DataCollator()
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
- datasets = []
- with gzip.open(input_file, "rt") as f:
- for line in tqdm(f):
- datasets.append(json.loads(line))
- # make batch faster
- datasets.sort(key=lambda x: len(x["source_boxes"]), reverse=True)
- total = 0
- total_out_idx = 0.0
- total_out_token = 0.0
- for i in tqdm(range(0, len(datasets), batch_size)):
- batch = datasets[i : i + batch_size]
- model_inputs = data_collator(batch)
- model_inputs = prepare_inputs(model_inputs, model)
- # forward
- with torch.no_grad():
- model_outputs = model(**model_inputs)
- logits = model_outputs.logits.cpu()
- for data, logit in zip(batch, logits):
- target_index = data["target_index"][:MAX_LEN]
- pred_index = parse_logits(logit, len(target_index))
- assert len(pred_index) == len(target_index)
- assert not check_duplicate(pred_index)
- target_texts = data["target_texts"][:MAX_LEN]
- source_texts = data["source_texts"][:MAX_LEN]
- pred_texts = []
- for idx in pred_index:
- pred_texts.append(source_texts[idx])
- total += 1
- total_out_idx += sentence_bleu(
- [target_index],
- [i + 1 for i in pred_index],
- smoothing_function=chen_cherry.method2,
- )
- total_out_token += sentence_bleu(
- [" ".join(target_texts).split()],
- " ".join(pred_texts).split(),
- smoothing_function=chen_cherry.method2,
- )
- print("total: ", total)
- print("out_idx: ", round(100 * total_out_idx / total, 1))
- print("out_token: ", round(100 * total_out_token / total, 1))
- if __name__ == "__main__":
- app()
|