|
|
@@ -94,11 +94,39 @@ def replace_text_span(pymu_spans, ocr_spans):
|
|
|
return list(filter(lambda x: x["type"] != ContentType.Text, ocr_spans)) + pymu_spans
|
|
|
|
|
|
|
|
|
-def do_predict(boxes: List[List[int]]) -> List[int]:
|
|
|
+def model_init(model_name: str):
|
|
|
from transformers import LayoutLMv3ForTokenClassification
|
|
|
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
+ if model_name == "layoutreader":
|
|
|
+ model = (
|
|
|
+ LayoutLMv3ForTokenClassification.from_pretrained("hantian/layoutreader")
|
|
|
+ # .bfloat16()
|
|
|
+ .to(device)
|
|
|
+ .eval()
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ logger.error("model name not allow")
|
|
|
+ exit(1)
|
|
|
+ return model
|
|
|
+
|
|
|
+
|
|
|
+class ModelSingleton:
|
|
|
+ _instance = None
|
|
|
+ _models = {}
|
|
|
+
|
|
|
+ def __new__(cls, *args, **kwargs):
|
|
|
+ if cls._instance is None:
|
|
|
+ cls._instance = super().__new__(cls)
|
|
|
+ return cls._instance
|
|
|
+
|
|
|
+ def get_model(self, model_name: str):
|
|
|
+ if model_name not in self._models:
|
|
|
+ self._models[model_name] = model_init(model_name=model_name)
|
|
|
+ return self._models[model_name]
|
|
|
+
|
|
|
+
|
|
|
+def do_predict(boxes: List[List[int]], model) -> List[int]:
|
|
|
from magic_pdf.v3.helpers import prepare_inputs, boxes2inputs, parse_logits
|
|
|
- model = LayoutLMv3ForTokenClassification.from_pretrained("hantian/layoutreader")
|
|
|
- # model.to("cuda")
|
|
|
inputs = boxes2inputs(boxes)
|
|
|
inputs = prepare_inputs(inputs, model)
|
|
|
logits = model(**inputs).logits.cpu().squeeze(0)
|
|
|
@@ -184,7 +212,7 @@ def parse_page_core(pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter,
|
|
|
x_scale = 1000.0 / page_w
|
|
|
y_scale = 1000.0 / page_h
|
|
|
boxes = []
|
|
|
- logger.info(f"Scale: {x_scale}, {y_scale}, Boxes len: {len(page_line_list)}")
|
|
|
+ # logger.info(f"Scale: {x_scale}, {y_scale}, Boxes len: {len(page_line_list)}")
|
|
|
for left, top, right, bottom in page_line_list:
|
|
|
left = round(left * x_scale)
|
|
|
top = round(top * y_scale)
|
|
|
@@ -194,9 +222,12 @@ def parse_page_core(pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter,
|
|
|
1000 >= right >= left >= 0 and 1000 >= bottom >= top >= 0
|
|
|
), f"Invalid box. right: {right}, left: {left}, bottom: {bottom}, top: {top}"
|
|
|
boxes.append([left, top, right, bottom])
|
|
|
+ model_manager = ModelSingleton()
|
|
|
+ model = model_manager.get_model("layoutreader")
|
|
|
layoutreader_start = time.time()
|
|
|
- orders = do_predict(boxes)
|
|
|
- logger.info(f"layoutreader cost time{time.time() - layoutreader_start}")
|
|
|
+ with torch.no_grad():
|
|
|
+ orders = do_predict(boxes, model)
|
|
|
+ # logger.info(f"layoutreader cost time{time.time() - layoutreader_start}")
|
|
|
sorted_bboxes = [page_line_list[i] for i in orders]
|
|
|
|
|
|
'''根据line的中位数算block的序列关系'''
|