funsd.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. # coding=utf-8
  2. '''
  3. Reference: https://huggingface.co/datasets/nielsr/funsd/blob/main/funsd.py
  4. '''
  5. import json
  6. import os
  7. import datasets
  8. from .image_utils import load_image, normalize_bbox
  9. logger = datasets.logging.get_logger(__name__)
  10. _CITATION = """\
  11. @article{Jaume2019FUNSDAD,
  12. title={FUNSD: A Dataset for Form Understanding in Noisy Scanned Documents},
  13. author={Guillaume Jaume and H. K. Ekenel and J. Thiran},
  14. journal={2019 International Conference on Document Analysis and Recognition Workshops (ICDARW)},
  15. year={2019},
  16. volume={2},
  17. pages={1-6}
  18. }
  19. """
  20. _DESCRIPTION = """\
  21. https://guillaumejaume.github.io/FUNSD/
  22. """
  23. class FunsdConfig(datasets.BuilderConfig):
  24. """BuilderConfig for FUNSD"""
  25. def __init__(self, **kwargs):
  26. """BuilderConfig for FUNSD.
  27. Args:
  28. **kwargs: keyword arguments forwarded to super.
  29. """
  30. super(FunsdConfig, self).__init__(**kwargs)
  31. class Funsd(datasets.GeneratorBasedBuilder):
  32. """Conll2003 dataset."""
  33. BUILDER_CONFIGS = [
  34. FunsdConfig(name="funsd", version=datasets.Version("1.0.0"), description="FUNSD dataset"),
  35. ]
  36. def _info(self):
  37. return datasets.DatasetInfo(
  38. description=_DESCRIPTION,
  39. features=datasets.Features(
  40. {
  41. "id": datasets.Value("string"),
  42. "tokens": datasets.Sequence(datasets.Value("string")),
  43. "bboxes": datasets.Sequence(datasets.Sequence(datasets.Value("int64"))),
  44. "ner_tags": datasets.Sequence(
  45. datasets.features.ClassLabel(
  46. names=["O", "B-HEADER", "I-HEADER", "B-QUESTION", "I-QUESTION", "B-ANSWER", "I-ANSWER"]
  47. )
  48. ),
  49. "image": datasets.Array3D(shape=(3, 224, 224), dtype="uint8"),
  50. "image_path": datasets.Value("string"),
  51. }
  52. ),
  53. supervised_keys=None,
  54. homepage="https://guillaumejaume.github.io/FUNSD/",
  55. citation=_CITATION,
  56. )
  57. def _split_generators(self, dl_manager):
  58. """Returns SplitGenerators."""
  59. downloaded_file = dl_manager.download_and_extract("https://guillaumejaume.github.io/FUNSD/dataset.zip")
  60. return [
  61. datasets.SplitGenerator(
  62. name=datasets.Split.TRAIN, gen_kwargs={"filepath": f"{downloaded_file}/dataset/training_data/"}
  63. ),
  64. datasets.SplitGenerator(
  65. name=datasets.Split.TEST, gen_kwargs={"filepath": f"{downloaded_file}/dataset/testing_data/"}
  66. ),
  67. ]
  68. def get_line_bbox(self, bboxs):
  69. x = [bboxs[i][j] for i in range(len(bboxs)) for j in range(0, len(bboxs[i]), 2)]
  70. y = [bboxs[i][j] for i in range(len(bboxs)) for j in range(1, len(bboxs[i]), 2)]
  71. x0, y0, x1, y1 = min(x), min(y), max(x), max(y)
  72. assert x1 >= x0 and y1 >= y0
  73. bbox = [[x0, y0, x1, y1] for _ in range(len(bboxs))]
  74. return bbox
  75. def _generate_examples(self, filepath):
  76. logger.info("⏳ Generating examples from = %s", filepath)
  77. ann_dir = os.path.join(filepath, "annotations")
  78. img_dir = os.path.join(filepath, "images")
  79. for guid, file in enumerate(sorted(os.listdir(ann_dir))):
  80. tokens = []
  81. bboxes = []
  82. ner_tags = []
  83. file_path = os.path.join(ann_dir, file)
  84. with open(file_path, "r", encoding="utf8") as f:
  85. data = json.load(f)
  86. image_path = os.path.join(img_dir, file)
  87. image_path = image_path.replace("json", "png")
  88. image, size = load_image(image_path)
  89. for item in data["form"]:
  90. cur_line_bboxes = []
  91. words, label = item["words"], item["label"]
  92. words = [w for w in words if w["text"].strip() != ""]
  93. if len(words) == 0:
  94. continue
  95. if label == "other":
  96. for w in words:
  97. tokens.append(w["text"])
  98. ner_tags.append("O")
  99. cur_line_bboxes.append(normalize_bbox(w["box"], size))
  100. else:
  101. tokens.append(words[0]["text"])
  102. ner_tags.append("B-" + label.upper())
  103. cur_line_bboxes.append(normalize_bbox(words[0]["box"], size))
  104. for w in words[1:]:
  105. tokens.append(w["text"])
  106. ner_tags.append("I-" + label.upper())
  107. cur_line_bboxes.append(normalize_bbox(w["box"], size))
  108. # by default: --segment_level_layout 1
  109. # if do not want to use segment_level_layout, comment the following line
  110. cur_line_bboxes = self.get_line_bbox(cur_line_bboxes)
  111. # box = normalize_bbox(item["box"], size)
  112. # cur_line_bboxes = [box for _ in range(len(words))]
  113. bboxes.extend(cur_line_bboxes)
  114. yield guid, {"id": str(guid), "tokens": tokens, "bboxes": bboxes, "ner_tags": ner_tags,
  115. "image": image, "image_path": image_path}