pytorchocr_utility.py 9.1 KB


  1. import os
  2. import math
  3. from pathlib import Path
  4. import numpy as np
  5. import cv2
  6. import argparse
  7. root_dir = Path(__file__).resolve().parent.parent.parent
  8. DEFAULT_CFG_PATH = root_dir / "pytorchocr" / "utils" / "resources" / "arch_config.yaml"
  9. def init_args():
  10. def str2bool(v):
  11. return v.lower() in ("true", "t", "1")
  12. parser = argparse.ArgumentParser()
  13. # params for prediction engine
  14. parser.add_argument("--use_gpu", type=str2bool, default=False)
  15. parser.add_argument("--det", type=str2bool, default=True)
  16. parser.add_argument("--rec", type=str2bool, default=True)
  17. parser.add_argument("--device", type=str, default='cpu')
  18. # parser.add_argument("--ir_optim", type=str2bool, default=True)
  19. # parser.add_argument("--use_tensorrt", type=str2bool, default=False)
  20. # parser.add_argument("--use_fp16", type=str2bool, default=False)
  21. parser.add_argument("--gpu_mem", type=int, default=500)
  22. parser.add_argument("--warmup", type=str2bool, default=False)
  23. # params for text detector
  24. parser.add_argument("--image_dir", type=str)
  25. parser.add_argument("--det_algorithm", type=str, default='DB')
  26. parser.add_argument("--det_model_path", type=str)
  27. parser.add_argument("--det_limit_side_len", type=float, default=960)
  28. parser.add_argument("--det_limit_type", type=str, default='max')
  29. # DB parmas
  30. parser.add_argument("--det_db_thresh", type=float, default=0.3)
  31. parser.add_argument("--det_db_box_thresh", type=float, default=0.6)
  32. parser.add_argument("--det_db_unclip_ratio", type=float, default=1.5)
  33. parser.add_argument("--max_batch_size", type=int, default=10)
  34. parser.add_argument("--use_dilation", type=str2bool, default=False)
  35. parser.add_argument("--det_db_score_mode", type=str, default="fast")
  36. # EAST parmas
  37. parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
  38. parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
  39. parser.add_argument("--det_east_nms_thresh", type=float, default=0.2)
  40. # SAST parmas
  41. parser.add_argument("--det_sast_score_thresh", type=float, default=0.5)
  42. parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2)
  43. parser.add_argument("--det_sast_polygon", type=str2bool, default=False)
  44. # PSE parmas
  45. parser.add_argument("--det_pse_thresh", type=float, default=0)
  46. parser.add_argument("--det_pse_box_thresh", type=float, default=0.85)
  47. parser.add_argument("--det_pse_min_area", type=float, default=16)
  48. parser.add_argument("--det_pse_box_type", type=str, default='box')
  49. parser.add_argument("--det_pse_scale", type=int, default=1)
  50. # FCE parmas
  51. parser.add_argument("--scales", type=list, default=[8, 16, 32])
  52. parser.add_argument("--alpha", type=float, default=1.0)
  53. parser.add_argument("--beta", type=float, default=1.0)
  54. parser.add_argument("--fourier_degree", type=int, default=5)
  55. parser.add_argument("--det_fce_box_type", type=str, default='poly')
  56. # params for text recognizer
  57. parser.add_argument("--rec_algorithm", type=str, default='CRNN')
  58. parser.add_argument("--rec_model_path", type=str)
  59. parser.add_argument("--rec_image_inverse", type=str2bool, default=True)
  60. parser.add_argument("--rec_image_shape", type=str, default="3, 48, 320")
  61. parser.add_argument("--rec_char_type", type=str, default='ch')
  62. parser.add_argument("--rec_batch_num", type=int, default=6)
  63. parser.add_argument("--max_text_length", type=int, default=25)
  64. parser.add_argument("--use_space_char", type=str2bool, default=True)
  65. parser.add_argument("--drop_score", type=float, default=0.5)
  66. parser.add_argument("--limited_max_width", type=int, default=1280)
  67. parser.add_argument("--limited_min_width", type=int, default=16)
  68. parser.add_argument(
  69. "--vis_font_path", type=str,
  70. default=os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), 'doc/fonts/simfang.ttf'))
  71. parser.add_argument(
  72. "--rec_char_dict_path",
  73. type=str,
  74. default=os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))),
  75. 'pytorchocr/utils/ppocr_keys_v1.txt'))
  76. # params for text classifier
  77. parser.add_argument("--use_angle_cls", type=str2bool, default=False)
  78. parser.add_argument("--cls_model_path", type=str)
  79. parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192")
  80. parser.add_argument("--label_list", type=list, default=['0', '180'])
  81. parser.add_argument("--cls_batch_num", type=int, default=6)
  82. parser.add_argument("--cls_thresh", type=float, default=0.9)
  83. parser.add_argument("--enable_mkldnn", type=str2bool, default=False)
  84. parser.add_argument("--use_pdserving", type=str2bool, default=False)
  85. # params for e2e
  86. parser.add_argument("--e2e_algorithm", type=str, default='PGNet')
  87. parser.add_argument("--e2e_model_path", type=str)
  88. parser.add_argument("--e2e_limit_side_len", type=float, default=768)
  89. parser.add_argument("--e2e_limit_type", type=str, default='max')
  90. # PGNet parmas
  91. parser.add_argument("--e2e_pgnet_score_thresh", type=float, default=0.5)
  92. parser.add_argument(
  93. "--e2e_char_dict_path", type=str,
  94. default=os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))),
  95. 'pytorchocr/utils/ic15_dict.txt'))
  96. parser.add_argument("--e2e_pgnet_valid_set", type=str, default='totaltext')
  97. parser.add_argument("--e2e_pgnet_polygon", type=bool, default=True)
  98. parser.add_argument("--e2e_pgnet_mode", type=str, default='fast')
  99. # SR parmas
  100. parser.add_argument("--sr_model_path", type=str)
  101. parser.add_argument("--sr_image_shape", type=str, default="3, 32, 128")
  102. parser.add_argument("--sr_batch_num", type=int, default=1)
  103. # params .yaml
  104. parser.add_argument("--det_yaml_path", type=str, default=None)
  105. parser.add_argument("--rec_yaml_path", type=str, default=None)
  106. parser.add_argument("--cls_yaml_path", type=str, default=None)
  107. parser.add_argument("--e2e_yaml_path", type=str, default=None)
  108. parser.add_argument("--sr_yaml_path", type=str, default=None)
  109. # multi-process
  110. parser.add_argument("--use_mp", type=str2bool, default=False)
  111. parser.add_argument("--total_process_num", type=int, default=1)
  112. parser.add_argument("--process_id", type=int, default=0)
  113. parser.add_argument("--benchmark", type=str2bool, default=False)
  114. parser.add_argument("--save_log_path", type=str, default="./log_output/")
  115. parser.add_argument("--show_log", type=str2bool, default=True)
  116. return parser
  117. def parse_args():
  118. parser = init_args()
  119. return parser.parse_args()
  120. def get_default_config(args):
  121. return vars(args)
  122. def read_network_config_from_yaml(yaml_path, char_num=None):
  123. if not os.path.exists(yaml_path):
  124. raise FileNotFoundError('{} is not existed.'.format(yaml_path))
  125. import yaml
  126. with open(yaml_path, encoding='utf-8') as f:
  127. res = yaml.safe_load(f)
  128. if res.get('Architecture') is None:
  129. raise ValueError('{} has no Architecture'.format(yaml_path))
  130. if res['Architecture']['Head']['name'] == 'MultiHead' and char_num is not None:
  131. res['Architecture']['Head']['out_channels_list'] = {
  132. 'CTCLabelDecode': char_num,
  133. 'SARLabelDecode': char_num + 2,
  134. 'NRTRLabelDecode': char_num + 3
  135. }
  136. return res['Architecture']
  137. def AnalysisConfig(weights_path, yaml_path=None, char_num=None):
  138. if not os.path.exists(os.path.abspath(weights_path)):
  139. raise FileNotFoundError('{} is not found.'.format(weights_path))
  140. if yaml_path is not None:
  141. return read_network_config_from_yaml(yaml_path, char_num=char_num)
  142. def resize_img(img, input_size=600):
  143. """
  144. resize img and limit the longest side of the image to input_size
  145. """
  146. img = np.array(img)
  147. im_shape = img.shape
  148. im_size_max = np.max(im_shape[0:2])
  149. im_scale = float(input_size) / float(im_size_max)
  150. img = cv2.resize(img, None, None, fx=im_scale, fy=im_scale)
  151. return img
  152. def str_count(s):
  153. """
  154. Count the number of Chinese characters,
  155. a single English character and a single number
  156. equal to half the length of Chinese characters.
  157. args:
  158. s(string): the input of string
  159. return(int):
  160. the number of Chinese characters
  161. """
  162. import string
  163. count_zh = count_pu = 0
  164. s_len = len(s)
  165. en_dg_count = 0
  166. for c in s:
  167. if c in string.ascii_letters or c.isdigit() or c.isspace():
  168. en_dg_count += 1
  169. elif c.isalpha():
  170. count_zh += 1
  171. else:
  172. count_pu += 1
  173. return s_len - math.ceil(en_dg_count / 2)
  174. def base64_to_cv2(b64str):
  175. import base64
  176. data = base64.b64decode(b64str.encode('utf8'))
  177. data = np.fromstring(data, np.uint8)
  178. data = cv2.imdecode(data, cv2.IMREAD_COLOR)
  179. return data
  180. def get_arch_config(model_path):
  181. from omegaconf import OmegaConf
  182. all_arch_config = OmegaConf.load(DEFAULT_CFG_PATH)
  183. path = Path(model_path)
  184. file_name = path.stem
  185. if file_name not in all_arch_config:
  186. raise ValueError(f"architecture {file_name} is not in arch_config.yaml")
  187. arch_config = all_arch_config[file_name]
  188. return arch_config