diff --git a/scripts/rocm/rocm_mgr.py b/scripts/rocm/rocm_mgr.py index f9acd61a0..6ad3eebfc 100644 --- a/scripts/rocm/rocm_mgr.py +++ b/scripts/rocm/rocm_mgr.py @@ -7,8 +7,7 @@ from typing import Dict, Optional import installer from modules.logger import log from modules.json_helpers import readfile, writefile - -from modules.shared import cmd_opts +from modules.shared import cmd_opts, opts from modules.devices import has_rocm from scripts.rocm.rocm_vars import ROCM_ENV_VARS # pylint: disable=no-name-in-module @@ -16,13 +15,11 @@ from scripts.rocm import rocm_profiles # pylint: disable=no-name-in-module def _check_rocm() -> bool: - from modules import shared - if getattr(shared.cmd_opts, 'use_rocm', False): + if getattr(cmd_opts, 'use_rocm', False): return True if installer.torch_info.get('type') == 'rocm': return True - import torch # pylint: disable=import-outside-toplevel - return hasattr(torch.version, 'hip') and torch.version.hip is not None + return has_rocm() is_rocm = _check_rocm() @@ -111,8 +108,7 @@ def _resolve_dtype() -> str: except Exception: pass try: - from modules import shared as _sh # pylint: disable=import-outside-toplevel - v = getattr(_sh.opts, 'cuda_dtype', None) + v = getattr(opts, 'cuda_dtype', None) if v in ('FP16', 'BF16', 'FP32'): return v except Exception: @@ -483,7 +479,7 @@ def info() -> dict: # Apply saved config to os.environ at import time (only when ROCm is present) -if has_rocm() and getattr(cmd_opts, 'use_rocm', False) and sys.platform == 'win32': +if is_rocm and sys.platform == 'win32': try: apply_env() except Exception as _e: