||
- # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import numpy as np
- import random
- import imghdr
- import os
- import signal
- from paddle.io import Dataset, DataLoader, DistributedBatchSampler
- from . import imaug
- from .imaug import transform
- from paddlex.ppcls.utils import logger
- trainers_num = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
- trainer_id = int(os.environ.get("PADDLE_TRAINER_ID", 0))
- class ModeException(Exception):
- """
- ModeException
- """
- def __init__(self, message='', mode=''):
- message += "\nOnly the following 3 modes are supported: " \
- "train, valid, test. Given mode is {}".format(mode)
- super(ModeException, self).__init__(message)
- class SampleNumException(Exception):
- """
- SampleNumException
- """
- def __init__(self, message='', sample_num=0, batch_size=1):
- message += "\nError: The number of the whole data ({}) " \
- "is smaller than the batch_size ({}), and drop_last " \
- "is turnning on, so nothing will feed in program, " \
- "Terminated now. Please reset batch_size to a smaller " \
- "number or feed more data!".format(sample_num, batch_size)
- super(SampleNumException, self).__init__(message)
- class ShuffleSeedException(Exception):
- """
- ShuffleSeedException
- """
- def __init__(self, message=''):
- message += "\nIf trainers_num > 1, the shuffle_seed must be set, " \
- "because the order of batch data generated by reader " \
- "must be the same in the respective processes."
- super(ShuffleSeedException, self).__init__(message)
- def check_params(params):
- """
- check params to avoid unexpect errors
- Args:
- params(dict):
- """
- if 'shuffle_seed' not in params:
- params['shuffle_seed'] = None
- if trainers_num > 1 and params['shuffle_seed'] is None:
- raise ShuffleSeedException()
- data_dir = params.get('data_dir', '')
- assert os.path.isdir(data_dir), \
- "{} doesn't exist, please check datadir path".format(data_dir)
- if params['mode'] != 'test':
- file_list = params.get('file_list', '')
- assert os.path.isfile(file_list), \
- "{} doesn't exist, please check file list path".format(file_list)
- def create_file_list(params):
- """
- if mode is test, create the file list
- Args:
- params(dict):
- """
- data_dir = params.get('data_dir', '')
- params['file_list'] = ".tmp.txt"
- imgtype_list = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff'}
- with open(params['file_list'], "w") as fout:
- tmp_file_list = os.listdir(data_dir)
- for file_name in tmp_file_list:
- file_path = os.path.join(data_dir, file_name)
- if imghdr.what(file_path) not in imgtype_list:
- continue
- fout.write(file_name + " 0" + "\n")
- def shuffle_lines(full_lines, seed=None):
- """
- random shuffle lines
- Args:
- full_lines(list):
- seed(int): random seed
- """
- if seed is not None:
- np.random.RandomState(seed).shuffle(full_lines)
- else:
- np.random.shuffle(full_lines)
- return full_lines
- def get_file_list(params):
- """
- read label list from file and shuffle the list
- Args:
- params(dict):
- """
- if params['mode'] == 'test':
- create_file_list(params)
- with open(params['file_list']) as flist:
- full_lines = [line.strip() for line in flist]
- if params["mode"] == "train":
- full_lines = shuffle_lines(full_lines, seed=params['shuffle_seed'])
- return full_lines
- def create_operators(params):
- """
- create operators based on the config
- Args:
- params(list): a dict list, used to create some operators
- """
- assert isinstance(params, list), ('operator config should be a list')
- ops = []
- for operator in params:
- assert isinstance(operator,
- dict) and len(operator) == 1, "yaml format error"
- op_name = list(operator)[0]
- param = {} if operator[op_name] is None else operator[op_name]
- op = getattr(imaug, op_name)(**param)
- ops.append(op)
- return ops
- def term_mp(sig_num, frame):
- """ kill all child processes
- """
- pid = os.getpid()
- pgid = os.getpgid(os.getpid())
- logger.info("main proc {} exit, kill process group "
- "{}".format(pid, pgid))
- os.killpg(pgid, signal.SIGKILL)
- return
- class CommonDataset(Dataset):
- def __init__(self, params):
- self.params = params
- self.mode = params.get("mode", "train")
- self.full_lines = get_file_list(params)
- self.delimiter = params.get('delimiter', ' ')
- self.ops = create_operators(params['transforms'])
- self.num_samples = len(self.full_lines)
- return
- def __getitem__(self, idx):
- try:
- line = self.full_lines[idx]
- img_path, label = line.split(self.delimiter)
- img_path = os.path.join(self.params['data_dir'], img_path)
- with open(img_path, 'rb') as f:
- img = f.read()
- return (transform(img, self.ops), int(label))
- except Exception as e:
- logger.error("data read faild: {}, exception info: {}".format(line,
- e))
- return self.__getitem__(random.randint(0, len(self)))
- def __len__(self):
- return self.num_samples
- class MultiLabelDataset(Dataset):
- """
- Define dataset class for multilabel image classification
- """
- def __init__(self, params):
- self.params = params
- self.mode = params.get("mode", "train")
- self.full_lines = get_file_list(params)
- self.delimiter = params.get("delimiter", "\t")
- self.ops = create_operators(params["transforms"])
- self.num_samples = len(self.full_lines)
- return
- def __getitem__(self, idx):
- try:
- line = self.full_lines[idx]
- img_path, label_str = line.split(self.delimiter)
- img_path = os.path.join(self.params["data_dir"], img_path)
- with open(img_path, "rb") as f:
- img = f.read()
- labels = label_str.split(',')
- labels = [int(i) for i in labels]
- return (transform(img, self.ops),
- np.array(labels).astype("float32"))
- except Exception as e:
- logger.error("data read failed: {}, exception info: {}".format(
- line, e))
- return self.__getitem__(random.randint(0, len(self)))
- def __len__(self):
- return self.num_samples
- class Reader:
- """
- Create a reader for trainning/validate/test
- Args:
- config(dict): arguments
- mode(str): train or val or test
- seed(int): random seed used to generate same sequence in each trainer
- Returns:
- the specific reader
- """
- def __init__(self, config, mode='train', places=None):
- try:
- self.params = config[mode.upper()]
- except KeyError:
- raise ModeException(mode=mode)
- use_mix = config.get('use_mix')
- self.params['mode'] = mode
- self.shuffle = mode == "train"
- self.collate_fn = None
- self.batch_ops = []
- if use_mix and mode == "train":
- self.batch_ops = create_operators(self.params['mix'])
- self.collate_fn = self.mix_collate_fn
- self.places = places
- self.use_xpu = config.get("use_xpu", False)
- self.multilabel = config.get("multilabel", False)
- def mix_collate_fn(self, batch):
- batch = transform(batch, self.batch_ops)
- # batch each field
- slots = []
- for items in batch:
- for i, item in enumerate(items):
- if len(slots) < len(items):
- slots.append([item])
- else:
- slots[i].append(item)
- return [np.stack(slot, axis=0) for slot in slots]
- def __call__(self):
- batch_size = int(self.params['batch_size']) // trainers_num
- if self.multilabel:
- dataset = MultiLabelDataset(self.params)
- else:
- dataset = CommonDataset(self.params)
- if (self.params['mode'] != "train") and self.use_xpu:
- loader = DataLoader(
- dataset,
- places=self.places,
- batch_size=batch_size,
- drop_last=False,
- return_list=True,
- shuffle=False,
- num_workers=self.params["num_workers"])
- else:
- is_train = self.params['mode'] == "train"
- batch_sampler = DistributedBatchSampler(
- dataset,
- batch_size=batch_size,
- shuffle=self.shuffle and is_train,
- drop_last=is_train)
- loader = DataLoader(
- dataset,
- batch_sampler=batch_sampler,
- collate_fn=self.collate_fn if is_train else None,
- places=self.places,
- return_list=True,
- num_workers=self.params["num_workers"])
- return loader
- signal.signal(signal.SIGINT, term_mp)
- signal.signal(signal.SIGTERM, term_mp)
|