|
|
@@ -16,6 +16,7 @@
|
|
|
import os
|
|
|
import json
|
|
|
import os.path as osp
|
|
|
+from PIL import Image, ImageOps
|
|
|
from collections import defaultdict
|
|
|
from .....utils.errors import DatasetFileNotFoundError, CheckFailedError
|
|
|
|
|
|
@@ -58,11 +59,21 @@ def check(dataset_dir, output, dataset_type="PubTabTableRecDataset", sample_num=
|
|
|
structure = info["html"]["structure"]["tokens"].copy()
|
|
|
|
|
|
img_path = osp.join(dataset_dir, file_name)
|
|
|
- if len(sample_paths[tag]) < max_recorded_sample_cnts:
|
|
|
- sample_paths[tag].append(os.path.relpath(img_path, output))
|
|
|
|
|
|
if not os.path.exists(img_path):
|
|
|
raise DatasetFileNotFoundError(file_path=img_path)
|
|
|
+ vis_save_dir = osp.join(output, "demo_img")
|
|
|
+ if not osp.exists(vis_save_dir):
|
|
|
+ os.makedirs(vis_save_dir)
|
|
|
+ if len(sample_paths[tag]) < sample_num:
|
|
|
+ img = Image.open(img_path)
|
|
|
+ img = ImageOps.exif_transpose(img)
|
|
|
+ vis_path = osp.join(vis_save_dir, osp.basename(file_name))
|
|
|
+ img.save(vis_path)
|
|
|
+ sample_path = osp.join(
|
|
|
+ "check_dataset", os.path.relpath(vis_path, output)
|
|
|
+ )
|
|
|
+ sample_paths[tag].append(sample_path)
|
|
|
|
|
|
boxes_num = len(cells)
|
|
|
tokens_num = sum(
|