Forráskód Böngészése

support more format of data size

gaotingquan 1 éve
szülő
commit
31e7ea876b
1 módosított fájl, 16 hozzáadás és 7 törlés
  1. 16 7
      paddlex/inference/components/transforms/image/common.py

+ 16 - 7
paddlex/inference/components/transforms/image/common.py

@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import ast
 import math
 from pathlib import Path
 from copy import deepcopy
@@ -95,6 +96,20 @@ class ReadImage(_BaseRead):
     def apply(self, img):
         """apply"""
 
+        def rand_data():
+            def parse_size(s):
+                res = ast.literal_eval(s)
+                if isinstance(res, int):
+                    return (res, res)
+                else:
+                    assert isinstance(res, (tuple, list))
+                    assert len(res) == 2
+                    assert all(isinstance(item, int) for item in res)
+                    return res
+
+            size = parse_size(INFER_BENCHMARK_DATA_SIZE)
+            return np.random.randint(0, 256, (*size, 3), dtype=np.uint8)
+
         def process_ndarray(img):
             with temp_file_manager.temp_file_context(suffix=".png") as temp_file:
                 img_path = Path(temp_file.name)
@@ -110,14 +125,8 @@ class ReadImage(_BaseRead):
                 }
 
         if INFER_BENCHMARK and img is None:
-            size = int(INFER_BENCHMARK_DATA_SIZE)
             for _ in range(INFER_BENCHMARK_ITER):
-                yield [
-                    process_ndarray(
-                        np.random.randint(0, 256, (size, size, 3), dtype=np.uint8)
-                    )
-                    for _ in range(self.batch_size)
-                ]
+                yield [process_ndarray(rand_data()) for _ in range(self.batch_size)]
 
         elif isinstance(img, np.ndarray):
             yield [process_ndarray(img)]