| 123456789101112131415161718192021222324252627282930313233343536373839 |
- import os
- import torch
- from .modeling.architectures.base_model import BaseModel
- class BaseOCRV20:
- def __init__(self, config, **kwargs):
- self.config = config
- self.build_net(**kwargs)
- self.net.eval()
- def build_net(self, **kwargs):
- self.net = BaseModel(self.config, **kwargs)
- def read_pytorch_weights(self, weights_path):
- if not os.path.exists(weights_path):
- raise FileNotFoundError('{} is not existed.'.format(weights_path))
- weights = torch.load(weights_path)
- return weights
- def get_out_channels(self, weights):
- if list(weights.keys())[-1].endswith('.weight') and len(list(weights.values())[-1].shape) == 2:
- out_channels = list(weights.values())[-1].numpy().shape[1]
- else:
- out_channels = list(weights.values())[-1].numpy().shape[0]
- return out_channels
- def load_state_dict(self, weights):
- self.net.load_state_dict(weights)
- # print('weights is loaded.')
- def load_pytorch_weights(self, weights_path):
- self.net.load_state_dict(torch.load(weights_path, weights_only=True))
- # print('model is loaded: {}'.format(weights_path))
- def inference(self, inputs):
- with torch.no_grad():
- infer = self.net(inputs)
- return infer
|