train.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. import os
  2. from dataclasses import dataclass, field
  3. from datasets import load_dataset, Dataset
  4. from loguru import logger
  5. from transformers import (
  6. TrainingArguments,
  7. HfArgumentParser,
  8. LayoutLMv3ForTokenClassification,
  9. set_seed,
  10. )
  11. from transformers.trainer import Trainer
  12. from helpers import DataCollator, MAX_LEN
  13. @dataclass
  14. class Arguments(TrainingArguments):
  15. model_dir: str = field(
  16. default=None,
  17. metadata={"help": "Path to model, based on `microsoft/layoutlmv3-base`"},
  18. )
  19. dataset_dir: str = field(
  20. default=None,
  21. metadata={"help": "Path to dataset"},
  22. )
  23. def load_train_and_dev_dataset(path: str) -> (Dataset, Dataset):
  24. datasets = load_dataset(
  25. "json",
  26. data_files={
  27. "train": os.path.join(path, "train.jsonl.gz"),
  28. "dev": os.path.join(path, "dev.jsonl.gz"),
  29. },
  30. )
  31. return datasets["train"], datasets["dev"]
  32. def main():
  33. parser = HfArgumentParser((Arguments,))
  34. args: Arguments = parser.parse_args_into_dataclasses()[0]
  35. set_seed(args.seed)
  36. train_dataset, dev_dataset = load_train_and_dev_dataset(args.dataset_dir)
  37. logger.info(
  38. "Train dataset size: {}, Dev dataset size: {}".format(
  39. len(train_dataset), len(dev_dataset)
  40. )
  41. )
  42. model = LayoutLMv3ForTokenClassification.from_pretrained(
  43. args.model_dir, num_labels=MAX_LEN, visual_embed=False
  44. )
  45. data_collator = DataCollator()
  46. trainer = Trainer(
  47. model=model,
  48. args=args,
  49. train_dataset=train_dataset,
  50. eval_dataset=dev_dataset,
  51. data_collator=data_collator,
  52. )
  53. trainer.train()
  54. if __name__ == "__main__":
  55. main()