mirror of https://github.com/vladmandic/automatic
further fixes requested by review
parent
ee3b141297
commit
2177609e54
|
|
@ -7,8 +7,7 @@ from typing import Dict, Optional
|
||||||
import installer
|
import installer
|
||||||
from modules.logger import log
|
from modules.logger import log
|
||||||
from modules.json_helpers import readfile, writefile
|
from modules.json_helpers import readfile, writefile
|
||||||
|
from modules.shared import cmd_opts, opts
|
||||||
from modules.shared import cmd_opts
|
|
||||||
from modules.devices import has_rocm
|
from modules.devices import has_rocm
|
||||||
|
|
||||||
from scripts.rocm.rocm_vars import ROCM_ENV_VARS # pylint: disable=no-name-in-module
|
from scripts.rocm.rocm_vars import ROCM_ENV_VARS # pylint: disable=no-name-in-module
|
||||||
|
|
@ -16,13 +15,11 @@ from scripts.rocm import rocm_profiles # pylint: disable=no-name-in-module
|
||||||
|
|
||||||
|
|
||||||
def _check_rocm() -> bool:
|
def _check_rocm() -> bool:
|
||||||
from modules import shared
|
if getattr(cmd_opts, 'use_rocm', False):
|
||||||
if getattr(shared.cmd_opts, 'use_rocm', False):
|
|
||||||
return True
|
return True
|
||||||
if installer.torch_info.get('type') == 'rocm':
|
if installer.torch_info.get('type') == 'rocm':
|
||||||
return True
|
return True
|
||||||
import torch # pylint: disable=import-outside-toplevel
|
return has_rocm()
|
||||||
return hasattr(torch.version, 'hip') and torch.version.hip is not None
|
|
||||||
|
|
||||||
|
|
||||||
is_rocm = _check_rocm()
|
is_rocm = _check_rocm()
|
||||||
|
|
@ -111,8 +108,7 @@ def _resolve_dtype() -> str:
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
try:
|
try:
|
||||||
from modules import shared as _sh # pylint: disable=import-outside-toplevel
|
v = getattr(opts, 'cuda_dtype', None)
|
||||||
v = getattr(_sh.opts, 'cuda_dtype', None)
|
|
||||||
if v in ('FP16', 'BF16', 'FP32'):
|
if v in ('FP16', 'BF16', 'FP32'):
|
||||||
return v
|
return v
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
@ -483,7 +479,7 @@ def info() -> dict:
|
||||||
|
|
||||||
|
|
||||||
# Apply saved config to os.environ at import time (only when ROCm is present)
|
# Apply saved config to os.environ at import time (only when ROCm is present)
|
||||||
if has_rocm() and getattr(cmd_opts, 'use_rocm', False) and sys.platform == 'win32':
|
if is_rocm and sys.platform == 'win32':
|
||||||
try:
|
try:
|
||||||
apply_env()
|
apply_env()
|
||||||
except Exception as _e:
|
except Exception as _e:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue