| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227 |
- import os
- import math
- from pathlib import Path
- import numpy as np
- import cv2
- import argparse
- root_dir = Path(__file__).resolve().parent.parent.parent
- DEFAULT_CFG_PATH = root_dir / "pytorchocr" / "utils" / "resources" / "arch_config.yaml"
- def init_args():
- def str2bool(v):
- return v.lower() in ("true", "t", "1")
- parser = argparse.ArgumentParser()
- # params for prediction engine
- parser.add_argument("--use_gpu", type=str2bool, default=False)
- parser.add_argument("--det", type=str2bool, default=True)
- parser.add_argument("--rec", type=str2bool, default=True)
- parser.add_argument("--device", type=str, default='cpu')
- # parser.add_argument("--ir_optim", type=str2bool, default=True)
- # parser.add_argument("--use_tensorrt", type=str2bool, default=False)
- # parser.add_argument("--use_fp16", type=str2bool, default=False)
- parser.add_argument("--gpu_mem", type=int, default=500)
- parser.add_argument("--warmup", type=str2bool, default=False)
- # params for text detector
- parser.add_argument("--image_dir", type=str)
- parser.add_argument("--det_algorithm", type=str, default='DB')
- parser.add_argument("--det_model_path", type=str)
- parser.add_argument("--det_limit_side_len", type=float, default=960)
- parser.add_argument("--det_limit_type", type=str, default='max')
- # DB parmas
- parser.add_argument("--det_db_thresh", type=float, default=0.3)
- parser.add_argument("--det_db_box_thresh", type=float, default=0.6)
- parser.add_argument("--det_db_unclip_ratio", type=float, default=1.5)
- parser.add_argument("--max_batch_size", type=int, default=10)
- parser.add_argument("--use_dilation", type=str2bool, default=False)
- parser.add_argument("--det_db_score_mode", type=str, default="fast")
- # EAST parmas
- parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
- parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
- parser.add_argument("--det_east_nms_thresh", type=float, default=0.2)
- # SAST parmas
- parser.add_argument("--det_sast_score_thresh", type=float, default=0.5)
- parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2)
- parser.add_argument("--det_sast_polygon", type=str2bool, default=False)
- # PSE parmas
- parser.add_argument("--det_pse_thresh", type=float, default=0)
- parser.add_argument("--det_pse_box_thresh", type=float, default=0.85)
- parser.add_argument("--det_pse_min_area", type=float, default=16)
- parser.add_argument("--det_pse_box_type", type=str, default='box')
- parser.add_argument("--det_pse_scale", type=int, default=1)
- # FCE parmas
- parser.add_argument("--scales", type=list, default=[8, 16, 32])
- parser.add_argument("--alpha", type=float, default=1.0)
- parser.add_argument("--beta", type=float, default=1.0)
- parser.add_argument("--fourier_degree", type=int, default=5)
- parser.add_argument("--det_fce_box_type", type=str, default='poly')
- # params for text recognizer
- parser.add_argument("--rec_algorithm", type=str, default='CRNN')
- parser.add_argument("--rec_model_path", type=str)
- parser.add_argument("--rec_image_inverse", type=str2bool, default=True)
- parser.add_argument("--rec_image_shape", type=str, default="3, 48, 320")
- parser.add_argument("--rec_char_type", type=str, default='ch')
- parser.add_argument("--rec_batch_num", type=int, default=6)
- parser.add_argument("--max_text_length", type=int, default=25)
- parser.add_argument("--use_space_char", type=str2bool, default=True)
- parser.add_argument("--drop_score", type=float, default=0.5)
- parser.add_argument("--limited_max_width", type=int, default=1280)
- parser.add_argument("--limited_min_width", type=int, default=16)
- parser.add_argument(
- "--vis_font_path", type=str,
- default=os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), 'doc/fonts/simfang.ttf'))
- parser.add_argument(
- "--rec_char_dict_path",
- type=str,
- default=os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))),
- 'pytorchocr/utils/ppocr_keys_v1.txt'))
- # params for text classifier
- parser.add_argument("--use_angle_cls", type=str2bool, default=False)
- parser.add_argument("--cls_model_path", type=str)
- parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192")
- parser.add_argument("--label_list", type=list, default=['0', '180'])
- parser.add_argument("--cls_batch_num", type=int, default=6)
- parser.add_argument("--cls_thresh", type=float, default=0.9)
- parser.add_argument("--enable_mkldnn", type=str2bool, default=False)
- parser.add_argument("--use_pdserving", type=str2bool, default=False)
- # params for e2e
- parser.add_argument("--e2e_algorithm", type=str, default='PGNet')
- parser.add_argument("--e2e_model_path", type=str)
- parser.add_argument("--e2e_limit_side_len", type=float, default=768)
- parser.add_argument("--e2e_limit_type", type=str, default='max')
- # PGNet parmas
- parser.add_argument("--e2e_pgnet_score_thresh", type=float, default=0.5)
- parser.add_argument(
- "--e2e_char_dict_path", type=str,
- default=os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))),
- 'pytorchocr/utils/ic15_dict.txt'))
- parser.add_argument("--e2e_pgnet_valid_set", type=str, default='totaltext')
- parser.add_argument("--e2e_pgnet_polygon", type=bool, default=True)
- parser.add_argument("--e2e_pgnet_mode", type=str, default='fast')
- # SR parmas
- parser.add_argument("--sr_model_path", type=str)
- parser.add_argument("--sr_image_shape", type=str, default="3, 32, 128")
- parser.add_argument("--sr_batch_num", type=int, default=1)
- # params .yaml
- parser.add_argument("--det_yaml_path", type=str, default=None)
- parser.add_argument("--rec_yaml_path", type=str, default=None)
- parser.add_argument("--cls_yaml_path", type=str, default=None)
- parser.add_argument("--e2e_yaml_path", type=str, default=None)
- parser.add_argument("--sr_yaml_path", type=str, default=None)
- # multi-process
- parser.add_argument("--use_mp", type=str2bool, default=False)
- parser.add_argument("--total_process_num", type=int, default=1)
- parser.add_argument("--process_id", type=int, default=0)
- parser.add_argument("--benchmark", type=str2bool, default=False)
- parser.add_argument("--save_log_path", type=str, default="./log_output/")
- parser.add_argument("--show_log", type=str2bool, default=True)
- return parser
- def parse_args():
- parser = init_args()
- return parser.parse_args()
- def get_default_config(args):
- return vars(args)
- def read_network_config_from_yaml(yaml_path, char_num=None):
- if not os.path.exists(yaml_path):
- raise FileNotFoundError('{} is not existed.'.format(yaml_path))
- import yaml
- with open(yaml_path, encoding='utf-8') as f:
- res = yaml.safe_load(f)
- if res.get('Architecture') is None:
- raise ValueError('{} has no Architecture'.format(yaml_path))
- if res['Architecture']['Head']['name'] == 'MultiHead' and char_num is not None:
- res['Architecture']['Head']['out_channels_list'] = {
- 'CTCLabelDecode': char_num,
- 'SARLabelDecode': char_num + 2,
- 'NRTRLabelDecode': char_num + 3
- }
- return res['Architecture']
- def AnalysisConfig(weights_path, yaml_path=None, char_num=None):
- if not os.path.exists(os.path.abspath(weights_path)):
- raise FileNotFoundError('{} is not found.'.format(weights_path))
- if yaml_path is not None:
- return read_network_config_from_yaml(yaml_path, char_num=char_num)
- def resize_img(img, input_size=600):
- """
- resize img and limit the longest side of the image to input_size
- """
- img = np.array(img)
- im_shape = img.shape
- im_size_max = np.max(im_shape[0:2])
- im_scale = float(input_size) / float(im_size_max)
- img = cv2.resize(img, None, None, fx=im_scale, fy=im_scale)
- return img
- def str_count(s):
- """
- Count the number of Chinese characters,
- a single English character and a single number
- equal to half the length of Chinese characters.
- args:
- s(string): the input of string
- return(int):
- the number of Chinese characters
- """
- import string
- count_zh = count_pu = 0
- s_len = len(s)
- en_dg_count = 0
- for c in s:
- if c in string.ascii_letters or c.isdigit() or c.isspace():
- en_dg_count += 1
- elif c.isalpha():
- count_zh += 1
- else:
- count_pu += 1
- return s_len - math.ceil(en_dg_count / 2)
- def base64_to_cv2(b64str):
- import base64
- data = base64.b64decode(b64str.encode('utf8'))
- data = np.fromstring(data, np.uint8)
- data = cv2.imdecode(data, cv2.IMREAD_COLOR)
- return data
- def get_arch_config(model_path):
- from omegaconf import OmegaConf
- all_arch_config = OmegaConf.load(DEFAULT_CFG_PATH)
- path = Path(model_path)
- file_name = path.stem
- if file_name not in all_arch_config:
- raise ValueError(f"architecture {file_name} is not in arch_config.yaml")
- arch_config = all_arch_config[file_name]
- return arch_config
|