diff --git a/install.py b/install.py index a37ff62..8872035 100644 --- a/install.py +++ b/install.py @@ -1,48 +1,89 @@ -''' -Author: SpenserCai -Date: 2023-07-28 14:37:09 -version: -LastEditors: SpenserCai -LastEditTime: 2023-08-04 09:47:33 -Description: file content -''' import os import launch from modules import paths_internal import urllib.request from tqdm import tqdm -# 从huggingface下载权重 +import pkg_resources +from pkg_resources import DistributionNotFound, VersionConflict + +# Define model URLs +MODEL_URLS = { + "stable_model": "https://huggingface.co/spensercai/DeOldify/resolve/main/ColorizeStable_gen.pth", + "artistic_model": "https://huggingface.co/spensercai/DeOldify/resolve/main/ColorizeArtistic_gen.pth", + "video_model": "https://huggingface.co/spensercai/DeOldify/resolve/main/ColorizeVideo_gen.pth" +} + +DEPS = [ + 'wandb', + 'fastai==1.0.60', + 'tensorboardX', + 'ffmpeg', + 'ffmpeg-python', + 'yt-dlp', + 'opencv-python', + 'Pillow' +] models_dir = os.path.join(paths_internal.models_path, "deoldify") -stable_model_url = "https://huggingface.co/spensercai/DeOldify/resolve/main/ColorizeStable_gen.pth" -artistic_model_url = "https://huggingface.co/spensercai/DeOldify/resolve/main/ColorizeArtistic_gen.pth" -video_model_url = "https://huggingface.co/spensercai/DeOldify/resolve/main/ColorizeVideo_gen.pth" -stable_model_name = os.path.basename(stable_model_url) -artistic_model_name = os.path.basename(artistic_model_url) -video_model_name = os.path.basename(video_model_url) -stable_model_path = os.path.join(models_dir, stable_model_name) -artistic_model_path = os.path.join(models_dir, artistic_model_name) -video_model_path = os.path.join(models_dir, video_model_name) -if not os.path.exists(models_dir): - os.makedirs(models_dir) +# Ensure models directory exists +os.makedirs(models_dir, exist_ok=True) def download(url, path): - request = urllib.request.urlopen(url) - total = int(request.headers.get('Content-Length', 0)) - with tqdm(total=total, desc='Downloading', unit='B', unit_scale=True, unit_divisor=1024) as progress: - urllib.request.urlretrieve(url, path, reporthook=lambda count, block_size, total_size: progress.update(block_size)) + """Download a file from a URL to a specific path, showing progress.""" + if os.path.exists(path): + return False # File already exists, no need to download + try: + with urllib.request.urlopen(url) as request, open(path, 'wb') as f, tqdm( + desc=f"Downloading {os.path.basename(path)}", + total=int(request.headers.get('Content-Length', 0)), + unit='B', + unit_scale=True, + unit_divisor=1024 + ) as progress: + for chunk in iter(lambda: request.read(4096), b""): + f.write(chunk) + progress.update(len(chunk)) + return True + except Exception as e: + print(f"Failed to download {os.path.basename(path)}: {e}") + return False -if not os.path.exists(stable_model_path): - download(stable_model_url, stable_model_path) +def download_models(): + """Download all models if they don't exist locally.""" + models_already_downloaded = True + for name, url in MODEL_URLS.items(): + path = os.path.join(models_dir, os.path.basename(url)) + if download(url, path): + models_already_downloaded = False + if models_already_downloaded: + print("All models for DeOldify are already downloaded.") +def check_and_install_dependencies(): + """Install required dependencies for DeOldify, with version checks.""" + dependencies_met = True + for dep in DEPS: + package_name, *version = dep.split('==') + version = version[0] if version else None + if not check_package_installed(package_name, version): + dependencies_met = False + print(f"Installing {dep} for DeOldify extension.") + launch.run_pip(f"install {dep}", package_name) + if dependencies_met: + print("All requirements for the DeOldify extension are already installed.") -if not os.path.exists(artistic_model_path): - download(artistic_model_url, artistic_model_path) +def check_package_installed(package_name, version=None): + """Check if a package is installed with an optional version.""" + if version: + package_spec = f"{package_name}=={version}" + else: + package_spec = package_name + try: + pkg_resources.require(package_spec) + return True + except (DistributionNotFound, VersionConflict): + return False -if not os.path.exists(video_model_path): - download(video_model_url, video_model_path) - -for dep in ['wandb','fastai==1.0.60', 'tensorboardX', 'ffmpeg', 'ffmpeg-python', 'yt-dlp', 'opencv-python','Pillow']: - if not launch.is_installed(dep): - launch.run_pip(f"install {dep}", f"{dep} for DeOldify extension") +if __name__ == "__main__": + download_models() + check_and_install_dependencies()