eval.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. import gzip
  2. import json
  3. import torch
  4. import typer
  5. from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
  6. from tqdm import tqdm
  7. from transformers import LayoutLMv3ForTokenClassification
  8. from helpers import (
  9. DataCollator,
  10. check_duplicate,
  11. MAX_LEN,
  12. parse_logits,
  13. prepare_inputs,
  14. )
  15. app = typer.Typer()
  16. chen_cherry = SmoothingFunction()
  17. @app.command()
  18. def main(
  19. input_file: str = typer.Argument(..., help="input file"),
  20. model_path: str = typer.Argument(..., help="model path"),
  21. batch_size: int = typer.Option(16, help="batch size"),
  22. ):
  23. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  24. model = (
  25. LayoutLMv3ForTokenClassification.from_pretrained(model_path, num_labels=MAX_LEN)
  26. .bfloat16()
  27. .to(device)
  28. .eval()
  29. )
  30. data_collator = DataCollator()
  31. if torch.cuda.is_available():
  32. torch.cuda.empty_cache()
  33. datasets = []
  34. with gzip.open(input_file, "rt") as f:
  35. for line in tqdm(f):
  36. datasets.append(json.loads(line))
  37. # make batch faster
  38. datasets.sort(key=lambda x: len(x["source_boxes"]), reverse=True)
  39. total = 0
  40. total_out_idx = 0.0
  41. total_out_token = 0.0
  42. for i in tqdm(range(0, len(datasets), batch_size)):
  43. batch = datasets[i : i + batch_size]
  44. model_inputs = data_collator(batch)
  45. model_inputs = prepare_inputs(model_inputs, model)
  46. # forward
  47. with torch.no_grad():
  48. model_outputs = model(**model_inputs)
  49. logits = model_outputs.logits.cpu()
  50. for data, logit in zip(batch, logits):
  51. target_index = data["target_index"][:MAX_LEN]
  52. pred_index = parse_logits(logit, len(target_index))
  53. assert len(pred_index) == len(target_index)
  54. assert not check_duplicate(pred_index)
  55. target_texts = data["target_texts"][:MAX_LEN]
  56. source_texts = data["source_texts"][:MAX_LEN]
  57. pred_texts = []
  58. for idx in pred_index:
  59. pred_texts.append(source_texts[idx])
  60. total += 1
  61. total_out_idx += sentence_bleu(
  62. [target_index],
  63. [i + 1 for i in pred_index],
  64. smoothing_function=chen_cherry.method2,
  65. )
  66. total_out_token += sentence_bleu(
  67. [" ".join(target_texts).split()],
  68. " ".join(pred_texts).split(),
  69. smoothing_function=chen_cherry.method2,
  70. )
  71. print("total: ", total)
  72. print("out_idx: ", round(100 * total_out_idx / total, 1))
  73. print("out_token: ", round(100 * total_out_token / total, 1))
  74. if __name__ == "__main__":
  75. app()