|
|
@@ -12,6 +12,8 @@
|
|
|
# See the License for the specific language governing permissions and
|
|
|
# limitations under the License.
|
|
|
|
|
|
+import os
|
|
|
+
|
|
|
from ...utils.device import parse_device, set_env_for_device, get_default_device
|
|
|
from ...utils import logging
|
|
|
from .new_ir_blacklist import NEWIR_BLOCKLIST
|
|
|
@@ -122,6 +124,9 @@ class PaddlePredictorOption(object):
|
|
|
if device_type not in ("cpu"):
|
|
|
if device_ids is None or len(device_ids) > 1:
|
|
|
logging.debug(f"The device ID has been set to {device_id}.")
|
|
|
+ # XXX(gaotingquan): set flag to accelerate inference in paddle 3.0b2
|
|
|
+ if device_type in ("gpu", "cpu"):
|
|
|
+ os.environ["FLAGS_enable_pir_api"] = "1"
|
|
|
|
|
|
@property
|
|
|
def min_subgraph_size(self):
|