|
|
@@ -93,18 +93,19 @@ class BatchAnalyze:
|
|
|
})
|
|
|
|
|
|
for table_res in table_res_list:
|
|
|
- # table_img, _ = crop_img(table_res, pil_img)
|
|
|
- # bbox = (241, 208, 1475, 2019)
|
|
|
- scale = 10/3
|
|
|
- # scale = 1
|
|
|
- crop_xmin, crop_ymin = int(table_res['poly'][0]), int(table_res['poly'][1])
|
|
|
- crop_xmax, crop_ymax = int(table_res['poly'][4]), int(table_res['poly'][5])
|
|
|
- bbox = (int(crop_xmin/scale), int(crop_ymin/scale), int(crop_xmax/scale), int(crop_ymax/scale))
|
|
|
- table_img = get_crop_np_img(bbox, np_img, scale=scale)
|
|
|
+ def get_crop_table_img(scale):
|
|
|
+ crop_xmin, crop_ymin = int(table_res['poly'][0]), int(table_res['poly'][1])
|
|
|
+ crop_xmax, crop_ymax = int(table_res['poly'][4]), int(table_res['poly'][5])
|
|
|
+ bbox = (int(crop_xmin / scale), int(crop_ymin / scale), int(crop_xmax / scale), int(crop_ymax / scale))
|
|
|
+ return get_crop_np_img(bbox, np_img, scale=scale)
|
|
|
+
|
|
|
+ wireless_table_img = get_crop_table_img(scale = 1)
|
|
|
+ wired_table_img = get_crop_table_img(scale = 10/3)
|
|
|
|
|
|
table_res_list_all_page.append({'table_res':table_res,
|
|
|
'lang':_lang,
|
|
|
- 'table_img':table_img,
|
|
|
+ 'table_img':wireless_table_img,
|
|
|
+ 'wired_table_img':wired_table_img,
|
|
|
})
|
|
|
|
|
|
# 表格识别 table recognition
|
|
|
@@ -193,8 +194,14 @@ class BatchAnalyze:
|
|
|
# 单独拿出有线表格进行预测
|
|
|
wired_table_res_list = []
|
|
|
for table_res_dict in table_res_list_all_page:
|
|
|
- if table_res_dict["table_res"]["cls_label"] == AtomicModel.WiredTable:
|
|
|
+ # logger.debug(f"Table classification result: {table_res_dict["table_res"]["cls_label"]} with confidence {table_res_dict["table_res"]["cls_score"]}")
|
|
|
+ if (
|
|
|
+ (table_res_dict["table_res"]["cls_label"] == AtomicModel.WirelessTable and table_res_dict["table_res"]["cls_score"] < 0.9)
|
|
|
+ or table_res_dict["table_res"]["cls_label"] == AtomicModel.WiredTable
|
|
|
+ ):
|
|
|
wired_table_res_list.append(table_res_dict)
|
|
|
+ del table_res_dict["table_res"]["cls_label"]
|
|
|
+ del table_res_dict["table_res"]["cls_score"]
|
|
|
if wired_table_res_list:
|
|
|
for table_res_dict in tqdm(
|
|
|
wired_table_res_list, desc="Table-wired Predict"
|
|
|
@@ -207,7 +214,7 @@ class BatchAnalyze:
|
|
|
lang=table_res_dict["lang"],
|
|
|
)
|
|
|
table_res_dict["table_res"]["html"] = wired_table_model.predict(
|
|
|
- table_res_dict["table_img"],
|
|
|
+ table_res_dict["wired_table_img"],
|
|
|
table_res_dict["ocr_result"],
|
|
|
table_res_dict["table_res"].get("html", None)
|
|
|
)
|