|
|
@@ -9,7 +9,7 @@ import numpy as np
|
|
|
from .model_init import AtomModelSingleton
|
|
|
from .model_list import AtomicModel
|
|
|
from ...utils.config_reader import get_formula_enable, get_table_enable
|
|
|
-from ...utils.model_utils import crop_img, get_res_list_from_layout_res
|
|
|
+from ...utils.model_utils import crop_img, get_res_list_from_layout_res, clean_vram
|
|
|
from ...utils.ocr_utils import merge_det_boxes, update_det_boxes, sorted_boxes
|
|
|
from ...utils.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list, OcrConfidence, get_rotate_crop_image
|
|
|
from ...utils.pdf_image_tools import get_crop_np_img
|
|
|
@@ -71,7 +71,7 @@ class BatchAnalyze:
|
|
|
mfr_count += len(images_formula_list[image_index])
|
|
|
|
|
|
# 清理显存
|
|
|
- # clean_vram(self.model.device, vram_threshold=8)
|
|
|
+ clean_vram(self.model.device, vram_threshold=8)
|
|
|
|
|
|
ocr_res_list_all_page = []
|
|
|
table_res_list_all_page = []
|
|
|
@@ -93,18 +93,19 @@ class BatchAnalyze:
|
|
|
})
|
|
|
|
|
|
for table_res in table_res_list:
|
|
|
- # table_img, _ = crop_img(table_res, pil_img)
|
|
|
- # bbox = (241, 208, 1475, 2019)
|
|
|
- scale = 10/3
|
|
|
- # scale = 1
|
|
|
- crop_xmin, crop_ymin = int(table_res['poly'][0]), int(table_res['poly'][1])
|
|
|
- crop_xmax, crop_ymax = int(table_res['poly'][4]), int(table_res['poly'][5])
|
|
|
- bbox = (int(crop_xmin/scale), int(crop_ymin/scale), int(crop_xmax/scale), int(crop_ymax/scale))
|
|
|
- table_img = get_crop_np_img(bbox, np_img, scale=scale)
|
|
|
+ def get_crop_table_img(scale):
|
|
|
+ crop_xmin, crop_ymin = int(table_res['poly'][0]), int(table_res['poly'][1])
|
|
|
+ crop_xmax, crop_ymax = int(table_res['poly'][4]), int(table_res['poly'][5])
|
|
|
+ bbox = (int(crop_xmin / scale), int(crop_ymin / scale), int(crop_xmax / scale), int(crop_ymax / scale))
|
|
|
+ return get_crop_np_img(bbox, np_img, scale=scale)
|
|
|
+
|
|
|
+ wireless_table_img = get_crop_table_img(scale = 1)
|
|
|
+ wired_table_img = get_crop_table_img(scale = 10/3)
|
|
|
|
|
|
table_res_list_all_page.append({'table_res':table_res,
|
|
|
'lang':_lang,
|
|
|
- 'table_img':table_img,
|
|
|
+ 'table_img':wireless_table_img,
|
|
|
+ 'wired_table_img':wired_table_img,
|
|
|
})
|
|
|
|
|
|
# 表格识别 table recognition
|
|
|
@@ -137,18 +138,17 @@ class BatchAnalyze:
|
|
|
|
|
|
# OCR det 过程,顺序执行
|
|
|
rec_img_lang_group = defaultdict(list)
|
|
|
+ det_ocr_engine = atom_model_manager.get_atom_model(
|
|
|
+ atom_model_name=AtomicModel.OCR,
|
|
|
+ det_db_box_thresh=0.5,
|
|
|
+ det_db_unclip_ratio=1.6,
|
|
|
+ enable_merge_det_boxes=False,
|
|
|
+ )
|
|
|
for index, table_res_dict in enumerate(
|
|
|
tqdm(table_res_list_all_page, desc="Table-ocr det")
|
|
|
):
|
|
|
- ocr_engine = atom_model_manager.get_atom_model(
|
|
|
- atom_model_name=AtomicModel.OCR,
|
|
|
- det_db_box_thresh=0.5,
|
|
|
- det_db_unclip_ratio=1.6,
|
|
|
- # lang= table_res_dict["lang"],
|
|
|
- enable_merge_det_boxes=False,
|
|
|
- )
|
|
|
bgr_image = cv2.cvtColor(table_res_dict["table_img"], cv2.COLOR_RGB2BGR)
|
|
|
- ocr_result = ocr_engine.ocr(bgr_image, rec=False)[0]
|
|
|
+ ocr_result = det_ocr_engine.ocr(bgr_image, rec=False)[0]
|
|
|
# 构造需要 OCR 识别的图片字典,包括cropped_img, dt_box, table_id,并按照语言进行分组
|
|
|
for dt_box in ocr_result:
|
|
|
rec_img_lang_group[_lang].append(
|
|
|
@@ -171,8 +171,7 @@ class BatchAnalyze:
|
|
|
enable_merge_det_boxes=False,
|
|
|
)
|
|
|
cropped_img_list = [item["cropped_img"] for item in rec_img_list]
|
|
|
- ocr_res_list = \
|
|
|
- ocr_engine.ocr(cropped_img_list, det=False, tqdm_enable=True, tqdm_desc="Table-ocr rec")[0]
|
|
|
+ ocr_res_list = ocr_engine.ocr(cropped_img_list, det=False, tqdm_enable=True, tqdm_desc=f"Table-ocr rec {_lang}")[0]
|
|
|
# 按照 table_id 将识别结果进行回填
|
|
|
for img_dict, ocr_res in zip(rec_img_list, ocr_res_list):
|
|
|
if table_res_list_all_page[img_dict["table_id"]].get("ocr_result"):
|
|
|
@@ -184,6 +183,8 @@ class BatchAnalyze:
|
|
|
[img_dict["dt_box"], html.escape(ocr_res[0]), ocr_res[1]]
|
|
|
]
|
|
|
|
|
|
+ clean_vram(self.model.device, vram_threshold=8)
|
|
|
+
|
|
|
# 先对所有表格使用无线表格模型,然后对分类为有线的表格使用有线表格模型
|
|
|
wireless_table_model = atom_model_manager.get_atom_model(
|
|
|
atom_model_name=AtomicModel.WirelessTable,
|
|
|
@@ -193,18 +194,27 @@ class BatchAnalyze:
|
|
|
# 单独拿出有线表格进行预测
|
|
|
wired_table_res_list = []
|
|
|
for table_res_dict in table_res_list_all_page:
|
|
|
- if table_res_dict["table_res"]["cls_label"] == AtomicModel.WiredTable:
|
|
|
+ # logger.debug(f"Table classification result: {table_res_dict["table_res"]["cls_label"]} with confidence {table_res_dict["table_res"]["cls_score"]}")
|
|
|
+ if (
|
|
|
+ (table_res_dict["table_res"]["cls_label"] == AtomicModel.WirelessTable and table_res_dict["table_res"]["cls_score"] < 0.9)
|
|
|
+ or table_res_dict["table_res"]["cls_label"] == AtomicModel.WiredTable
|
|
|
+ ):
|
|
|
wired_table_res_list.append(table_res_dict)
|
|
|
+ del table_res_dict["table_res"]["cls_label"]
|
|
|
+ del table_res_dict["table_res"]["cls_score"]
|
|
|
if wired_table_res_list:
|
|
|
for table_res_dict in tqdm(
|
|
|
wired_table_res_list, desc="Table-wired Predict"
|
|
|
):
|
|
|
+ if not table_res_dict.get("ocr_result", None):
|
|
|
+ continue
|
|
|
+
|
|
|
wired_table_model = atom_model_manager.get_atom_model(
|
|
|
atom_model_name=AtomicModel.WiredTable,
|
|
|
lang=table_res_dict["lang"],
|
|
|
)
|
|
|
table_res_dict["table_res"]["html"] = wired_table_model.predict(
|
|
|
- table_res_dict["table_img"],
|
|
|
+ table_res_dict["wired_table_img"],
|
|
|
table_res_dict["ocr_result"],
|
|
|
table_res_dict["table_res"].get("html", None)
|
|
|
)
|
|
|
@@ -428,7 +438,7 @@ class BatchAnalyze:
|
|
|
layout_res_item['poly'][4], layout_res_item['poly'][5]]
|
|
|
layout_res_width = layout_res_bbox[2] - layout_res_bbox[0]
|
|
|
layout_res_height = layout_res_bbox[3] - layout_res_bbox[1]
|
|
|
- if ocr_text in ['(204号', '(20', '(2', '(2号', '(20号'] and ocr_score < 0.8 and layout_res_width < layout_res_height:
|
|
|
+ if ocr_text in ['(204号', '(20', '(2', '(2号', '(20号', '号', '(204'] and ocr_score < 0.8 and layout_res_width < layout_res_height:
|
|
|
layout_res_item['category_id'] = 16
|
|
|
|
|
|
total_processed += len(img_crop_list)
|