|
@@ -10,12 +10,17 @@ from paddleocr import PaddleOCR
|
|
|
from paddleocr.ppocr.utils.logging import get_logger
|
|
from paddleocr.ppocr.utils.logging import get_logger
|
|
|
from paddleocr.ppocr.utils.utility import check_and_read, alpha_to_color, binarize_img
|
|
from paddleocr.ppocr.utils.utility import check_and_read, alpha_to_color, binarize_img
|
|
|
from paddleocr.tools.infer.utility import draw_ocr_box_txt, get_rotate_crop_image, get_minarea_rect_crop
|
|
from paddleocr.tools.infer.utility import draw_ocr_box_txt, get_rotate_crop_image, get_minarea_rect_crop
|
|
|
|
|
+
|
|
|
|
|
+from magic_pdf.libs.boxbase import __is_overlaps_y_exceeds_threshold
|
|
|
|
|
+
|
|
|
logger = get_logger()
|
|
logger = get_logger()
|
|
|
|
|
|
|
|
|
|
+
|
|
|
def img_decode(content: bytes):
|
|
def img_decode(content: bytes):
|
|
|
np_arr = np.frombuffer(content, dtype=np.uint8)
|
|
np_arr = np.frombuffer(content, dtype=np.uint8)
|
|
|
return cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED)
|
|
return cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED)
|
|
|
|
|
|
|
|
|
|
+
|
|
|
def check_img(img):
|
|
def check_img(img):
|
|
|
if isinstance(img, bytes):
|
|
if isinstance(img, bytes):
|
|
|
img = img_decode(img)
|
|
img = img_decode(img)
|
|
@@ -51,6 +56,7 @@ def check_img(img):
|
|
|
|
|
|
|
|
return img
|
|
return img
|
|
|
|
|
|
|
|
|
|
+
|
|
|
def sorted_boxes(dt_boxes):
|
|
def sorted_boxes(dt_boxes):
|
|
|
"""
|
|
"""
|
|
|
Sort text boxes in order from top to bottom, left to right
|
|
Sort text boxes in order from top to bottom, left to right
|
|
@@ -75,49 +81,87 @@ def sorted_boxes(dt_boxes):
|
|
|
return _boxes
|
|
return _boxes
|
|
|
|
|
|
|
|
|
|
|
|
|
-def formula_in_text(mf_bbox, text_bbox):
|
|
|
|
|
- x1, y1, x2, y2 = mf_bbox
|
|
|
|
|
- x3, y3 = text_bbox[0]
|
|
|
|
|
- x4, y4 = text_bbox[2]
|
|
|
|
|
- left_box, right_box = None, None
|
|
|
|
|
- same_line = abs((y1+y2)/2 - (y3+y4)/2) / abs(y4-y3) < 0.2
|
|
|
|
|
- if not same_line:
|
|
|
|
|
- return False, left_box, right_box
|
|
|
|
|
- else:
|
|
|
|
|
- drop_origin = False
|
|
|
|
|
- left_x = x1 - 1
|
|
|
|
|
- right_x = x2 + 1
|
|
|
|
|
- if x3 < x1 and x2 < x4:
|
|
|
|
|
- drop_origin = True
|
|
|
|
|
- left_box = np.array([text_bbox[0], [left_x, text_bbox[1][1]], [left_x, text_bbox[2][1]], text_bbox[3]]).astype('float32')
|
|
|
|
|
- right_box = np.array([[right_x, text_bbox[0][1]], text_bbox[1], text_bbox[2], [right_x, text_bbox[3][1]]]).astype('float32')
|
|
|
|
|
- if x3 < x1 and x1 <= x4 <= x2:
|
|
|
|
|
- drop_origin = True
|
|
|
|
|
- left_box = np.array([text_bbox[0], [left_x, text_bbox[1][1]], [left_x, text_bbox[2][1]], text_bbox[3]]).astype('float32')
|
|
|
|
|
- if x1 <= x3 <= x2 and x2 < x4:
|
|
|
|
|
- drop_origin = True
|
|
|
|
|
- right_box = np.array([[right_x, text_bbox[0][1]], text_bbox[1], text_bbox[2], [right_x, text_bbox[3][1]]]).astype('float32')
|
|
|
|
|
- if x1 <= x3 < x4 <= x2:
|
|
|
|
|
- drop_origin = True
|
|
|
|
|
- return drop_origin, left_box, right_box
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-def update_det_boxes(dt_boxes, mfdetrec_res):
|
|
|
|
|
- new_dt_boxes = dt_boxes
|
|
|
|
|
- for mf_box in mfdetrec_res:
|
|
|
|
|
- flag, left_box, right_box = False, None, None
|
|
|
|
|
- for idx, text_box in enumerate(new_dt_boxes):
|
|
|
|
|
- ret, left_box, right_box = formula_in_text(mf_box['bbox'], text_box)
|
|
|
|
|
- if ret:
|
|
|
|
|
- new_dt_boxes.pop(idx)
|
|
|
|
|
- if left_box is not None:
|
|
|
|
|
- new_dt_boxes.append(left_box)
|
|
|
|
|
- if right_box is not None:
|
|
|
|
|
- new_dt_boxes.append(right_box)
|
|
|
|
|
- break
|
|
|
|
|
-
|
|
|
|
|
|
|
+def bbox_to_points(bbox):
|
|
|
|
|
+ """ 将bbox格式转换为四个顶点的数组 """
|
|
|
|
|
+ x0, y0, x1, y1 = bbox
|
|
|
|
|
+ return np.array([[x0, y0], [x1, y0], [x1, y1], [x0, y1]]).astype('float32')
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def points_to_bbox(points):
|
|
|
|
|
+ """ 将四个顶点的数组转换为bbox格式 """
|
|
|
|
|
+ x0, y0 = points[0]
|
|
|
|
|
+ x1, _ = points[1]
|
|
|
|
|
+ _, y1 = points[2]
|
|
|
|
|
+ return [x0, y0, x1, y1]
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def merge_intervals(intervals):
|
|
|
|
|
+ # Sort the intervals based on the start value
|
|
|
|
|
+ intervals.sort(key=lambda x: x[0])
|
|
|
|
|
+
|
|
|
|
|
+ merged = []
|
|
|
|
|
+ for interval in intervals:
|
|
|
|
|
+ # If the list of merged intervals is empty or if the current
|
|
|
|
|
+ # interval does not overlap with the previous, simply append it.
|
|
|
|
|
+ if not merged or merged[-1][1] < interval[0]:
|
|
|
|
|
+ merged.append(interval)
|
|
|
|
|
+ else:
|
|
|
|
|
+ # Otherwise, there is overlap, so we merge the current and previous intervals.
|
|
|
|
|
+ merged[-1][1] = max(merged[-1][1], interval[1])
|
|
|
|
|
+
|
|
|
|
|
+ return merged
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def remove_intervals(original, masks):
|
|
|
|
|
+ # Merge all mask intervals
|
|
|
|
|
+ merged_masks = merge_intervals(masks)
|
|
|
|
|
+
|
|
|
|
|
+ result = []
|
|
|
|
|
+ original_start, original_end = original
|
|
|
|
|
+
|
|
|
|
|
+ for mask in merged_masks:
|
|
|
|
|
+ mask_start, mask_end = mask
|
|
|
|
|
+
|
|
|
|
|
+ # If the mask starts after the original range, ignore it
|
|
|
|
|
+ if mask_start > original_end:
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ # If the mask ends before the original range starts, ignore it
|
|
|
|
|
+ if mask_end < original_start:
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ # Remove the masked part from the original range
|
|
|
|
|
+ if original_start < mask_start:
|
|
|
|
|
+ result.append([original_start, mask_start - 1])
|
|
|
|
|
+
|
|
|
|
|
+ original_start = max(mask_end + 1, original_start)
|
|
|
|
|
+
|
|
|
|
|
+ # Add the remaining part of the original range, if any
|
|
|
|
|
+ if original_start <= original_end:
|
|
|
|
|
+ result.append([original_start, original_end])
|
|
|
|
|
+
|
|
|
|
|
+ return result
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def update_det_boxes(dt_boxes, mfd_res):
|
|
|
|
|
+ new_dt_boxes = []
|
|
|
|
|
+ for text_box in dt_boxes:
|
|
|
|
|
+ text_bbox = points_to_bbox(text_box)
|
|
|
|
|
+ masks_list = []
|
|
|
|
|
+ for mf_box in mfd_res:
|
|
|
|
|
+ mf_bbox = mf_box['bbox']
|
|
|
|
|
+ if __is_overlaps_y_exceeds_threshold(text_bbox, mf_bbox):
|
|
|
|
|
+ masks_list.append([mf_bbox[0], mf_bbox[2]])
|
|
|
|
|
+ text_x_range = [text_bbox[0], text_bbox[2]]
|
|
|
|
|
+ text_remove_mask_range = remove_intervals(text_x_range, masks_list)
|
|
|
|
|
+ temp_dt_box = []
|
|
|
|
|
+ for text_remove_mask in text_remove_mask_range:
|
|
|
|
|
+ temp_dt_box.append(bbox_to_points([text_remove_mask[0], text_bbox[1], text_remove_mask[1], text_bbox[3]]))
|
|
|
|
|
+ if len(temp_dt_box) > 0:
|
|
|
|
|
+ new_dt_boxes.extend(temp_dt_box)
|
|
|
return new_dt_boxes
|
|
return new_dt_boxes
|
|
|
|
|
|
|
|
|
|
+
|
|
|
class ModifiedPaddleOCR(PaddleOCR):
|
|
class ModifiedPaddleOCR(PaddleOCR):
|
|
|
def ocr(self, img, det=True, rec=True, cls=True, bin=False, inv=False, mfd_res=None, alpha_color=(255, 255, 255)):
|
|
def ocr(self, img, det=True, rec=True, cls=True, bin=False, inv=False, mfd_res=None, alpha_color=(255, 255, 255)):
|
|
|
"""
|
|
"""
|
|
@@ -197,7 +241,7 @@ class ModifiedPaddleOCR(PaddleOCR):
|
|
|
if not rec:
|
|
if not rec:
|
|
|
return cls_res
|
|
return cls_res
|
|
|
return ocr_res
|
|
return ocr_res
|
|
|
-
|
|
|
|
|
|
|
+
|
|
|
def __call__(self, img, cls=True, mfd_res=None):
|
|
def __call__(self, img, cls=True, mfd_res=None):
|
|
|
time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}
|
|
time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}
|
|
|
|
|
|
|
@@ -226,7 +270,7 @@ class ModifiedPaddleOCR(PaddleOCR):
|
|
|
dt_boxes = update_det_boxes(dt_boxes, mfd_res)
|
|
dt_boxes = update_det_boxes(dt_boxes, mfd_res)
|
|
|
aft = time.time()
|
|
aft = time.time()
|
|
|
logger.debug("split text box by formula, new dt_boxes num : {}, elapsed : {}".format(
|
|
logger.debug("split text box by formula, new dt_boxes num : {}, elapsed : {}".format(
|
|
|
- len(dt_boxes), aft-bef))
|
|
|
|
|
|
|
+ len(dt_boxes), aft - bef))
|
|
|
|
|
|
|
|
for bno in range(len(dt_boxes)):
|
|
for bno in range(len(dt_boxes)):
|
|
|
tmp_box = copy.deepcopy(dt_boxes[bno])
|
|
tmp_box = copy.deepcopy(dt_boxes[bno])
|
|
@@ -257,4 +301,60 @@ class ModifiedPaddleOCR(PaddleOCR):
|
|
|
filter_rec_res.append(rec_result)
|
|
filter_rec_res.append(rec_result)
|
|
|
end = time.time()
|
|
end = time.time()
|
|
|
time_dict['all'] = end - start
|
|
time_dict['all'] = end - start
|
|
|
- return filter_boxes, filter_rec_res, time_dict
|
|
|
|
|
|
|
+ return filter_boxes, filter_rec_res, time_dict
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+if __name__ == '__main__':
|
|
|
|
|
+ def merge_intervals(intervals):
|
|
|
|
|
+ # Sort the intervals based on the start value
|
|
|
|
|
+ intervals.sort(key=lambda x: x[0])
|
|
|
|
|
+
|
|
|
|
|
+ merged = []
|
|
|
|
|
+ for interval in intervals:
|
|
|
|
|
+ # If the list of merged intervals is empty or if the current
|
|
|
|
|
+ # interval does not overlap with the previous, simply append it.
|
|
|
|
|
+ if not merged or merged[-1][1] < interval[0]:
|
|
|
|
|
+ merged.append(interval)
|
|
|
|
|
+ else:
|
|
|
|
|
+ # Otherwise, there is overlap, so we merge the current and previous intervals.
|
|
|
|
|
+ merged[-1][1] = max(merged[-1][1], interval[1])
|
|
|
|
|
+
|
|
|
|
|
+ return merged
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+ def remove_intervals(original, masks):
|
|
|
|
|
+ # Merge all mask intervals
|
|
|
|
|
+ merged_masks = merge_intervals(masks)
|
|
|
|
|
+
|
|
|
|
|
+ result = []
|
|
|
|
|
+ original_start, original_end = original
|
|
|
|
|
+
|
|
|
|
|
+ for mask in merged_masks:
|
|
|
|
|
+ mask_start, mask_end = mask
|
|
|
|
|
+
|
|
|
|
|
+ # If the mask starts after the original range, ignore it
|
|
|
|
|
+ if mask_start > original_end:
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ # If the mask ends before the original range starts, ignore it
|
|
|
|
|
+ if mask_end < original_start:
|
|
|
|
|
+ continue
|
|
|
|
|
+
|
|
|
|
|
+ # Remove the masked part from the original range
|
|
|
|
|
+ if original_start < mask_start:
|
|
|
|
|
+ result.append([original_start, mask_start - 1])
|
|
|
|
|
+
|
|
|
|
|
+ original_start = max(mask_end + 1, original_start)
|
|
|
|
|
+
|
|
|
|
|
+ # Add the remaining part of the original range, if any
|
|
|
|
|
+ if original_start <= original_end:
|
|
|
|
|
+ result.append([original_start, original_end])
|
|
|
|
|
+
|
|
|
|
|
+ return result
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+ # Test the function
|
|
|
|
|
+ original_range = [1, 100]
|
|
|
|
|
+ masks = [[0, 15], [25, 40], [55, 80]]
|
|
|
|
|
+ result = remove_intervals(original_range, masks)
|
|
|
|
|
+ print(result) # Expected output: [[1, 4], [21, 59], [81, 100]]
|