|
|
@@ -2,7 +2,7 @@
|
|
|
import time
|
|
|
from collections import Counter
|
|
|
from uuid import uuid4
|
|
|
-
|
|
|
+import cv2
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
from loguru import logger
|
|
|
@@ -29,7 +29,7 @@ def split_images(image, result_images=None):
|
|
|
if result_images is None:
|
|
|
result_images = []
|
|
|
|
|
|
- width, height = image.shape[:2]
|
|
|
+ height, width = image.shape[:2]
|
|
|
long_side = max(width, height) # 获取较长边长度
|
|
|
|
|
|
if long_side <= 400:
|
|
|
@@ -68,14 +68,8 @@ def resize_images_to_224(image):
|
|
|
Works directly with NumPy arrays.
|
|
|
"""
|
|
|
try:
|
|
|
- # Handle numpy array directly
|
|
|
- if len(image.shape) == 3: # Color image
|
|
|
- height, width, channels = image.shape
|
|
|
- else: # Grayscale image
|
|
|
- height, width = image.shape
|
|
|
- image = np.stack([image] * 3, axis=2) # Convert to RGB
|
|
|
-
|
|
|
- import cv2
|
|
|
+ height, width = image.shape[:2]
|
|
|
+
|
|
|
if width < 224 or height < 224:
|
|
|
# Create black background
|
|
|
new_image = np.zeros((224, 224, 3), dtype=np.uint8)
|