update requirements .py and .ipynb

pull/16/head
deforum 2022-10-05 15:50:36 -07:00 committed by GitHub
parent f8c72e2982
commit b25b0edbb3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 1985 additions and 13 deletions

File diff suppressed because it is too large Load Diff

View File

@ -47,7 +47,7 @@ Please read the full license here: https://huggingface.co/spaces/CompVis/stable-
import subprocess
sub_p_res = subprocess.run(['nvidia-smi', '--query-gpu=name,memory.total,memory.free', '--format=csv,noheader'], stdout=subprocess.PIPE).stdout.decode('utf-8')
print("NVIDIA GPU:")
print(f" {sub_p_res[:-1]}")
print(f"{sub_p_res[:-1]}")
# %%
# !! {"metadata":{
@ -82,8 +82,8 @@ import os
os.makedirs(models_path, exist_ok=True)
os.makedirs(output_path, exist_ok=True)
print(f" models_path: {models_path}")
print(f" output_path: {output_path}")
print(f"models_path: {models_path}")
print(f"output_path: {output_path}")
# %%
# !! {"metadata":{
@ -168,6 +168,9 @@ from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from transformers import logging
logging.set_verbosity_error()
def sanitize(prompt):
whitelist = set('abcdefghijklmnopqrstuvwxyz ABCDEFGHIJKLMNOPQRSTUVWXYZ')
tmp = ''.join(filter(whitelist.__contains__, prompt))
@ -1017,8 +1020,8 @@ else:
ckpt_valid = False
print("Config and Model Location:")
print(f" {ckpt_config_path}")
print(f" {ckpt_path}")
print(f"{ckpt_config_path}")
print(f"{ckpt_path}")
if check_sha256 and model_checkpoint != "custom" and ckpt_valid:
import hashlib
@ -1036,21 +1039,23 @@ if check_sha256 and model_checkpoint != "custom" and ckpt_valid:
if ckpt_valid:
print(f"..using {ckpt_path}")
def load_model_from_config(config, ckpt, verbose=False, device='cuda', half_precision=True):
def load_model_from_config(config, ckpt, verbose=False, device='cuda', half_precision=True,print_flag=False):
map_location = "cuda" #@param ["cpu", "cuda"]
print(f"..loading model")
pl_sd = torch.load(ckpt, map_location=map_location)
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
if print_flag:
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
if print_flag:
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
if half_precision:
model = model.half().to(device)