helpers.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. from collections import defaultdict
  2. from typing import List, Dict
  3. import torch
  4. from transformers import LayoutLMv3ForTokenClassification
  5. MAX_LEN = 510
  6. CLS_TOKEN_ID = 0
  7. UNK_TOKEN_ID = 3
  8. EOS_TOKEN_ID = 2
  9. class DataCollator:
  10. def __call__(self, features: List[dict]) -> Dict[str, torch.Tensor]:
  11. bbox = []
  12. labels = []
  13. input_ids = []
  14. attention_mask = []
  15. # clip bbox and labels to max length, build input_ids and attention_mask
  16. for feature in features:
  17. _bbox = feature["source_boxes"]
  18. if len(_bbox) > MAX_LEN:
  19. _bbox = _bbox[:MAX_LEN]
  20. _labels = feature["target_index"]
  21. if len(_labels) > MAX_LEN:
  22. _labels = _labels[:MAX_LEN]
  23. _input_ids = [UNK_TOKEN_ID] * len(_bbox)
  24. _attention_mask = [1] * len(_bbox)
  25. assert len(_bbox) == len(_labels) == len(_input_ids) == len(_attention_mask)
  26. bbox.append(_bbox)
  27. labels.append(_labels)
  28. input_ids.append(_input_ids)
  29. attention_mask.append(_attention_mask)
  30. # add CLS and EOS tokens
  31. for i in range(len(bbox)):
  32. bbox[i] = [[0, 0, 0, 0]] + bbox[i] + [[0, 0, 0, 0]]
  33. labels[i] = [-100] + labels[i] + [-100]
  34. input_ids[i] = [CLS_TOKEN_ID] + input_ids[i] + [EOS_TOKEN_ID]
  35. attention_mask[i] = [1] + attention_mask[i] + [1]
  36. # padding to max length
  37. max_len = max(len(x) for x in bbox)
  38. for i in range(len(bbox)):
  39. bbox[i] = bbox[i] + [[0, 0, 0, 0]] * (max_len - len(bbox[i]))
  40. labels[i] = labels[i] + [-100] * (max_len - len(labels[i]))
  41. input_ids[i] = input_ids[i] + [EOS_TOKEN_ID] * (max_len - len(input_ids[i]))
  42. attention_mask[i] = attention_mask[i] + [0] * (
  43. max_len - len(attention_mask[i])
  44. )
  45. ret = {
  46. "bbox": torch.tensor(bbox),
  47. "attention_mask": torch.tensor(attention_mask),
  48. "labels": torch.tensor(labels),
  49. "input_ids": torch.tensor(input_ids),
  50. }
  51. # set label > MAX_LEN to -100, because original labels may be > MAX_LEN
  52. ret["labels"][ret["labels"] > MAX_LEN] = -100
  53. # set label > 0 to label-1, because original labels are 1-indexed
  54. ret["labels"][ret["labels"] > 0] -= 1
  55. return ret
  56. def boxes2inputs(boxes: List[List[int]]) -> Dict[str, torch.Tensor]:
  57. bbox = [[0, 0, 0, 0]] + boxes + [[0, 0, 0, 0]]
  58. input_ids = [CLS_TOKEN_ID] + [UNK_TOKEN_ID] * len(boxes) + [EOS_TOKEN_ID]
  59. attention_mask = [1] + [1] * len(boxes) + [1]
  60. return {
  61. "bbox": torch.tensor([bbox]),
  62. "attention_mask": torch.tensor([attention_mask]),
  63. "input_ids": torch.tensor([input_ids]),
  64. }
  65. def prepare_inputs(
  66. inputs: Dict[str, torch.Tensor], model: LayoutLMv3ForTokenClassification
  67. ) -> Dict[str, torch.Tensor]:
  68. ret = {}
  69. for k, v in inputs.items():
  70. v = v.to(model.device)
  71. if torch.is_floating_point(v):
  72. v = v.to(model.dtype)
  73. ret[k] = v
  74. return ret
  75. def parse_logits(logits: torch.Tensor, length: int) -> List[int]:
  76. """
  77. parse logits to orders
  78. :param logits: logits from model
  79. :param length: input length
  80. :return: orders
  81. """
  82. logits = logits[1 : length + 1, :length]
  83. orders = logits.argsort(descending=False).tolist()
  84. ret = [o.pop() for o in orders]
  85. while True:
  86. order_to_idxes = defaultdict(list)
  87. for idx, order in enumerate(ret):
  88. order_to_idxes[order].append(idx)
  89. # filter idxes len > 1
  90. order_to_idxes = {k: v for k, v in order_to_idxes.items() if len(v) > 1}
  91. if not order_to_idxes:
  92. break
  93. # filter
  94. for order, idxes in order_to_idxes.items():
  95. # find original logits of idxes
  96. idxes_to_logit = {}
  97. for idx in idxes:
  98. idxes_to_logit[idx] = logits[idx, order]
  99. idxes_to_logit = sorted(
  100. idxes_to_logit.items(), key=lambda x: x[1], reverse=True
  101. )
  102. # keep the highest logit as order, set others to next candidate
  103. for idx, _ in idxes_to_logit[1:]:
  104. ret[idx] = orders[idx].pop()
  105. return ret
  106. def check_duplicate(a: List[int]) -> bool:
  107. return len(a) != len(set(a))