|
|
@@ -57,20 +57,19 @@ class MinerUPreprocessor(BasePreprocessor):
|
|
|
if isinstance(image, Image.Image):
|
|
|
image = np.array(image)
|
|
|
|
|
|
- rotate_map = {0: 0, 1: 90, 2: 180, 3: 270}
|
|
|
- rotate_label = 0
|
|
|
+ rotate_angle = 0
|
|
|
processed_image = image
|
|
|
|
|
|
# 方向校正
|
|
|
if self.orientation_classifier is not None:
|
|
|
try:
|
|
|
- rotate_label = self.orientation_classifier.predict(image)
|
|
|
- processed_image = self._apply_rotation(processed_image, rotate_label)
|
|
|
- logger.info(f"📐 Applied rotation: {rotate_label}")
|
|
|
+ rotate_angle = int(self.orientation_classifier.predict(image))
|
|
|
+ processed_image = self._apply_rotation(processed_image, rotate_angle)
|
|
|
+ logger.info(f"📐 Applied rotation: {rotate_angle}")
|
|
|
except Exception as e:
|
|
|
logger.error(f"⚠️ Orientation classification failed: {e}")
|
|
|
|
|
|
- return processed_image, rotate_map.get(rotate_label, 0)
|
|
|
+ return processed_image, rotate_angle
|
|
|
|
|
|
class MinerULayoutDetector(BaseLayoutDetector):
|
|
|
"""MinerU版式检测适配器"""
|
|
|
@@ -87,7 +86,7 @@ class MinerULayoutDetector(BaseLayoutDetector):
|
|
|
"""初始化版式检测模型"""
|
|
|
try:
|
|
|
# 获取模型配置
|
|
|
- model_name = self.config.get('model_name', 'RT-DETR-H_layout_17cls')
|
|
|
+ model_name = self.config.get('model_name', AtomicModel.Layout)
|
|
|
model_dir = self.config.get('model_dir')
|
|
|
device = self.config.get('device', 'cpu')
|
|
|
|
|
|
@@ -100,9 +99,12 @@ class MinerULayoutDetector(BaseLayoutDetector):
|
|
|
device=device
|
|
|
)
|
|
|
else:
|
|
|
- # 使用默认模型
|
|
|
+ import os
|
|
|
+ from mineru.utils.enum_class import ModelPath
|
|
|
+ from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
|
|
|
self.layout_model = self.atom_model_manager.get_atom_model(
|
|
|
atom_model_name=AtomicModel.Layout,
|
|
|
+ doclayout_yolo_weights=os.path.join(auto_download_and_get_model_root_path(ModelPath.doclayout_yolo), ModelPath.doclayout_yolo),
|
|
|
device=device
|
|
|
)
|
|
|
print(f"✅ Layout detector initialized: {model_name}")
|
|
|
@@ -130,7 +132,7 @@ class MinerULayoutDetector(BaseLayoutDetector):
|
|
|
|
|
|
# 转换结果格式
|
|
|
formatted_results = []
|
|
|
- for result in layout_results[0]: # 第一页结果
|
|
|
+ for result in layout_results: # 第一页结果
|
|
|
# 提取坐标信息
|
|
|
poly = result.get('poly', [0, 0, 0, 0, 0, 0, 0, 0])
|
|
|
if len(poly) >= 8:
|