|
|
@@ -267,13 +267,13 @@ class _TableRecognitionPipelineV2(BasePipeline):
|
|
|
if input_params["use_doc_preprocessor"]:
|
|
|
use_doc_orientation_classify = input_params["use_doc_orientation_classify"]
|
|
|
use_doc_unwarping = input_params["use_doc_unwarping"]
|
|
|
- doc_preprocessor_res = next(
|
|
|
+ doc_preprocessor_res = list(
|
|
|
self.doc_preprocessor_pipeline(
|
|
|
image_array,
|
|
|
use_doc_orientation_classify=use_doc_orientation_classify,
|
|
|
use_doc_unwarping=use_doc_unwarping,
|
|
|
)
|
|
|
- )
|
|
|
+ )[0]
|
|
|
doc_preprocessor_image = doc_preprocessor_res["output_img"]
|
|
|
else:
|
|
|
doc_preprocessor_res = {}
|
|
|
@@ -686,11 +686,11 @@ class _TableRecognitionPipelineV2(BasePipeline):
|
|
|
for box in split_boxes:
|
|
|
x1, y1, x2, y2 = int(box[0]), int(box[1]), int(box[2]), int(box[3])
|
|
|
if y2 - y1 > 1 and x2 - x1 > 1:
|
|
|
- ocr_result = next(
|
|
|
+ ocr_result = list(
|
|
|
self.general_ocr_pipeline.text_rec_model(
|
|
|
ori_img[y1:y2, x1:x2, :]
|
|
|
)
|
|
|
- )
|
|
|
+ )[0]
|
|
|
# Extract the recognized text from the OCR result
|
|
|
if "rec_text" in ocr_result:
|
|
|
result = ocr_result[
|
|
|
@@ -738,7 +738,7 @@ class _TableRecognitionPipelineV2(BasePipeline):
|
|
|
x1, y1, x2, y2 = [math.ceil(k) for k in cells_bboxes[i]]
|
|
|
# Perform OCR on the defined region of the image and get the recognized text.
|
|
|
if y2 - y1 > 1 and x2 - x1 > 1:
|
|
|
- rec_te = next(self.general_ocr_pipeline(ori_img[y1:y2, x1:x2, :]))
|
|
|
+ rec_te = list(self.general_ocr_pipeline(ori_img[y1:y2, x1:x2, :]))[0]
|
|
|
# Concatenate the texts and append them to the texts_list.
|
|
|
texts_list.append("".join(rec_te["rec_texts"]))
|
|
|
# Return the list of recognized texts from each cell.
|
|
|
@@ -979,7 +979,7 @@ class _TableRecognitionPipelineV2(BasePipeline):
|
|
|
SingleTableRecognitionResult: single table recognition result.
|
|
|
"""
|
|
|
|
|
|
- table_cls_pred = next(self.table_cls_model(image_array))
|
|
|
+ table_cls_pred = list(self.table_cls_model(image_array))[0]
|
|
|
table_cls_result = self.extract_results(table_cls_pred, "cls")
|
|
|
use_e2e_model = False
|
|
|
cells_trans_to_html = False
|
|
|
@@ -988,33 +988,41 @@ class _TableRecognitionPipelineV2(BasePipeline):
|
|
|
if use_wired_table_cells_trans_to_html == True:
|
|
|
cells_trans_to_html = True
|
|
|
else:
|
|
|
- table_structure_pred = next(self.wired_table_rec_model(image_array))
|
|
|
+ table_structure_pred = list(self.wired_table_rec_model(image_array))[0]
|
|
|
if use_e2e_wired_table_rec_model == True:
|
|
|
use_e2e_model = True
|
|
|
if cells_trans_to_html == True:
|
|
|
- table_structure_pred = next(self.wired_table_rec_model(image_array))
|
|
|
+ table_structure_pred = list(
|
|
|
+ self.wired_table_rec_model(image_array)
|
|
|
+ )[0]
|
|
|
else:
|
|
|
- table_cells_pred = next(
|
|
|
+ table_cells_pred = list(
|
|
|
self.wired_table_cells_detection_model(image_array, threshold=0.3)
|
|
|
- ) # Setting the threshold to 0.3 can improve the accuracy of table cells detection.
|
|
|
+ )[
|
|
|
+ 0
|
|
|
+ ] # Setting the threshold to 0.3 can improve the accuracy of table cells detection.
|
|
|
# If you really want more or fewer table cells detection boxes, the threshold can be adjusted.
|
|
|
elif table_cls_result == "wireless_table":
|
|
|
if use_wireless_table_cells_trans_to_html == True:
|
|
|
cells_trans_to_html = True
|
|
|
else:
|
|
|
- table_structure_pred = next(self.wireless_table_rec_model(image_array))
|
|
|
+ table_structure_pred = list(self.wireless_table_rec_model(image_array))[
|
|
|
+ 0
|
|
|
+ ]
|
|
|
if use_e2e_wireless_table_rec_model == True:
|
|
|
use_e2e_model = True
|
|
|
if cells_trans_to_html == True:
|
|
|
- table_structure_pred = next(
|
|
|
+ table_structure_pred = list(
|
|
|
self.wireless_table_rec_model(image_array)
|
|
|
- )
|
|
|
+ )[0]
|
|
|
else:
|
|
|
- table_cells_pred = next(
|
|
|
+ table_cells_pred = list(
|
|
|
self.wireless_table_cells_detection_model(
|
|
|
image_array, threshold=0.3
|
|
|
)
|
|
|
- ) # Setting the threshold to 0.3 can improve the accuracy of table cells detection.
|
|
|
+ )[
|
|
|
+ 0
|
|
|
+ ] # Setting the threshold to 0.3 can improve the accuracy of table cells detection.
|
|
|
# If you really want more or fewer table cells detection boxes, the threshold can be adjusted.
|
|
|
|
|
|
if use_e2e_model == False:
|
|
|
@@ -1172,20 +1180,20 @@ class _TableRecognitionPipelineV2(BasePipeline):
|
|
|
image_array = self.img_reader(batch_data.instances)[0]
|
|
|
|
|
|
if model_settings["use_doc_preprocessor"]:
|
|
|
- doc_preprocessor_res = next(
|
|
|
+ doc_preprocessor_res = list(
|
|
|
self.doc_preprocessor_pipeline(
|
|
|
image_array,
|
|
|
use_doc_orientation_classify=use_doc_orientation_classify,
|
|
|
use_doc_unwarping=use_doc_unwarping,
|
|
|
)
|
|
|
- )
|
|
|
+ )[0]
|
|
|
else:
|
|
|
doc_preprocessor_res = {"output_img": image_array}
|
|
|
|
|
|
doc_preprocessor_image = doc_preprocessor_res["output_img"]
|
|
|
|
|
|
if model_settings["use_ocr_model"]:
|
|
|
- overall_ocr_res = next(
|
|
|
+ overall_ocr_res = list(
|
|
|
self.general_ocr_pipeline(
|
|
|
doc_preprocessor_image,
|
|
|
text_det_limit_side_len=text_det_limit_side_len,
|
|
|
@@ -1195,7 +1203,7 @@ class _TableRecognitionPipelineV2(BasePipeline):
|
|
|
text_det_unclip_ratio=text_det_unclip_ratio,
|
|
|
text_rec_score_thresh=text_rec_score_thresh,
|
|
|
)
|
|
|
- )
|
|
|
+ )[0]
|
|
|
elif self.general_ocr_pipeline is None and (
|
|
|
(
|
|
|
use_ocr_results_with_table_cells == True
|
|
|
@@ -1218,9 +1226,9 @@ class _TableRecognitionPipelineV2(BasePipeline):
|
|
|
img_height, img_width = doc_preprocessor_image.shape[:2]
|
|
|
table_box = [0, 0, img_width - 1, img_height - 1]
|
|
|
if use_table_orientation_classify == True:
|
|
|
- table_angle = next(
|
|
|
+ table_angle = list(
|
|
|
self.table_orientation_classify_model(doc_preprocessor_image)
|
|
|
- )["label_names"][0]
|
|
|
+ )[0]["label_names"][0]
|
|
|
if table_angle == "90":
|
|
|
doc_preprocessor_image = np.rot90(doc_preprocessor_image, k=1)
|
|
|
elif table_angle == "180":
|
|
|
@@ -1228,7 +1236,7 @@ class _TableRecognitionPipelineV2(BasePipeline):
|
|
|
elif table_angle == "270":
|
|
|
doc_preprocessor_image = np.rot90(doc_preprocessor_image, k=3)
|
|
|
if table_angle in ["90", "180", "270"]:
|
|
|
- overall_ocr_res = next(
|
|
|
+ overall_ocr_res = list(
|
|
|
self.general_ocr_pipeline(
|
|
|
doc_preprocessor_image,
|
|
|
text_det_limit_side_len=text_det_limit_side_len,
|
|
|
@@ -1238,7 +1246,7 @@ class _TableRecognitionPipelineV2(BasePipeline):
|
|
|
text_det_unclip_ratio=text_det_unclip_ratio,
|
|
|
text_rec_score_thresh=text_rec_score_thresh,
|
|
|
)
|
|
|
- )
|
|
|
+ )[0]
|
|
|
tbx1, tby1, tbx2, tby2 = (
|
|
|
table_box[0],
|
|
|
table_box[1],
|
|
|
@@ -1282,7 +1290,9 @@ class _TableRecognitionPipelineV2(BasePipeline):
|
|
|
table_region_id += 1
|
|
|
else:
|
|
|
if model_settings["use_layout_detection"]:
|
|
|
- layout_det_res = next(self.layout_det_model(doc_preprocessor_image))
|
|
|
+ layout_det_res = list(
|
|
|
+ self.layout_det_model(doc_preprocessor_image)
|
|
|
+ )[0]
|
|
|
img_height, img_width = doc_preprocessor_image.shape[:2]
|
|
|
for box_info in layout_det_res["boxes"]:
|
|
|
if box_info["label"].lower() in ["table"]:
|
|
|
@@ -1293,11 +1303,11 @@ class _TableRecognitionPipelineV2(BasePipeline):
|
|
|
table_box = crop_img_info["box"]
|
|
|
if use_table_orientation_classify == True:
|
|
|
doc_preprocessor_image_copy = doc_preprocessor_image.copy()
|
|
|
- table_angle = next(
|
|
|
+ table_angle = list(
|
|
|
self.table_orientation_classify_model(
|
|
|
crop_img_info["img"]
|
|
|
)
|
|
|
- )["label_names"][0]
|
|
|
+ )[0]["label_names"][0]
|
|
|
if table_angle == "90":
|
|
|
crop_img_info["img"] = np.rot90(crop_img_info["img"], k=1)
|
|
|
doc_preprocessor_image_copy = np.rot90(
|
|
|
@@ -1314,7 +1324,7 @@ class _TableRecognitionPipelineV2(BasePipeline):
|
|
|
doc_preprocessor_image_copy, k=3
|
|
|
)
|
|
|
if table_angle in ["90", "180", "270"]:
|
|
|
- overall_ocr_res = next(
|
|
|
+ overall_ocr_res = list(
|
|
|
self.general_ocr_pipeline(
|
|
|
doc_preprocessor_image_copy,
|
|
|
text_det_limit_side_len=text_det_limit_side_len,
|
|
|
@@ -1324,7 +1334,7 @@ class _TableRecognitionPipelineV2(BasePipeline):
|
|
|
text_det_unclip_ratio=text_det_unclip_ratio,
|
|
|
text_rec_score_thresh=text_rec_score_thresh,
|
|
|
)
|
|
|
- )
|
|
|
+ )[0]
|
|
|
tbx1, tby1, tbx2, tby2 = (
|
|
|
table_box[0],
|
|
|
table_box[1],
|