batch.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
  2. # !/usr/bin/env python3
  3. # -*- coding: UTF-8 -*-
  4. ################################################################################
  5. #
  6. # Copyright (c) 2024 Baidu.com, Inc. All Rights Reserved
  7. #
  8. ################################################################################
  9. """
  10. Author: PaddlePaddle Authors
  11. """
  12. import inspect
  13. import functools
  14. import itertools
  15. __all__ = ['batchable_method', 'apply_batch', 'Batcher']
  16. def batchable_method(func):
  17. """ batchable """
  18. @functools.wraps(func)
  19. def _wrapper(self, input_, *args, **kwargs):
  20. if isinstance(input_, list):
  21. output = []
  22. for ele in input_:
  23. out = func(self, ele, *args, **kwargs)
  24. output.append(out)
  25. return output
  26. else:
  27. return func(self, input_, *args, **kwargs)
  28. sig = inspect.signature(func)
  29. if not len(sig.parameters) >= 2:
  30. raise TypeError(
  31. "The function to wrap should have at least two parameters.")
  32. return _wrapper
  33. def apply_batch(batch, callable_, *args, **kwargs):
  34. """ apply batch """
  35. output = []
  36. for ele in batch:
  37. out = callable_(ele, *args, **kwargs)
  38. output.append(out)
  39. return output
  40. class Batcher(object):
  41. """ Batcher """
  42. def __init__(self, iterable, batch_size=None):
  43. super().__init__()
  44. self.iterable = iterable
  45. self.batch_size = batch_size
  46. def __iter__(self):
  47. if self.batch_size is None:
  48. all_data = list(self.iterable)
  49. yield all_data
  50. else:
  51. iterator = iter(self.iterable)
  52. while True:
  53. batch = list(itertools.islice(iterator, self.batch_size))
  54. if not batch:
  55. break
  56. yield batch