from collections import defaultdict from typing import List, Dict import torch from transformers import LayoutLMv3ForTokenClassification MAX_LEN = 510 CLS_TOKEN_ID = 0 UNK_TOKEN_ID = 3 EOS_TOKEN_ID = 2 class DataCollator: def __call__(self, features: List[dict]) -> Dict[str, torch.Tensor]: bbox = [] labels = [] input_ids = [] attention_mask = [] # clip bbox and labels to max length, build input_ids and attention_mask for feature in features: _bbox = feature["source_boxes"] if len(_bbox) > MAX_LEN: _bbox = _bbox[:MAX_LEN] _labels = feature["target_index"] if len(_labels) > MAX_LEN: _labels = _labels[:MAX_LEN] _input_ids = [UNK_TOKEN_ID] * len(_bbox) _attention_mask = [1] * len(_bbox) assert len(_bbox) == len(_labels) == len(_input_ids) == len(_attention_mask) bbox.append(_bbox) labels.append(_labels) input_ids.append(_input_ids) attention_mask.append(_attention_mask) # add CLS and EOS tokens for i in range(len(bbox)): bbox[i] = [[0, 0, 0, 0]] + bbox[i] + [[0, 0, 0, 0]] labels[i] = [-100] + labels[i] + [-100] input_ids[i] = [CLS_TOKEN_ID] + input_ids[i] + [EOS_TOKEN_ID] attention_mask[i] = [1] + attention_mask[i] + [1] # padding to max length max_len = max(len(x) for x in bbox) for i in range(len(bbox)): bbox[i] = bbox[i] + [[0, 0, 0, 0]] * (max_len - len(bbox[i])) labels[i] = labels[i] + [-100] * (max_len - len(labels[i])) input_ids[i] = input_ids[i] + [EOS_TOKEN_ID] * (max_len - len(input_ids[i])) attention_mask[i] = attention_mask[i] + [0] * ( max_len - len(attention_mask[i]) ) ret = { "bbox": torch.tensor(bbox), "attention_mask": torch.tensor(attention_mask), "labels": torch.tensor(labels), "input_ids": torch.tensor(input_ids), } # set label > MAX_LEN to -100, because original labels may be > MAX_LEN ret["labels"][ret["labels"] > MAX_LEN] = -100 # set label > 0 to label-1, because original labels are 1-indexed ret["labels"][ret["labels"] > 0] -= 1 return ret def boxes2inputs(boxes: List[List[int]]) -> Dict[str, torch.Tensor]: bbox = [[0, 0, 0, 0]] + boxes + [[0, 0, 0, 0]] input_ids = [CLS_TOKEN_ID] + [UNK_TOKEN_ID] * len(boxes) + [EOS_TOKEN_ID] attention_mask = [1] + [1] * len(boxes) + [1] return { "bbox": torch.tensor([bbox]), "attention_mask": torch.tensor([attention_mask]), "input_ids": torch.tensor([input_ids]), } def prepare_inputs( inputs: Dict[str, torch.Tensor], model: LayoutLMv3ForTokenClassification ) -> Dict[str, torch.Tensor]: ret = {} for k, v in inputs.items(): v = v.to(model.device) if torch.is_floating_point(v): v = v.to(model.dtype) ret[k] = v return ret def parse_logits(logits: torch.Tensor, length: int) -> List[int]: """ parse logits to orders :param logits: logits from model :param length: input length :return: orders """ logits = logits[1 : length + 1, :length] orders = logits.argsort(descending=False).tolist() ret = [o.pop() for o in orders] while True: order_to_idxes = defaultdict(list) for idx, order in enumerate(ret): order_to_idxes[order].append(idx) # filter idxes len > 1 order_to_idxes = {k: v for k, v in order_to_idxes.items() if len(v) > 1} if not order_to_idxes: break # filter for order, idxes in order_to_idxes.items(): # find original logits of idxes idxes_to_logit = {} for idx in idxes: idxes_to_logit[idx] = logits[idx, order] idxes_to_logit = sorted( idxes_to_logit.items(), key=lambda x: x[1], reverse=True ) # keep the highest logit as order, set others to next candidate for idx, _ in idxes_to_logit[1:]: ret[idx] = orders[idx].pop() return ret def check_duplicate(a: List[int]) -> bool: return len(a) != len(set(a))