Przeglądaj źródła

refactor(pdf_parse_union_core_v2): implement model initialization within classRefactored model initialization to be handled by a singleton class to ensure that model
instances are reused across calls, avoiding redundant initializations. Removed logger
information that was commented out and ensured consistency in logging behavior.

myhloli 1 rok temu
rodzic
commit
b9dfdea3cb
1 zmienionych plików z 37 dodań i 6 usunięć
  1. 37 6
      magic_pdf/pdf_parse_union_core_v2.py

+ 37 - 6
magic_pdf/pdf_parse_union_core_v2.py

@@ -94,11 +94,39 @@ def replace_text_span(pymu_spans, ocr_spans):
     return list(filter(lambda x: x["type"] != ContentType.Text, ocr_spans)) + pymu_spans
 
 
-def do_predict(boxes: List[List[int]]) -> List[int]:
+def model_init(model_name: str):
     from transformers import LayoutLMv3ForTokenClassification
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    if model_name == "layoutreader":
+        model = (
+            LayoutLMv3ForTokenClassification.from_pretrained("hantian/layoutreader")
+            # .bfloat16()
+            .to(device)
+            .eval()
+        )
+    else:
+        logger.error("model name not allow")
+        exit(1)
+    return model
+
+
+class ModelSingleton:
+    _instance = None
+    _models = {}
+
+    def __new__(cls, *args, **kwargs):
+        if cls._instance is None:
+            cls._instance = super().__new__(cls)
+        return cls._instance
+
+    def get_model(self, model_name: str):
+        if model_name not in self._models:
+            self._models[model_name] = model_init(model_name=model_name)
+        return self._models[model_name]
+
+
+def do_predict(boxes: List[List[int]], model) -> List[int]:
     from magic_pdf.v3.helpers import prepare_inputs, boxes2inputs, parse_logits
-    model = LayoutLMv3ForTokenClassification.from_pretrained("hantian/layoutreader")
-    # model.to("cuda")
     inputs = boxes2inputs(boxes)
     inputs = prepare_inputs(inputs, model)
     logits = model(**inputs).logits.cpu().squeeze(0)
@@ -184,7 +212,7 @@ def parse_page_core(pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter,
     x_scale = 1000.0 / page_w
     y_scale = 1000.0 / page_h
     boxes = []
-    logger.info(f"Scale: {x_scale}, {y_scale}, Boxes len: {len(page_line_list)}")
+    # logger.info(f"Scale: {x_scale}, {y_scale}, Boxes len: {len(page_line_list)}")
     for left, top, right, bottom in page_line_list:
         left = round(left * x_scale)
         top = round(top * y_scale)
@@ -194,9 +222,12 @@ def parse_page_core(pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter,
                 1000 >= right >= left >= 0 and 1000 >= bottom >= top >= 0
         ), f"Invalid box. right: {right}, left: {left}, bottom: {bottom}, top: {top}"
         boxes.append([left, top, right, bottom])
+    model_manager = ModelSingleton()
+    model = model_manager.get_model("layoutreader")
     layoutreader_start = time.time()
-    orders = do_predict(boxes)
-    logger.info(f"layoutreader cost time{time.time() - layoutreader_start}")
+    with torch.no_grad():
+        orders = do_predict(boxes, model)
+    # logger.info(f"layoutreader cost time{time.time() - layoutreader_start}")
     sorted_bboxes = [page_line_list[i] for i in orders]
 
     '''根据line的中位数算block的序列关系'''