amend pr #9894b2d8

pull/356/head
Kahsolt 2023-10-16 21:05:02 +08:00
parent 9894b2d822
commit 4a6e3312d5
1 changed files with 8 additions and 13 deletions

View File

@ -60,7 +60,6 @@
import os
import json
import torch
import modules
import numpy as np
import gradio as gr
@ -75,12 +74,11 @@ from tile_methods.mixtureofdiffusers import MixtureOfDiffusers
from tile_utils.utils import *
from tile_utils.typing import *
CFG_PATH = os.path.join(scripts.basedir(), 'region_configs')
BBOX_MAX_NUM = min(getattr(shared.cmd_opts, 'md_max_regions', 8), 16)
class Script(modules.scripts.Script):
class Script(scripts.Script):
def __init__(self):
self.controlnet_script: ModuleType = None
@ -92,7 +90,7 @@ class Script(modules.scripts.Script):
return 'Tiled Diffusion'
def show(self, is_img2img):
return modules.scripts.AlwaysVisible
return scripts.AlwaysVisible
def ui(self, is_img2img):
tab = 't2i' if not is_img2img else 'i2i'
@ -358,11 +356,8 @@ class Script(modules.scripts.Script):
break
''' hijack inner APIs '''
if not hasattr(Script, "create_sampler_original_md"):
if getattr(Script, "create_sampler_original_md", None) is None:
Script.create_sampler_original_md = sd_samplers.create_sampler
else:
if Script.create_sampler_original_md is None:
Script.create_sampler_original_md = sd_samplers.create_sampler
sd_samplers.create_sampler = lambda name, model: self.create_sampler_hijack(
name, model, p, Method(method),
tile_width, tile_height, overlap, tile_batch_size,
@ -433,10 +428,8 @@ class Script(modules.scripts.Script):
print('warn: noise inversion only supports the Euler sampler, switch to it sliently...')
name = 'Euler'
p.sampler_name = 'Euler'
if name is None:
print('name is empty')
if model is None:
print('model is empty')
if name is None: print('>> name is empty')
if model is None: print('>> model is empty')
sampler = Script.create_sampler_original_md(name, model)
if method == Method.MULTI_DIFF: delegate_cls = MultiDiffusion
elif method == Method.MIX_DIFF: delegate_cls = MixtureOfDiffusers
@ -600,6 +593,8 @@ class Script(modules.scripts.Script):
mem = psutil.Process(os.getpid()).memory_info()
print(f'[Mem] rss: {mem.rss/2**30:.3f} GB, vms: {mem.vms/2**30:.3f} GB')
from modules.shared import mem_mon as vram_mon
from modules.memmon import MemUsageMonitor
vram_mon: MemUsageMonitor
free, total = vram_mon.cuda_mem_get_info()
print(f'[VRAM] free: {free/2**30:.3f} GB, total: {total/2**30:.3f} GB')
except: