|
|
@@ -174,13 +174,14 @@ class PaddleOrientationClsModel:
|
|
|
return x
|
|
|
|
|
|
def batch_predict(
|
|
|
- self, imgs: List[Dict], atom_model_manager, ocr_model_name: str, batch_size: int
|
|
|
+ self, imgs: List[Dict], batch_size: int
|
|
|
) -> None:
|
|
|
"""
|
|
|
批量预测传入的包含图片信息列表的旋转信息,并且将旋转过的图片正确地旋转回来
|
|
|
"""
|
|
|
- # 按语言分组,跳过长宽比小于1.2的图片
|
|
|
- lang_groups = defaultdict(list)
|
|
|
+ RESOLUTION_GROUP_STRIDE = 64
|
|
|
+ # 跳过长宽比小于1.2的图片
|
|
|
+ resolution_groups = defaultdict(list)
|
|
|
for img in imgs:
|
|
|
# RGB图像转换BGR
|
|
|
table_img: np.ndarray = cv2.cvtColor(img["table_img"], cv2.COLOR_RGB2BGR)
|
|
|
@@ -189,89 +190,73 @@ class PaddleOrientationClsModel:
|
|
|
img_aspect_ratio = img_height / img_width if img_width > 0 else 1.0
|
|
|
img_is_portrait = img_aspect_ratio > 1.2
|
|
|
if img_is_portrait:
|
|
|
- lang = img["lang"]
|
|
|
- lang_groups[lang].append(img)
|
|
|
-
|
|
|
- # 对每种语言按分辨率分组并批处理
|
|
|
- for lang, lang_group_img_list in lang_groups.items():
|
|
|
- if not lang_group_img_list:
|
|
|
- continue
|
|
|
-
|
|
|
- # 获取OCR模型
|
|
|
- ocr_model = atom_model_manager.get_atom_model(
|
|
|
- atom_model_name=ocr_model_name, det_db_box_thresh=0.3, lang=lang
|
|
|
- )
|
|
|
-
|
|
|
- # 按分辨率分组并同时完成padding
|
|
|
- resolution_groups = defaultdict(list)
|
|
|
- for img in lang_group_img_list:
|
|
|
h, w = img["table_img_bgr"].shape[:2]
|
|
|
- normalized_h = ((h + 32) // 32) * 32 # 向上取整到32的倍数
|
|
|
- normalized_w = ((w + 32) // 32) * 32
|
|
|
+ normalized_h = ((h + RESOLUTION_GROUP_STRIDE) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE # 向上取整到RESOLUTION_GROUP_STRIDE的倍数
|
|
|
+ normalized_w = ((w + RESOLUTION_GROUP_STRIDE) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE
|
|
|
group_key = (normalized_h, normalized_w)
|
|
|
resolution_groups[group_key].append(img)
|
|
|
|
|
|
# 对每个分辨率组进行批处理
|
|
|
- for group_key, group_imgs in tqdm(
|
|
|
- resolution_groups.items(), desc=f"ORI CLS OCR-det {lang}"
|
|
|
- ):
|
|
|
-
|
|
|
- # 计算目标尺寸(组内最大尺寸,向上取整到32的倍数)
|
|
|
- max_h = max(img["table_img_bgr"].shape[0] for img in group_imgs)
|
|
|
- max_w = max(img["table_img_bgr"].shape[1] for img in group_imgs)
|
|
|
- target_h = ((max_h + 32 - 1) // 32) * 32
|
|
|
- target_w = ((max_w + 32 - 1) // 32) * 32
|
|
|
-
|
|
|
- # 对所有图像进行padding到统一尺寸
|
|
|
- batch_images = []
|
|
|
- for img in group_imgs:
|
|
|
- table_img_ndarray = img["table_img_bgr"]
|
|
|
- h, w = table_img_ndarray.shape[:2]
|
|
|
- # 创建目标尺寸的白色背景
|
|
|
- padded_img = np.ones((target_h, target_w, 3), dtype=np.uint8) * 255
|
|
|
- # 将原图像粘贴到左上角
|
|
|
- padded_img[:h, :w] = table_img_ndarray
|
|
|
- batch_images.append(padded_img)
|
|
|
-
|
|
|
- # 批处理检测
|
|
|
- det_batch_size = min(len(batch_images), batch_size) # 增加批处理大小
|
|
|
- batch_results = ocr_model.text_detector.batch_predict(
|
|
|
- batch_images, det_batch_size
|
|
|
- )
|
|
|
+ for group_key, group_imgs in tqdm(
|
|
|
+ resolution_groups.items(), desc=f"ORI CLS OCR-det"
|
|
|
+ ):
|
|
|
+
|
|
|
+ # 计算目标尺寸(组内最大尺寸,向上取整到RESOLUTION_GROUP_STRIDE的倍数)
|
|
|
+ max_h = max(img["table_img_bgr"].shape[0] for img in group_imgs)
|
|
|
+ max_w = max(img["table_img_bgr"].shape[1] for img in group_imgs)
|
|
|
+ target_h = ((max_h + RESOLUTION_GROUP_STRIDE - 1) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE
|
|
|
+ target_w = ((max_w + RESOLUTION_GROUP_STRIDE - 1) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE
|
|
|
+
|
|
|
+ # 对所有图像进行padding到统一尺寸
|
|
|
+ batch_images = []
|
|
|
+ for img in group_imgs:
|
|
|
+ table_img_ndarray = img["table_img_bgr"]
|
|
|
+ h, w = table_img_ndarray.shape[:2]
|
|
|
+ # 创建目标尺寸的白色背景
|
|
|
+ padded_img = np.ones((target_h, target_w, 3), dtype=np.uint8) * 255
|
|
|
+ # 将原图像粘贴到左上角
|
|
|
+ padded_img[:h, :w] = table_img_ndarray
|
|
|
+ batch_images.append(padded_img)
|
|
|
+
|
|
|
+ # 批处理检测
|
|
|
+ det_batch_size = min(len(batch_images), batch_size) # 增加批处理大小
|
|
|
+ batch_results = self.ocr_engine.text_detector.batch_predict(
|
|
|
+ batch_images, det_batch_size
|
|
|
+ )
|
|
|
|
|
|
- rotated_imgs = []
|
|
|
- # 根据批处理结果检测图像是否旋转,旋转的图像放入列表中,继续进行旋转角度的预测
|
|
|
- for index, (img_info, (dt_boxes, elapse)) in enumerate(
|
|
|
- zip(group_imgs, batch_results)
|
|
|
- ):
|
|
|
- vertical_count = 0
|
|
|
- for box_ocr_res in dt_boxes:
|
|
|
- p1, p2, p3, p4 = box_ocr_res
|
|
|
+ rotated_imgs = []
|
|
|
+ # 根据批处理结果检测图像是否旋转,旋转的图像放入列表中,继续进行旋转角度的预测
|
|
|
+ for index, (img_info, (dt_boxes, elapse)) in enumerate(
|
|
|
+ zip(group_imgs, batch_results)
|
|
|
+ ):
|
|
|
+ vertical_count = 0
|
|
|
+ for box_ocr_res in dt_boxes:
|
|
|
+ p1, p2, p3, p4 = box_ocr_res
|
|
|
|
|
|
- # Calculate width and height
|
|
|
- width = p3[0] - p1[0]
|
|
|
- height = p3[1] - p1[1]
|
|
|
+ # Calculate width and height
|
|
|
+ width = p3[0] - p1[0]
|
|
|
+ height = p3[1] - p1[1]
|
|
|
|
|
|
- aspect_ratio = width / height if height > 0 else 1.0
|
|
|
+ aspect_ratio = width / height if height > 0 else 1.0
|
|
|
|
|
|
- # Count vertical text boxes
|
|
|
- if aspect_ratio < 0.8: # Taller than wide - vertical text
|
|
|
- vertical_count += 1
|
|
|
+ # Count vertical text boxes
|
|
|
+ if aspect_ratio < 0.8: # Taller than wide - vertical text
|
|
|
+ vertical_count += 1
|
|
|
|
|
|
- if vertical_count >= len(dt_boxes) * 0.28 and vertical_count >= 3:
|
|
|
- rotated_imgs.append(img_info)
|
|
|
- if len(rotated_imgs) > 0:
|
|
|
- x = self.batch_preprocess(rotated_imgs)
|
|
|
- results = self.sess.run(None, {"x": x})
|
|
|
- for img_info, res in zip(rotated_imgs, results[0]):
|
|
|
- label = self.labels[np.argmax(res)]
|
|
|
- if label == "270":
|
|
|
- img_info["table_img"] = cv2.rotate(
|
|
|
- np.asarray(img_info["table_img"]),
|
|
|
- cv2.ROTATE_90_CLOCKWISE,
|
|
|
- )
|
|
|
- elif label == "90":
|
|
|
- img_info["table_img"] = cv2.rotate(
|
|
|
- np.asarray(img_info["table_img"]),
|
|
|
- cv2.ROTATE_90_COUNTERCLOCKWISE,
|
|
|
- )
|
|
|
+ if vertical_count >= len(dt_boxes) * 0.28 and vertical_count >= 3:
|
|
|
+ rotated_imgs.append(img_info)
|
|
|
+ if len(rotated_imgs) > 0:
|
|
|
+ x = self.batch_preprocess(rotated_imgs)
|
|
|
+ results = self.sess.run(None, {"x": x})
|
|
|
+ for img_info, res in zip(rotated_imgs, results[0]):
|
|
|
+ label = self.labels[np.argmax(res)]
|
|
|
+ if label == "270":
|
|
|
+ img_info["table_img"] = cv2.rotate(
|
|
|
+ np.asarray(img_info["table_img"]),
|
|
|
+ cv2.ROTATE_90_CLOCKWISE,
|
|
|
+ )
|
|
|
+ elif label == "90":
|
|
|
+ img_info["table_img"] = cv2.rotate(
|
|
|
+ np.asarray(img_info["table_img"]),
|
|
|
+ cv2.ROTATE_90_COUNTERCLOCKWISE,
|
|
|
+ )
|