xfund.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. import os
  2. import json
  3. import torch
  4. from torch.utils.data.dataset import Dataset
  5. from torchvision import transforms
  6. from PIL import Image
  7. from .image_utils import Compose, RandomResizedCropAndInterpolationWithTwoPic
  8. XFund_label2ids = {
  9. "O":0,
  10. 'B-HEADER':1,
  11. 'I-HEADER':2,
  12. 'B-QUESTION':3,
  13. 'I-QUESTION':4,
  14. 'B-ANSWER':5,
  15. 'I-ANSWER':6,
  16. }
  17. class xfund_dataset(Dataset):
  18. def box_norm(self, box, width, height):
  19. def clip(min_num, num, max_num):
  20. return min(max(num, min_num), max_num)
  21. x0, y0, x1, y1 = box
  22. x0 = clip(0, int((x0 / width) * 1000), 1000)
  23. y0 = clip(0, int((y0 / height) * 1000), 1000)
  24. x1 = clip(0, int((x1 / width) * 1000), 1000)
  25. y1 = clip(0, int((y1 / height) * 1000), 1000)
  26. assert x1 >= x0
  27. assert y1 >= y0
  28. return [x0, y0, x1, y1]
  29. def get_segment_ids(self, bboxs):
  30. segment_ids = []
  31. for i in range(len(bboxs)):
  32. if i == 0:
  33. segment_ids.append(0)
  34. else:
  35. if bboxs[i - 1] == bboxs[i]:
  36. segment_ids.append(segment_ids[-1])
  37. else:
  38. segment_ids.append(segment_ids[-1] + 1)
  39. return segment_ids
  40. def get_position_ids(self, segment_ids):
  41. position_ids = []
  42. for i in range(len(segment_ids)):
  43. if i == 0:
  44. position_ids.append(2)
  45. else:
  46. if segment_ids[i] == segment_ids[i - 1]:
  47. position_ids.append(position_ids[-1] + 1)
  48. else:
  49. position_ids.append(2)
  50. return position_ids
  51. def load_data(
  52. self,
  53. data_file,
  54. ):
  55. # re-org data format
  56. total_data = {"id": [], "lines": [], "bboxes": [], "ner_tags": [], "image_path": []}
  57. for i in range(len(data_file['documents'])):
  58. width, height = data_file['documents'][i]['img']['width'], data_file['documents'][i]['img'][
  59. 'height']
  60. cur_doc_lines, cur_doc_bboxes, cur_doc_ner_tags, cur_doc_image_path = [], [], [], []
  61. for j in range(len(data_file['documents'][i]['document'])):
  62. cur_item = data_file['documents'][i]['document'][j]
  63. cur_doc_lines.append(cur_item['text'])
  64. cur_doc_bboxes.append(self.box_norm(cur_item['box'], width=width, height=height))
  65. cur_doc_ner_tags.append(cur_item['label'])
  66. total_data['id'] += [len(total_data['id'])]
  67. total_data['lines'] += [cur_doc_lines]
  68. total_data['bboxes'] += [cur_doc_bboxes]
  69. total_data['ner_tags'] += [cur_doc_ner_tags]
  70. total_data['image_path'] += [data_file['documents'][i]['img']['fname']]
  71. # tokenize text and get bbox/label
  72. total_input_ids, total_bboxs, total_label_ids = [], [], []
  73. for i in range(len(total_data['lines'])):
  74. cur_doc_input_ids, cur_doc_bboxs, cur_doc_labels = [], [], []
  75. for j in range(len(total_data['lines'][i])):
  76. cur_input_ids = self.tokenizer(total_data['lines'][i][j], truncation=False, add_special_tokens=False, return_attention_mask=False)['input_ids']
  77. if len(cur_input_ids) == 0: continue
  78. cur_label = total_data['ner_tags'][i][j].upper()
  79. if cur_label == 'OTHER':
  80. cur_labels = ["O"] * len(cur_input_ids)
  81. for k in range(len(cur_labels)):
  82. cur_labels[k] = self.label2ids[cur_labels[k]]
  83. else:
  84. cur_labels = [cur_label] * len(cur_input_ids)
  85. cur_labels[0] = self.label2ids['B-' + cur_labels[0]]
  86. for k in range(1, len(cur_labels)):
  87. cur_labels[k] = self.label2ids['I-' + cur_labels[k]]
  88. assert len(cur_input_ids) == len([total_data['bboxes'][i][j]] * len(cur_input_ids)) == len(cur_labels)
  89. cur_doc_input_ids += cur_input_ids
  90. cur_doc_bboxs += [total_data['bboxes'][i][j]] * len(cur_input_ids)
  91. cur_doc_labels += cur_labels
  92. assert len(cur_doc_input_ids) == len(cur_doc_bboxs) == len(cur_doc_labels)
  93. assert len(cur_doc_input_ids) > 0
  94. total_input_ids.append(cur_doc_input_ids)
  95. total_bboxs.append(cur_doc_bboxs)
  96. total_label_ids.append(cur_doc_labels)
  97. assert len(total_input_ids) == len(total_bboxs) == len(total_label_ids)
  98. # split text to several slices because of over-length
  99. input_ids, bboxs, labels = [], [], []
  100. segment_ids, position_ids = [], []
  101. image_path = []
  102. for i in range(len(total_input_ids)):
  103. start = 0
  104. cur_iter = 0
  105. while start < len(total_input_ids[i]):
  106. end = min(start + 510, len(total_input_ids[i]))
  107. input_ids.append([self.tokenizer.cls_token_id] + total_input_ids[i][start: end] + [self.tokenizer.sep_token_id])
  108. bboxs.append([[0, 0, 0, 0]] + total_bboxs[i][start: end] + [[1000, 1000, 1000, 1000]])
  109. labels.append([-100] + total_label_ids[i][start: end] + [-100])
  110. cur_segment_ids = self.get_segment_ids(bboxs[-1])
  111. cur_position_ids = self.get_position_ids(cur_segment_ids)
  112. segment_ids.append(cur_segment_ids)
  113. position_ids.append(cur_position_ids)
  114. image_path.append(os.path.join(self.args.data_dir, "images", total_data['image_path'][i]))
  115. start = end
  116. cur_iter += 1
  117. assert len(input_ids) == len(bboxs) == len(labels) == len(segment_ids) == len(position_ids)
  118. assert len(segment_ids) == len(image_path)
  119. res = {
  120. 'input_ids': input_ids,
  121. 'bbox': bboxs,
  122. 'labels': labels,
  123. 'segment_ids': segment_ids,
  124. 'position_ids': position_ids,
  125. 'image_path': image_path,
  126. }
  127. return res
  128. def __init__(
  129. self,
  130. args,
  131. tokenizer,
  132. mode
  133. ):
  134. self.args = args
  135. self.mode = mode
  136. self.cur_la = args.language
  137. self.tokenizer = tokenizer
  138. self.label2ids = XFund_label2ids
  139. self.common_transform = Compose([
  140. RandomResizedCropAndInterpolationWithTwoPic(
  141. size=args.input_size, interpolation=args.train_interpolation,
  142. ),
  143. ])
  144. self.patch_transform = transforms.Compose([
  145. transforms.ToTensor(),
  146. transforms.Normalize(
  147. mean=torch.tensor((0.5, 0.5, 0.5)),
  148. std=torch.tensor((0.5, 0.5, 0.5)))
  149. ])
  150. data_file = json.load(
  151. open(os.path.join(args.data_dir, "{}.{}.json".format(self.cur_la, 'train' if mode == 'train' else 'val')),
  152. 'r'))
  153. self.feature = self.load_data(data_file)
  154. def __len__(self):
  155. return len(self.feature['input_ids'])
  156. def __getitem__(self, index):
  157. input_ids = self.feature["input_ids"][index]
  158. # attention_mask = self.feature["attention_mask"][index]
  159. attention_mask = [1] * len(input_ids)
  160. labels = self.feature["labels"][index]
  161. bbox = self.feature["bbox"][index]
  162. segment_ids = self.feature['segment_ids'][index]
  163. position_ids = self.feature['position_ids'][index]
  164. img = pil_loader(self.feature['image_path'][index])
  165. for_patches, _ = self.common_transform(img, augmentation=False)
  166. patch = self.patch_transform(for_patches)
  167. assert len(input_ids) == len(attention_mask) == len(labels) == len(bbox) == len(segment_ids)
  168. res = {
  169. "input_ids": input_ids,
  170. "attention_mask": attention_mask,
  171. "labels": labels,
  172. "bbox": bbox,
  173. "segment_ids": segment_ids,
  174. "position_ids": position_ids,
  175. "images": patch,
  176. }
  177. return res
  178. def pil_loader(path: str) -> Image.Image:
  179. # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
  180. with open(path, 'rb') as f:
  181. img = Image.open(f)
  182. return img.convert('RGB')