convert_dataset.py 3.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import os
  16. import pickle
  17. from collections import defaultdict
  18. from .....utils.deps import function_requires_deps, is_dep_available
  19. from .....utils.errors import ConvertFailedError
  20. if is_dep_available("imagesize"):
  21. import imagesize
  22. if is_dep_available("tqdm"):
  23. from tqdm import tqdm
  24. def check_src_dataset(root_dir, dataset_type):
  25. """check src dataset format validity"""
  26. if dataset_type in ("MSTextRecDataset"):
  27. pass
  28. else:
  29. raise ConvertFailedError(
  30. message=f"数据格式转换失败!不支持{dataset_type}格式数据集。当前仅支持 MSTextRecDataset 格式。"
  31. )
  32. err_msg_prefix = f"数据格式转换失败!请参考上述`{dataset_type}格式数据集示例`检查待转换数据集格式。"
  33. for anno in ["train.txt", "val.txt", "latex_ocr_tokenizer.json"]:
  34. src_anno_path = os.path.join(root_dir, anno)
  35. if not os.path.exists(src_anno_path):
  36. raise ConvertFailedError(
  37. message=f"{err_msg_prefix}保证{src_anno_path}文件存在。"
  38. )
  39. return None
  40. def convert(dataset_type, input_dir):
  41. """convert dataset to pkl format"""
  42. # check format validity
  43. check_src_dataset(input_dir, dataset_type)
  44. if dataset_type in ("MSTextRecDataset"):
  45. convert_pkl_dataset(input_dir)
  46. else:
  47. raise ConvertFailedError(
  48. message=f"数据格式转换失败!不支持{dataset_type}格式数据集。当前仅支持 MSTextRecDataset 格式。"
  49. )
  50. def convert_pkl_dataset(root_dir):
  51. for anno in ["train.txt", "val.txt"]:
  52. src_img_dir = root_dir
  53. src_anno_path = os.path.join(root_dir, anno)
  54. txt2pickle(src_img_dir, src_anno_path, root_dir)
  55. @function_requires_deps("tqdm", "imagesize")
  56. def txt2pickle(images, equations, save_dir):
  57. phase = os.path.basename(equations).replace(".txt", "")
  58. save_p = os.path.join(save_dir, "latexocr_{}.pkl".format(phase))
  59. min_dimensions = (32, 32)
  60. max_dimensions = (672, 192)
  61. data = defaultdict(lambda: [])
  62. pic_num = 0
  63. if images is not None and equations is not None:
  64. with open(equations, "r") as f:
  65. lines = f.readlines()
  66. for l in tqdm(lines, total=len(lines)):
  67. l = l.strip()
  68. img_name, equation = l.split("\t")
  69. img_path = os.path.join(images, img_name)
  70. width, height = imagesize.get(img_path)
  71. if (
  72. min_dimensions[0] <= width <= max_dimensions[0]
  73. and min_dimensions[1] <= height <= max_dimensions[1]
  74. ):
  75. divide_h = math.ceil(height / 16) * 16
  76. divide_w = math.ceil(width / 16) * 16
  77. data[(divide_w, divide_h)].append((equation, img_name))
  78. pic_num += 1
  79. data = dict(data)
  80. with open(save_p, "wb") as file:
  81. pickle.dump(data, file)