__init__.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import importlib
  15. import inspect
  16. import os
  17. import sys
  18. from types import ModuleType
  19. import filelock
  20. from ..utils import logging
  21. from ..utils.deps import class_requires_deps
  22. def get_user_home() -> str:
  23. return os.path.expanduser("~")
  24. def get_pprndr_home() -> str:
  25. return os.path.join(get_user_home(), ".pprndr")
  26. def get_sub_home(directory: str) -> str:
  27. home = os.path.join(get_pprndr_home(), directory)
  28. os.makedirs(home, exist_ok=True)
  29. return home
  30. TMP_HOME = get_sub_home("tmp")
  31. custom_ops = {
  32. "voxelize": {
  33. "sources": ["voxel/voxelize_op.cc", "voxel/voxelize_op.cu"],
  34. "version": "0.1.0",
  35. },
  36. "iou3d_nms": {
  37. "sources": [
  38. "iou3d_nms/iou3d_cpu.cpp",
  39. "iou3d_nms/iou3d_nms_api.cpp",
  40. "iou3d_nms/iou3d_nms.cpp",
  41. "iou3d_nms/iou3d_nms_kernel.cu",
  42. ],
  43. "version": "0.1.0",
  44. },
  45. }
  46. class CustomOpNotFoundException(Exception):
  47. def __init__(self, op_name):
  48. self.op_name = op_name
  49. def __str__(self):
  50. return "Couldn't Found custom op {}".format(self.op_name)
  51. class CustomOperatorPathFinder:
  52. def find_spec(self, fullname: str, path, target=None):
  53. if not fullname.startswith("paddlex.ops"):
  54. return None
  55. return importlib.machinery.ModuleSpec(
  56. name=fullname,
  57. loader=CustomOperatorPathLoader(),
  58. is_package=False,
  59. )
  60. class CustomOperatorPathLoader:
  61. def load_module(self, fullname: str):
  62. modulename = fullname.split(".")[-1]
  63. if modulename not in custom_ops:
  64. raise CustomOpNotFoundException(modulename)
  65. if fullname not in sys.modules:
  66. try:
  67. sys.modules[fullname] = importlib.import_module(modulename)
  68. except ImportError:
  69. sys.modules[fullname] = PaddleXCustomOperatorModule(
  70. modulename, fullname
  71. )
  72. return sys.modules[fullname]
  73. class PaddleXCustomOperatorModule(ModuleType):
  74. def __init__(self, modulename: str, fullname: str):
  75. self.fullname = fullname
  76. self.modulename = modulename
  77. self.module = None
  78. super().__init__(modulename)
  79. def jit_build(self):
  80. from paddle.utils.cpp_extension import load as paddle_jit_load
  81. try:
  82. lockfile = "paddlex.ops.{}".format(self.modulename)
  83. lockfile = os.path.join(TMP_HOME, lockfile)
  84. file = inspect.getabsfile(sys.modules["paddlex.ops"])
  85. rootdir = os.path.split(file)[0]
  86. args = custom_ops[self.modulename].copy()
  87. sources = args.pop("sources")
  88. sources = [os.path.join(rootdir, file) for file in sources]
  89. args.pop("version")
  90. with filelock.FileLock(lockfile):
  91. return paddle_jit_load(name=self.modulename, sources=sources, **args)
  92. except:
  93. logging.error("{} built fail!".format(self.modulename))
  94. raise
  95. def _load_module(self):
  96. if self.module is None:
  97. try:
  98. self.module = importlib.import_module(self.modulename)
  99. except ImportError:
  100. logging.warning(
  101. "No custom op {} found, try JIT build".format(self.modulename)
  102. )
  103. self.module = self.jit_build()
  104. logging.info("{} built success!".format(self.modulename))
  105. # refresh
  106. sys.modules[self.fullname] = self.module
  107. return self.module
  108. def __getattr__(self, attr: str):
  109. if attr in ["__path__", "__file__"]:
  110. return None
  111. if attr in ["__loader__", "__package__", "__name__", "__spec__"]:
  112. return super().__getattr__(attr)
  113. module = self._load_module()
  114. if not hasattr(module, attr):
  115. raise ImportError(
  116. "cannot import name '{}' from '{}' ({})".format(
  117. attr, self.modulename, module.__file__
  118. )
  119. )
  120. return getattr(module, attr)
  121. sys.meta_path.insert(0, CustomOperatorPathFinder())