__init__.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import importlib
  15. import inspect
  16. import os
  17. import sys
  18. from types import ModuleType
  19. import filelock
  20. from ..utils import logging
  21. from ..utils.deps import class_requires_deps
  22. def get_user_home() -> str:
  23. return os.path.expanduser("~")
  24. def get_pprndr_home() -> str:
  25. return os.path.join(get_user_home(), ".pprndr")
  26. def get_sub_home(directory: str) -> str:
  27. home = os.path.join(get_pprndr_home(), directory)
  28. os.makedirs(home, exist_ok=True)
  29. return home
  30. TMP_HOME = get_sub_home("tmp")
  31. custom_ops = {
  32. "voxelize": {
  33. "sources": ["voxel/voxelize_op.cc", "voxel/voxelize_op.cu"],
  34. "version": "0.1.0",
  35. },
  36. "iou3d_nms": {
  37. "sources": [
  38. "iou3d_nms/iou3d_cpu.cpp",
  39. "iou3d_nms/iou3d_nms_api.cpp",
  40. "iou3d_nms/iou3d_nms.cpp",
  41. "iou3d_nms/iou3d_nms_kernel.cu",
  42. ],
  43. "version": "0.1.0",
  44. },
  45. }
  46. class CustomOpNotFoundException(Exception):
  47. def __init__(self, op_name):
  48. self.op_name = op_name
  49. def __str__(self):
  50. return "Couldn't Found custom op {}".format(self.op_name)
  51. class CustomOperatorPathFinder:
  52. def find_module(self, fullname: str, path: str = None):
  53. if not fullname.startswith("paddlex.ops"):
  54. return None
  55. return CustomOperatorPathLoader()
  56. class CustomOperatorPathLoader:
  57. def load_module(self, fullname: str):
  58. modulename = fullname.split(".")[-1]
  59. if modulename not in custom_ops:
  60. raise CustomOpNotFoundException(modulename)
  61. if fullname not in sys.modules:
  62. try:
  63. sys.modules[fullname] = importlib.import_module(modulename)
  64. except ImportError:
  65. sys.modules[fullname] = PaddleXCustomOperatorModule(
  66. modulename, fullname
  67. )
  68. return sys.modules[fullname]
  69. class PaddleXCustomOperatorModule(ModuleType):
  70. def __init__(self, modulename: str, fullname: str):
  71. self.fullname = fullname
  72. self.modulename = modulename
  73. self.module = None
  74. super().__init__(modulename)
  75. def jit_build(self):
  76. from paddle.utils.cpp_extension import load as paddle_jit_load
  77. try:
  78. lockfile = "paddlex.ops.{}".format(self.modulename)
  79. lockfile = os.path.join(TMP_HOME, lockfile)
  80. file = inspect.getabsfile(sys.modules["paddlex.ops"])
  81. rootdir = os.path.split(file)[0]
  82. args = custom_ops[self.modulename].copy()
  83. sources = args.pop("sources")
  84. sources = [os.path.join(rootdir, file) for file in sources]
  85. args.pop("version")
  86. with filelock.FileLock(lockfile):
  87. return paddle_jit_load(name=self.modulename, sources=sources, **args)
  88. except:
  89. logging.error("{} builded fail!".format(self.modulename))
  90. raise
  91. def _load_module(self):
  92. if self.module is None:
  93. try:
  94. self.module = importlib.import_module(self.modulename)
  95. except ImportError:
  96. logging.warning(
  97. "No custom op {} found, try JIT build".format(self.modulename)
  98. )
  99. self.module = self.jit_build()
  100. logging.info("{} builded success!".format(self.modulename))
  101. # refresh
  102. sys.modules[self.fullname] = self.module
  103. return self.module
  104. def __getattr__(self, attr: str):
  105. if attr in ["__path__", "__file__"]:
  106. return None
  107. if attr in ["__loader__", "__package__", "__name__", "__spec__"]:
  108. return super().__getattr__(attr)
  109. module = self._load_module()
  110. if not hasattr(module, attr):
  111. raise ImportError(
  112. "cannot import name '{}' from '{}' ({})".format(
  113. attr, self.modulename, module.__file__
  114. )
  115. )
  116. return getattr(module, attr)
  117. sys.meta_path.insert(0, CustomOperatorPathFinder())