|
|
@@ -8,7 +8,7 @@ from magic_pdf.config.constants import MODEL_NAME
|
|
|
from magic_pdf.model.pdf_extract_kit import CustomPEKModel
|
|
|
from magic_pdf.model.sub_modules.model_utils import (
|
|
|
clean_vram, crop_img, get_res_list_from_layout_res)
|
|
|
-from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
|
|
|
+from magic_pdf.model.sub_modules.ocr.paddleocr2pytorch.ocr_utils import (
|
|
|
get_adjusted_mfdetrec_res, get_ocr_result_list)
|
|
|
|
|
|
YOLO_LAYOUT_BASE_BATCH_SIZE = 1
|
|
|
@@ -85,8 +85,8 @@ class BatchAnalyze:
|
|
|
# 清理显存
|
|
|
clean_vram(self.model.device, vram_threshold=8)
|
|
|
|
|
|
- ocr_time = 0
|
|
|
- ocr_count = 0
|
|
|
+ det_time = 0
|
|
|
+ det_count = 0
|
|
|
table_time = 0
|
|
|
table_count = 0
|
|
|
# reference: magic_pdf/model/doc_analyze_by_custom_model.py:doc_analyze
|
|
|
@@ -100,7 +100,7 @@ class BatchAnalyze:
|
|
|
get_res_list_from_layout_res(layout_res)
|
|
|
)
|
|
|
# ocr识别
|
|
|
- ocr_start = time.time()
|
|
|
+ det_start = time.time()
|
|
|
# Process each area that requires OCR processing
|
|
|
for res in ocr_res_list:
|
|
|
new_image, useful_list = crop_img(
|
|
|
@@ -113,21 +113,21 @@ class BatchAnalyze:
|
|
|
# OCR recognition
|
|
|
new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
|
|
|
|
|
|
- if ocr_enable:
|
|
|
- ocr_res = self.model.ocr_model.ocr(
|
|
|
- new_image, mfd_res=adjusted_mfdetrec_res
|
|
|
- )[0]
|
|
|
- else:
|
|
|
- ocr_res = self.model.ocr_model.ocr(
|
|
|
- new_image, mfd_res=adjusted_mfdetrec_res, rec=False
|
|
|
- )[0]
|
|
|
+ # if ocr_enable:
|
|
|
+ # ocr_res = self.model.ocr_model.ocr(
|
|
|
+ # new_image, mfd_res=adjusted_mfdetrec_res
|
|
|
+ # )[0]
|
|
|
+ # else:
|
|
|
+ ocr_res = self.model.ocr_model.ocr(
|
|
|
+ new_image, mfd_res=adjusted_mfdetrec_res, rec=False
|
|
|
+ )[0]
|
|
|
|
|
|
# Integration results
|
|
|
if ocr_res:
|
|
|
- ocr_result_list = get_ocr_result_list(ocr_res, useful_list)
|
|
|
+ ocr_result_list = get_ocr_result_list(ocr_res, useful_list, ocr_enable, new_image)
|
|
|
layout_res.extend(ocr_result_list)
|
|
|
- ocr_time += time.time() - ocr_start
|
|
|
- ocr_count += len(ocr_res_list)
|
|
|
+ det_time += time.time() - det_start
|
|
|
+ det_count += len(ocr_res_list)
|
|
|
|
|
|
# 表格识别 table recognition
|
|
|
if self.model.apply_table:
|
|
|
@@ -172,9 +172,33 @@ class BatchAnalyze:
|
|
|
table_time += time.time() - table_start
|
|
|
table_count += len(table_res_list)
|
|
|
|
|
|
- if self.model.apply_ocr:
|
|
|
- logger.info(f'det or det time costs: {round(ocr_time, 2)}, image num: {ocr_count}')
|
|
|
+
|
|
|
+ logger.info(f'ocr-det time: {round(det_time, 2)}, image num: {det_count}')
|
|
|
if self.model.apply_table:
|
|
|
logger.info(f'table time: {round(table_time, 2)}, image num: {table_count}')
|
|
|
|
|
|
+ need_ocr_list = []
|
|
|
+ img_crop_list = []
|
|
|
+ for layout_res in images_layout_res:
|
|
|
+ for layout_res_item in layout_res:
|
|
|
+ if layout_res_item['category_id'] in [15]:
|
|
|
+ if 'np_img' in layout_res_item:
|
|
|
+ need_ocr_list.append(layout_res_item)
|
|
|
+ img_crop_list.append(layout_res_item['np_img'])
|
|
|
+ layout_res_item.pop('np_img')
|
|
|
+
|
|
|
+ rec_time = 0
|
|
|
+ rec_start = time.time()
|
|
|
+ if len(img_crop_list) > 0:
|
|
|
+ ocr_res_list = self.model.ocr_model.ocr(img_crop_list, det=False)[0]
|
|
|
+ assert len(ocr_res_list)==len(need_ocr_list), f'ocr_res_list: {len(ocr_res_list)}, need_ocr_list: {len(need_ocr_list)}'
|
|
|
+ for index, layout_res_item in enumerate(need_ocr_list):
|
|
|
+ ocr_text, ocr_score = ocr_res_list[index]
|
|
|
+ layout_res_item['text'] = ocr_text
|
|
|
+ layout_res_item['score'] = float(round(ocr_score, 2))
|
|
|
+ rec_time += time.time() - rec_start
|
|
|
+ logger.info(f'ocr-rec time: {round(rec_time, 2)}, image num: {len(img_crop_list)}')
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
return images_layout_res
|