__init__.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import importlib
  15. import inspect
  16. import os
  17. import sys
  18. from types import ModuleType
  19. import filelock
  20. from paddle.utils.cpp_extension import load as paddle_jit_load
  21. from ..utils import logging
  22. def get_user_home() -> str:
  23. return os.path.expanduser("~")
  24. def get_pprndr_home() -> str:
  25. return os.path.join(get_user_home(), ".pprndr")
  26. def get_sub_home(directory: str) -> str:
  27. home = os.path.join(get_pprndr_home(), directory)
  28. os.makedirs(home, exist_ok=True)
  29. return home
  30. TMP_HOME = get_sub_home("tmp")
  31. custom_ops = {
  32. "voxelize": {
  33. "sources": ["voxel/voxelize_op.cc", "voxel/voxelize_op.cu"],
  34. "version": "0.1.0",
  35. },
  36. "iou3d_nms": {
  37. "sources": [
  38. "iou3d_nms/iou3d_cpu.cpp",
  39. "iou3d_nms/iou3d_nms_api.cpp",
  40. "iou3d_nms/iou3d_nms.cpp",
  41. "iou3d_nms/iou3d_nms_kernel.cu",
  42. ],
  43. "version": "0.1.0",
  44. },
  45. }
  46. class CustomOpNotFoundException(Exception):
  47. def __init__(self, op_name):
  48. self.op_name = op_name
  49. def __str__(self):
  50. return "Couldn't Found custom op {}".format(self.op_name)
  51. class CustomOperatorPathFinder:
  52. def find_module(self, fullname: str, path: str = None):
  53. if not fullname.startswith("paddlex.ops"):
  54. return None
  55. return CustomOperatorPathLoader()
  56. class CustomOperatorPathLoader:
  57. def load_module(self, fullname: str):
  58. modulename = fullname.split(".")[-1]
  59. if modulename not in custom_ops:
  60. raise CustomOpNotFoundException(modulename)
  61. if fullname not in sys.modules:
  62. try:
  63. sys.modules[fullname] = importlib.import_module(modulename)
  64. except ImportError:
  65. sys.modules[fullname] = PaddleXCustomOperatorModule(
  66. modulename, fullname
  67. )
  68. return sys.modules[fullname]
  69. class PaddleXCustomOperatorModule(ModuleType):
  70. def __init__(self, modulename: str, fullname: str):
  71. self.fullname = fullname
  72. self.modulename = modulename
  73. self.module = None
  74. super().__init__(modulename)
  75. def jit_build(self):
  76. try:
  77. lockfile = "paddlex.ops.{}".format(self.modulename)
  78. lockfile = os.path.join(TMP_HOME, lockfile)
  79. file = inspect.getabsfile(sys.modules["paddlex.ops"])
  80. rootdir = os.path.split(file)[0]
  81. args = custom_ops[self.modulename].copy()
  82. sources = args.pop("sources")
  83. sources = [os.path.join(rootdir, file) for file in sources]
  84. args.pop("version")
  85. with filelock.FileLock(lockfile):
  86. return paddle_jit_load(name=self.modulename, sources=sources, **args)
  87. except:
  88. logging.error("{} builded fail!".format(self.modulename))
  89. raise
  90. def _load_module(self):
  91. if self.module is None:
  92. try:
  93. self.module = importlib.import_module(self.modulename)
  94. except ImportError:
  95. logging.warning(
  96. "No custom op {} found, try JIT build".format(self.modulename)
  97. )
  98. self.module = self.jit_build()
  99. logging.info("{} builded success!".format(self.modulename))
  100. # refresh
  101. sys.modules[self.fullname] = self.module
  102. return self.module
  103. def __getattr__(self, attr: str):
  104. if attr in ["__path__", "__file__"]:
  105. return None
  106. if attr in ["__loader__", "__package__", "__name__", "__spec__"]:
  107. return super().__getattr__(attr)
  108. module = self._load_module()
  109. if not hasattr(module, attr):
  110. raise ImportError(
  111. "cannot import name '{}' from '{}' ({})".format(
  112. attr, self.modulename, module.__file__
  113. )
  114. )
  115. return getattr(module, attr)
  116. sys.meta_path.insert(0, CustomOperatorPathFinder())