base_ocr_v20.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. import os
  2. import torch
  3. from .modeling.architectures.base_model import BaseModel
  4. class BaseOCRV20:
  5. def __init__(self, config, **kwargs):
  6. self.config = config
  7. self.build_net(**kwargs)
  8. self.net.eval()
  9. def build_net(self, **kwargs):
  10. self.net = BaseModel(self.config, **kwargs)
  11. def read_pytorch_weights(self, weights_path):
  12. if not os.path.exists(weights_path):
  13. raise FileNotFoundError('{} is not existed.'.format(weights_path))
  14. weights = torch.load(weights_path)
  15. return weights
  16. def get_out_channels(self, weights):
  17. if list(weights.keys())[-1].endswith('.weight') and len(list(weights.values())[-1].shape) == 2:
  18. out_channels = list(weights.values())[-1].numpy().shape[1]
  19. else:
  20. out_channels = list(weights.values())[-1].numpy().shape[0]
  21. return out_channels
  22. def load_state_dict(self, weights):
  23. self.net.load_state_dict(weights)
  24. # print('weights is loaded.')
  25. def load_pytorch_weights(self, weights_path):
  26. self.net.load_state_dict(torch.load(weights_path, weights_only=True))
  27. # print('model is loaded: {}'.format(weights_path))
  28. def inference(self, inputs):
  29. with torch.no_grad():
  30. infer = self.net(inputs)
  31. return infer