queue.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. from __future__ import unicode_literals
  18. import sys
  19. import six
  20. if six.PY3:
  21. import pickle
  22. from io import BytesIO as StringIO
  23. else:
  24. import cPickle as pickle
  25. from cStringIO import StringIO
  26. import logging
  27. import traceback
  28. import multiprocessing as mp
  29. from multiprocessing.queues import Queue
  30. from .sharedmemory import SharedMemoryMgr
  31. logger = logging.getLogger(__name__)
  32. class SharedQueueError(ValueError):
  33. """ SharedQueueError
  34. """
  35. pass
  36. class SharedQueue(Queue):
  37. """ a Queue based on shared memory to communicate data between Process,
  38. and it's interface is compatible with 'multiprocessing.queues.Queue'
  39. """
  40. def __init__(self, maxsize=0, mem_mgr=None, memsize=None, pagesize=None):
  41. """ init
  42. """
  43. if six.PY3:
  44. super(SharedQueue, self).__init__(maxsize, ctx=mp.get_context())
  45. else:
  46. super(SharedQueue, self).__init__(maxsize)
  47. if mem_mgr is not None:
  48. self._shared_mem = mem_mgr
  49. else:
  50. self._shared_mem = SharedMemoryMgr(
  51. capacity=memsize, pagesize=pagesize)
  52. def put(self, obj, **kwargs):
  53. """ put an object to this queue
  54. """
  55. obj = pickle.dumps(obj, -1)
  56. buff = None
  57. try:
  58. buff = self._shared_mem.malloc(len(obj))
  59. buff.put(obj)
  60. super(SharedQueue, self).put(buff, **kwargs)
  61. except Exception as e:
  62. stack_info = traceback.format_exc()
  63. err_msg = 'failed to put a element to SharedQueue '\
  64. 'with stack info[%s]' % (stack_info)
  65. logger.warn(err_msg)
  66. if buff is not None:
  67. buff.free()
  68. raise e
  69. def get(self, **kwargs):
  70. """ get an object from this queue
  71. """
  72. buff = None
  73. try:
  74. buff = super(SharedQueue, self).get(**kwargs)
  75. data = buff.get()
  76. return pickle.load(StringIO(data))
  77. except Exception as e:
  78. stack_info = traceback.format_exc()
  79. err_msg = 'failed to get element from SharedQueue '\
  80. 'with stack info[%s]' % (stack_info)
  81. logger.warn(err_msg)
  82. raise e
  83. finally:
  84. if buff is not None:
  85. buff.free()
  86. def release(self):
  87. self._shared_mem.release()
  88. self._shared_mem = None