automatic/modules/mit_nunchaku.py

87 lines
2.7 KiB
Python

# MIT-Han-Lab Nunchaku: <https://github.com/mit-han-lab/nunchaku>
from installer import log, pip
from modules import devices
nunchaku_versions = {
'2.5': '1.0.1',
'2.6': '1.0.1',
'2.7': '1.1.0',
'2.8': '1.1.0',
'2.9': '1.1.0',
'2.10': '1.0.2',
'2.11': '1.1.0',
}
ok = False
def _expected_ver():
try:
import torch
torch_ver = '.'.join(torch.__version__.split('+')[0].split('.')[:2])
return nunchaku_versions.get(torch_ver)
except Exception:
return None
def check():
global ok # pylint: disable=global-statement
if ok:
return True
try:
import nunchaku
import nunchaku.utils
from nunchaku import __version__
expected = _expected_ver()
log.info(f'Nunchaku: path={nunchaku.__path__} version={__version__.__version__} precision={nunchaku.utils.get_precision()}')
if expected is not None and __version__.__version__ != expected:
ok = False
return False
ok = True
return True
except Exception as e:
log.error(f'Nunchaku: {e}')
ok = False
return False
def install_nunchaku():
if devices.backend is None:
return False # too early
if not check():
import os
import sys
import platform
import torch
python_ver = f'{sys.version_info.major}{sys.version_info.minor}'
if python_ver not in ['311', '312', '313']:
log.error(f'Nunchaku: python={sys.version_info} unsupported')
return False
arch = platform.system().lower()
if arch not in ['linux', 'windows']:
log.error(f'Nunchaku: platform={arch} unsupported')
return False
if devices.backend not in ['cuda']:
log.error(f'Nunchaku: backend={devices.backend} unsupported')
return False
torch_ver = '.'.join(torch.__version__.split('+')[0].split('.')[:2])
nunchaku_ver = nunchaku_versions.get(torch_ver)
if nunchaku_ver is None:
log.error(f'Nunchaku: torch={torch.__version__} unsupported')
return False
suffix = 'x86_64' if arch == 'linux' else 'win_amd64'
url = os.environ.get('NUNCHAKU_COMMAND', None)
if url is None:
arch = f'{arch}_' if arch == 'linux' else ''
url = f'https://huggingface.co/nunchaku-ai/nunchaku/resolve/main/nunchaku-{nunchaku_ver}'
url += f'+torch{torch_ver}-cp{python_ver}-cp{python_ver}-{arch}{suffix}.whl'
cmd = f'install --upgrade {url}'
log.debug(f'Nunchaku: install="{url}"')
pip(cmd, ignore=False, uv=False)
pass
if not check():
log.error('Nunchaku: install failed')
return False
return True