|
|
@@ -35,26 +35,67 @@ class RapidTableModel(object):
|
|
|
# from rapidocr_onnxruntime import RapidOCR
|
|
|
# self.ocr_engine = RapidOCR()
|
|
|
|
|
|
- self.ocr_model_name = "PaddleOCR"
|
|
|
+ # self.ocr_model_name = "PaddleOCR"
|
|
|
self.ocr_engine = ocr_engine
|
|
|
|
|
|
|
|
|
def predict(self, image):
|
|
|
+ bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
|
|
|
|
|
|
- if self.ocr_model_name == "RapidOCR":
|
|
|
- ocr_result, _ = self.ocr_engine(np.asarray(image))
|
|
|
- elif self.ocr_model_name == "PaddleOCR":
|
|
|
- bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
|
|
|
- ocr_result = self.ocr_engine.ocr(bgr_image)[0]
|
|
|
- if ocr_result:
|
|
|
- ocr_result = [[item[0], item[1][0], item[1][1]] for item in ocr_result if
|
|
|
- len(item) == 2 and isinstance(item[1], tuple)]
|
|
|
- else:
|
|
|
- ocr_result = None
|
|
|
+ # First check the overall image aspect ratio (height/width)
|
|
|
+ img_height, img_width = bgr_image.shape[:2]
|
|
|
+ img_aspect_ratio = img_height / img_width if img_width > 0 else 1.0
|
|
|
+ img_is_portrait = img_aspect_ratio > 1.2
|
|
|
+
|
|
|
+ if img_is_portrait:
|
|
|
+
|
|
|
+ det_res = self.ocr_engine.ocr(bgr_image, rec=False)[0]
|
|
|
+ # Check if table is rotated by analyzing text box aspect ratios
|
|
|
+ is_rotated = False
|
|
|
+ if det_res:
|
|
|
+ aspect_ratios = []
|
|
|
+ vertical_count = 0
|
|
|
+
|
|
|
+ for box_ocr_res in det_res:
|
|
|
+ p1, p2, p3, p4 = box_ocr_res
|
|
|
+
|
|
|
+ # Calculate width and height
|
|
|
+ width = max(np.linalg.norm(np.array(p1) - np.array(p2)),
|
|
|
+ np.linalg.norm(np.array(p3) - np.array(p4)))
|
|
|
+ height = max(np.linalg.norm(np.array(p1) - np.array(p4)),
|
|
|
+ np.linalg.norm(np.array(p2) - np.array(p3)))
|
|
|
+
|
|
|
+ aspect_ratio = width / height if height > 0 else 1.0
|
|
|
+ aspect_ratios.append(aspect_ratio)
|
|
|
+
|
|
|
+ # Count vertical vs horizontal text boxes
|
|
|
+ if aspect_ratio < 0.8: # Taller than wide - vertical text
|
|
|
+ vertical_count += 1
|
|
|
+ # elif aspect_ratio > 1.2: # Wider than tall - horizontal text
|
|
|
+ # horizontal_count += 1
|
|
|
+
|
|
|
+ # If we have more vertical text boxes than horizontal ones,
|
|
|
+ # and vertical ones are significant, table might be rotated
|
|
|
+ if vertical_count >= len(det_res) * 0.3:
|
|
|
+ is_rotated = True
|
|
|
+
|
|
|
+ # logger.debug(f"Text orientation analysis: vertical={vertical_count}, det_res={len(det_res)}, rotated={is_rotated}")
|
|
|
+
|
|
|
+ # Rotate image if necessary
|
|
|
+ if is_rotated:
|
|
|
+ # logger.debug("Table appears to be in portrait orientation, rotating 90 degrees clockwise")
|
|
|
+ image = cv2.rotate(np.asarray(image), cv2.ROTATE_90_CLOCKWISE)
|
|
|
+ bgr_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
|
|
+
|
|
|
+ # Continue with OCR on potentially rotated image
|
|
|
+ ocr_result = self.ocr_engine.ocr(bgr_image)[0]
|
|
|
+ if ocr_result:
|
|
|
+ ocr_result = [[item[0], item[1][0], item[1][1]] for item in ocr_result if
|
|
|
+ len(item) == 2 and isinstance(item[1], tuple)]
|
|
|
else:
|
|
|
- logger.error("OCR model not supported")
|
|
|
ocr_result = None
|
|
|
|
|
|
+
|
|
|
if ocr_result:
|
|
|
table_results = self.table_model(np.asarray(image), ocr_result)
|
|
|
html_code = table_results.pred_html
|