| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667 |
- import os
- from dataclasses import dataclass, field
- from datasets import load_dataset, Dataset
- from loguru import logger
- from transformers import (
- TrainingArguments,
- HfArgumentParser,
- LayoutLMv3ForTokenClassification,
- set_seed,
- )
- from transformers.trainer import Trainer
- from helpers import DataCollator, MAX_LEN
- @dataclass
- class Arguments(TrainingArguments):
- model_dir: str = field(
- default=None,
- metadata={"help": "Path to model, based on `microsoft/layoutlmv3-base`"},
- )
- dataset_dir: str = field(
- default=None,
- metadata={"help": "Path to dataset"},
- )
- def load_train_and_dev_dataset(path: str) -> (Dataset, Dataset):
- datasets = load_dataset(
- "json",
- data_files={
- "train": os.path.join(path, "train.jsonl.gz"),
- "dev": os.path.join(path, "dev.jsonl.gz"),
- },
- )
- return datasets["train"], datasets["dev"]
- def main():
- parser = HfArgumentParser((Arguments,))
- args: Arguments = parser.parse_args_into_dataclasses()[0]
- set_seed(args.seed)
- train_dataset, dev_dataset = load_train_and_dev_dataset(args.dataset_dir)
- logger.info(
- "Train dataset size: {}, Dev dataset size: {}".format(
- len(train_dataset), len(dev_dataset)
- )
- )
- model = LayoutLMv3ForTokenClassification.from_pretrained(
- args.model_dir, num_labels=MAX_LEN, visual_embed=False
- )
- data_collator = DataCollator()
- trainer = Trainer(
- model=model,
- args=args,
- train_dataset=train_dataset,
- eval_dataset=dev_dataset,
- data_collator=data_collator,
- )
- trainer.train()
- if __name__ == "__main__":
- main()
|