diff --git a/modules/sdnq/common.py b/modules/sdnq/common.py index 72d5b6b3a..10351d3ba 100644 --- a/modules/sdnq/common.py +++ b/modules/sdnq/common.py @@ -1,6 +1,7 @@ # pylint: disable=redefined-builtin,no-member,protected-access import os +import json import torch from modules import shared, devices @@ -365,6 +366,9 @@ if use_torch_compile: kwargs["fullgraph"] = True if kwargs.get("dynamic", None) is None: kwargs["dynamic"] = False + if os.environ.get("SDNQ_COMPILE_KWARGS", None) is not None: + for key, value in json.loads(os.environ.get("SDNQ_COMPILE_KWARGS")).items(): + kwargs[key] = value return torch.compile(fn, **kwargs) else: def compile_func(fn, **kwargs): # pylint: disable=unused-argument