mirror of https://github.com/vladmandic/automatic
cleanup references to p.sd_model
Signed-off-by: Vladimir Mandic <mandic00@live.com>pull/4663/head
parent
6a72aa8083
commit
47543663f9
|
|
@ -1427,7 +1427,7 @@ def run_deferred_tasks():
|
||||||
log.debug('Starting deferred tasks')
|
log.debug('Starting deferred tasks')
|
||||||
time.sleep(1.0) # wait for server to start
|
time.sleep(1.0) # wait for server to start
|
||||||
try:
|
try:
|
||||||
from modules.sd_models import write_metadata
|
from modules.sd_checkpoint import write_metadata
|
||||||
write_metadata()
|
write_metadata()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f'Deferred task error: write_metadata {e}')
|
log.error(f'Deferred task error: write_metadata {e}')
|
||||||
|
|
@ -1446,7 +1446,6 @@ def run_deferred_tasks():
|
||||||
log.debug(f'Deferred tasks complete: time={round(time.time() - t_start, 2)}')
|
log.debug(f'Deferred tasks complete: time={round(time.time() - t_start, 2)}')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache
|
@lru_cache
|
||||||
def get_state():
|
def get_state():
|
||||||
state = {
|
state = {
|
||||||
|
|
|
||||||
|
|
@ -212,6 +212,7 @@ def start_server(immediate=True, server=None):
|
||||||
log.debug(f'Starting module: {server}')
|
log.debug(f'Starting module: {server}')
|
||||||
module_spec.loader.exec_module(server)
|
module_spec.loader.exec_module(server)
|
||||||
threading.Thread(target=installer.run_deferred_tasks, daemon=True).start()
|
threading.Thread(target=installer.run_deferred_tasks, daemon=True).start()
|
||||||
|
|
||||||
uvicorn = None
|
uvicorn = None
|
||||||
if args.test:
|
if args.test:
|
||||||
log.info("Test only")
|
log.info("Test only")
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,8 @@ import os
|
||||||
import sys
|
import sys
|
||||||
import cv2
|
import cv2
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from modules.logger import log
|
||||||
|
from modules import devices, shared, errors, processing, sd_models, sd_vae, scripts_manager, masking
|
||||||
from modules.control import util # helper functions
|
from modules.control import util # helper functions
|
||||||
from modules.control import unit # control units
|
from modules.control import unit # control units
|
||||||
from modules.control import processors # image preprocessors
|
from modules.control import processors # image preprocessors
|
||||||
|
|
@ -12,8 +14,6 @@ from modules.control.units import lite # Kohya ControlLLLite
|
||||||
from modules.control.units import t2iadapter # TencentARC T2I-Adapter
|
from modules.control.units import t2iadapter # TencentARC T2I-Adapter
|
||||||
from modules.control.units import reference # ControlNet-Reference
|
from modules.control.units import reference # ControlNet-Reference
|
||||||
from modules.control.processor import preprocess_image
|
from modules.control.processor import preprocess_image
|
||||||
from modules import devices, shared, errors, processing, sd_models, sd_vae, scripts_manager, masking
|
|
||||||
from modules.logger import log
|
|
||||||
from modules.processing_class import StableDiffusionProcessingControl
|
from modules.processing_class import StableDiffusionProcessingControl
|
||||||
from modules.ui_common import infotext_to_html
|
from modules.ui_common import infotext_to_html
|
||||||
from modules.api import script
|
from modules.api import script
|
||||||
|
|
@ -449,8 +449,8 @@ def control_run(state: str = '', # pylint: disable=keyword-arg-before-vararg
|
||||||
setattr(p, k, v)
|
setattr(p, k, v)
|
||||||
p_extra_args = {}
|
p_extra_args = {}
|
||||||
|
|
||||||
if shared.sd_model is None:
|
if shared.sd_model is None: # triggers load if not loaded
|
||||||
log.warning('Aborted: op=control model not loaded')
|
log.warning('Aborted: op=generate model not loaded')
|
||||||
return [], '', '', 'Error: model not loaded'
|
return [], '', '', 'Error: model not loaded'
|
||||||
|
|
||||||
unit_type = unit_type.strip().lower() if unit_type is not None else ''
|
unit_type = unit_type.strip().lower() if unit_type is not None else ''
|
||||||
|
|
|
||||||
|
|
@ -235,7 +235,7 @@ def face_id(
|
||||||
faceid_model_name = None
|
faceid_model_name = None
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
ipadapter.unapply(p.sd_model)
|
ipadapter.unapply(shared.sd_model)
|
||||||
extra_networks.deactivate(p, p.network_data)
|
extra_networks.deactivate(p, p.network_data)
|
||||||
|
|
||||||
p.extra_generation_params["IP Adapter"] = f"{basename}:{scale}"
|
p.extra_generation_params["IP Adapter"] = f"{basename}:{scale}"
|
||||||
|
|
@ -243,7 +243,7 @@ def face_id(
|
||||||
if faceid_model is not None and original_load_ip_adapter is not None:
|
if faceid_model is not None and original_load_ip_adapter is not None:
|
||||||
faceid_model.__class__.load_ip_adapter = original_load_ip_adapter
|
faceid_model.__class__.load_ip_adapter = original_load_ip_adapter
|
||||||
if shared.opts.cuda_compile_backend == 'none':
|
if shared.opts.cuda_compile_backend == 'none':
|
||||||
token_merge.remove_token_merging(p.sd_model)
|
token_merge.remove_token_merging(shared.sd_model)
|
||||||
script_callbacks.after_process_callback(p)
|
script_callbacks.after_process_callback(p)
|
||||||
|
|
||||||
return processed_images
|
return processed_images
|
||||||
|
|
|
||||||
|
|
@ -178,6 +178,8 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
|
||||||
debug_log(f'Network check: type=LoRA requested={requested} status=forced')
|
debug_log(f'Network check: type=LoRA requested={requested} status=forced')
|
||||||
return True
|
return True
|
||||||
sd_model = shared.sd_model.pipe if hasattr(shared.sd_model, 'pipe') else shared.sd_model
|
sd_model = shared.sd_model.pipe if hasattr(shared.sd_model, 'pipe') else shared.sd_model
|
||||||
|
if sd_model is None:
|
||||||
|
return False
|
||||||
if not hasattr(sd_model, 'loaded_loras'):
|
if not hasattr(sd_model, 'loaded_loras'):
|
||||||
sd_model.loaded_loras = {}
|
sd_model.loaded_loras = {}
|
||||||
if include is None or len(include) == 0:
|
if include is None or len(include) == 0:
|
||||||
|
|
|
||||||
|
|
@ -167,47 +167,46 @@ class ModelData:
|
||||||
self.sd_refiner = v
|
self.sd_refiner = v
|
||||||
|
|
||||||
|
|
||||||
|
model_data = ModelData()
|
||||||
|
|
||||||
|
|
||||||
# provides shared.sd_model field as a property
|
# provides shared.sd_model field as a property
|
||||||
class Shared(sys.modules[__name__].__class__):
|
class Shared(sys.modules[__name__].__class__):
|
||||||
@property
|
@property
|
||||||
def sd_loaded(self):
|
def sd_loaded(self):
|
||||||
import modules.sd_models # pylint: disable=W0621
|
return model_data.sd_model is not None
|
||||||
return modules.sd_models.model_data.sd_model is not None
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sd_model(self):
|
def sd_model(self):
|
||||||
import modules.sd_models # pylint: disable=W0621
|
if model_data.sd_model is None:
|
||||||
if modules.sd_models.model_data.sd_model is None:
|
|
||||||
fn = f'{os.path.basename(sys._getframe(2).f_code.co_filename)}:{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access
|
fn = f'{os.path.basename(sys._getframe(2).f_code.co_filename)}:{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access
|
||||||
log.debug(f'Model requested: fn={fn}') # pylint: disable=protected-access
|
log.debug(f'Model requested: fn={fn}')
|
||||||
return modules.sd_models.model_data.get_sd_model()
|
model = model_data.get_sd_model()
|
||||||
|
return model
|
||||||
|
|
||||||
@sd_model.setter
|
@sd_model.setter
|
||||||
def sd_model(self, value):
|
def sd_model(self, value):
|
||||||
import modules.sd_models # pylint: disable=W0621
|
|
||||||
if value is None:
|
if value is None:
|
||||||
fn = f'{os.path.basename(sys._getframe(2).f_code.co_filename)}:{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access
|
fn = f'{os.path.basename(sys._getframe(2).f_code.co_filename)}:{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access
|
||||||
log.debug(f'Model unloaded: fn={fn}') # pylint: disable=protected-access
|
if model_data.sd_model is not None:
|
||||||
modules.sd_models.model_data.set_sd_model(value)
|
log.debug(f'Model unloaded: fn={fn}')
|
||||||
|
model_data.set_sd_model(value)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sd_refiner(self):
|
def sd_refiner(self):
|
||||||
import modules.sd_models # pylint: disable=W0621
|
return model_data.get_sd_refiner()
|
||||||
return modules.sd_models.model_data.get_sd_refiner()
|
|
||||||
|
|
||||||
@sd_refiner.setter
|
@sd_refiner.setter
|
||||||
def sd_refiner(self, value):
|
def sd_refiner(self, value):
|
||||||
import modules.sd_models # pylint: disable=W0621
|
model_data.set_sd_refiner(value)
|
||||||
modules.sd_models.model_data.set_sd_refiner(value)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sd_model_type(self):
|
def sd_model_type(self):
|
||||||
try:
|
try:
|
||||||
import modules.sd_models # pylint: disable=W0621
|
if model_data.sd_model is None:
|
||||||
if modules.sd_models.model_data.sd_model is None:
|
|
||||||
model_type = 'none'
|
model_type = 'none'
|
||||||
return model_type
|
return model_type
|
||||||
model_type = get_model_type(modules.sd_models.model_data.sd_model)
|
model_type = get_model_type(model_data.sd_model)
|
||||||
except Exception:
|
except Exception:
|
||||||
model_type = 'unknown'
|
model_type = 'unknown'
|
||||||
return model_type
|
return model_type
|
||||||
|
|
@ -215,11 +214,10 @@ class Shared(sys.modules[__name__].__class__):
|
||||||
@property
|
@property
|
||||||
def sd_refiner_type(self):
|
def sd_refiner_type(self):
|
||||||
try:
|
try:
|
||||||
import modules.sd_models # pylint: disable=W0621
|
if model_data.sd_refiner is None:
|
||||||
if modules.sd_models.model_data.sd_refiner is None:
|
|
||||||
model_type = 'none'
|
model_type = 'none'
|
||||||
return model_type
|
return model_type
|
||||||
model_type = get_model_type(modules.sd_models.model_data.sd_refiner)
|
model_type = get_model_type(model_data.sd_refiner)
|
||||||
except Exception:
|
except Exception:
|
||||||
model_type = 'unknown'
|
model_type = 'unknown'
|
||||||
return model_type
|
return model_type
|
||||||
|
|
@ -231,6 +229,3 @@ class Shared(sys.modules[__name__].__class__):
|
||||||
return get_console()
|
return get_console()
|
||||||
except ImportError:
|
except ImportError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
model_data = ModelData()
|
|
||||||
|
|
|
||||||
|
|
@ -76,8 +76,10 @@ class YoloRestorer(Detailer):
|
||||||
return list(self.list)
|
return list(self.list)
|
||||||
|
|
||||||
def dependencies(self):
|
def dependencies(self):
|
||||||
import installer
|
from installer import install
|
||||||
installer.install('ultralytics==8.3.40', ignore=True, quiet=True)
|
install('ultralytics==8.3.40', ignore=True, quiet=True)
|
||||||
|
install('omegaconf')
|
||||||
|
install('antlr4-python3-runtime')
|
||||||
|
|
||||||
def predict(
|
def predict(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,6 @@ from modules.processing_class import ( # pylint: disable=unused-import
|
||||||
StableDiffusionProcessingControl,
|
StableDiffusionProcessingControl,
|
||||||
)
|
)
|
||||||
from modules.processing_info import create_infotext
|
from modules.processing_info import create_infotext
|
||||||
from modules.modeldata import model_data
|
|
||||||
|
|
||||||
|
|
||||||
opt_C = 4
|
opt_C = 4
|
||||||
|
|
@ -39,7 +38,7 @@ processed = None # last known processed results
|
||||||
|
|
||||||
class Processed:
|
class Processed:
|
||||||
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info=None, subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments="", binary=None, audio=None):
|
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info=None, subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments="", binary=None, audio=None):
|
||||||
self.sd_model_hash = getattr(shared.sd_model, 'sd_model_hash', '') if model_data.sd_model is not None else ''
|
self.sd_model_hash = getattr(shared.sd_model, 'sd_model_hash', '') if shared.sd_loaded is not None else ''
|
||||||
|
|
||||||
self.prompt = p.prompt or ''
|
self.prompt = p.prompt or ''
|
||||||
self.negative_prompt = p.negative_prompt or ''
|
self.negative_prompt = p.negative_prompt or ''
|
||||||
|
|
@ -139,8 +138,11 @@ def get_processed(*args, **kwargs):
|
||||||
def process_images(p: StableDiffusionProcessing) -> Processed:
|
def process_images(p: StableDiffusionProcessing) -> Processed:
|
||||||
timer.process.reset()
|
timer.process.reset()
|
||||||
debug(f'Process images: class={p.__class__.__name__} {vars(p)}')
|
debug(f'Process images: class={p.__class__.__name__} {vars(p)}')
|
||||||
if not hasattr(p.sd_model, 'sd_checkpoint_info'):
|
if shared.sd_model is None:
|
||||||
log.error('Processing: incomplete model')
|
log.warning('Aborted: op=process model not loaded')
|
||||||
|
return None
|
||||||
|
if not hasattr(shared.sd_model, 'sd_checkpoint_info'):
|
||||||
|
log.error('Aborted: op=process incomplete model')
|
||||||
return None
|
return None
|
||||||
if p.abort:
|
if p.abort:
|
||||||
log.debug('Processing: aborted')
|
log.debug('Processing: aborted')
|
||||||
|
|
|
||||||
|
|
@ -70,8 +70,10 @@ def restore_state(p: processing.StableDiffusionProcessing):
|
||||||
|
|
||||||
def process_pre(p: processing.StableDiffusionProcessing):
|
def process_pre(p: processing.StableDiffusionProcessing):
|
||||||
from modules import ipadapter, sd_hijack_freeu, para_attention, teacache, hidiffusion, ras, pag, cfgzero, transformer_cache, token_merge, linfusion, cachedit
|
from modules import ipadapter, sd_hijack_freeu, para_attention, teacache, hidiffusion, ras, pag, cfgzero, transformer_cache, token_merge, linfusion, cachedit
|
||||||
|
if shared.sd_model is None:
|
||||||
|
log.warning('Processing modifiers: model not loaded')
|
||||||
|
return
|
||||||
log.info('Processing modifiers: apply')
|
log.info('Processing modifiers: apply')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# apply-with-unapply
|
# apply-with-unapply
|
||||||
sd_models_compile.check_deepcache(enable=True)
|
sd_models_compile.check_deepcache(enable=True)
|
||||||
|
|
@ -505,19 +507,25 @@ def process_decode(p: processing.StableDiffusionProcessing, output):
|
||||||
|
|
||||||
|
|
||||||
def update_pipeline(sd_model, p: processing.StableDiffusionProcessing):
|
def update_pipeline(sd_model, p: processing.StableDiffusionProcessing):
|
||||||
|
if sd_model is None:
|
||||||
|
sd_model = shared.sd_model
|
||||||
|
if sd_model is None:
|
||||||
|
shared.log.warning('Processing: op=update model not loaded')
|
||||||
|
return None
|
||||||
|
updated_model = sd_model
|
||||||
if sd_models.get_diffusers_task(sd_model) == sd_models.DiffusersTaskType.INPAINTING and getattr(p, 'image_mask', None) is None and p.task_args.get('image_mask', None) is None and getattr(p, 'mask', None) is None:
|
if sd_models.get_diffusers_task(sd_model) == sd_models.DiffusersTaskType.INPAINTING and getattr(p, 'image_mask', None) is None and p.task_args.get('image_mask', None) is None and getattr(p, 'mask', None) is None:
|
||||||
log.warning('Processing: mode=inpaint mask=None')
|
log.warning('Processing: mode=inpaint mask=None')
|
||||||
sd_model = sd_models.set_diffuser_pipe(sd_model, sd_models.DiffusersTaskType.IMAGE_2_IMAGE)
|
updated_model = sd_models.set_diffuser_pipe(sd_model, sd_models.DiffusersTaskType.IMAGE_2_IMAGE)
|
||||||
if shared.opts.cuda_compile_backend == "olive-ai":
|
if shared.opts.cuda_compile_backend == "olive-ai":
|
||||||
sd_model = olive_check_parameters_changed(p, is_refiner_enabled(p))
|
updated_model = olive_check_parameters_changed(p, is_refiner_enabled(p))
|
||||||
if sd_model.__class__.__name__ == "OnnxRawPipeline":
|
if sd_model.__class__.__name__ == "OnnxRawPipeline":
|
||||||
sd_model = preprocess_onnx_pipeline(p)
|
updated_model = preprocess_onnx_pipeline(p)
|
||||||
global orig_pipeline # pylint: disable=global-statement
|
global orig_pipeline # pylint: disable=global-statement
|
||||||
orig_pipeline = sd_model # processed ONNX pipeline should not be replaced with original pipeline.
|
orig_pipeline = updated_model # processed ONNX pipeline should not be replaced with original pipeline.
|
||||||
if getattr(sd_model, "current_attn_name", None) != shared.opts.cross_attention_optimization:
|
if getattr(updated_model, "current_attn_name", None) != shared.opts.cross_attention_optimization:
|
||||||
log.info(f"Setting attention optimization: {shared.opts.cross_attention_optimization}")
|
log.info(f"Setting attention optimization: {shared.opts.cross_attention_optimization}")
|
||||||
attention.set_diffusers_attention(sd_model)
|
attention.set_diffusers_attention(updated_model)
|
||||||
return sd_model
|
return updated_model
|
||||||
|
|
||||||
|
|
||||||
def validate_pipeline(p: processing.StableDiffusionProcessing):
|
def validate_pipeline(p: processing.StableDiffusionProcessing):
|
||||||
|
|
|
||||||
|
|
@ -79,7 +79,7 @@ class PromptEmbedder:
|
||||||
earlyout = self.checkcache(p)
|
earlyout = self.checkcache(p)
|
||||||
if earlyout:
|
if earlyout:
|
||||||
return
|
return
|
||||||
self.pipe = prepare_model(p.sd_model)
|
self.pipe = prepare_model(shared.sd_model)
|
||||||
if self.pipe is None:
|
if self.pipe is None:
|
||||||
log.error("Prompt encode: cannot find text encoder in model")
|
log.error("Prompt encode: cannot find text encoder in model")
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -146,15 +146,15 @@ def ratio_to_region(width: float, offset: float, n: int):
|
||||||
|
|
||||||
def apply_freeu(p):
|
def apply_freeu(p):
|
||||||
global state_enabled # pylint: disable=global-statement
|
global state_enabled # pylint: disable=global-statement
|
||||||
if hasattr(p.sd_model, 'enable_freeu'):
|
if hasattr(shared.sd_model, 'enable_freeu'):
|
||||||
if shared.opts.freeu_enabled:
|
if shared.opts.freeu_enabled:
|
||||||
freeu_device = get_fft_device()
|
freeu_device = get_fft_device()
|
||||||
if freeu_device != devices.cpu:
|
if freeu_device != devices.cpu:
|
||||||
p.extra_generation_params['FreeU'] = f'b1={shared.opts.freeu_b1} b2={shared.opts.freeu_b2} s1={shared.opts.freeu_s1} s2={shared.opts.freeu_s2}'
|
p.extra_generation_params['FreeU'] = f'b1={shared.opts.freeu_b1} b2={shared.opts.freeu_b2} s1={shared.opts.freeu_s1} s2={shared.opts.freeu_s2}'
|
||||||
p.sd_model.enable_freeu(s1=shared.opts.freeu_s1, s2=shared.opts.freeu_s2, b1=shared.opts.freeu_b1, b2=shared.opts.freeu_b2)
|
shared.sd_model.enable_freeu(s1=shared.opts.freeu_s1, s2=shared.opts.freeu_s2, b1=shared.opts.freeu_b1, b2=shared.opts.freeu_b2)
|
||||||
state_enabled = True
|
state_enabled = True
|
||||||
elif state_enabled:
|
elif state_enabled:
|
||||||
p.sd_model.disable_freeu()
|
shared.sd_model.disable_freeu()
|
||||||
state_enabled = False
|
state_enabled = False
|
||||||
if shared.opts.freeu_enabled and state_enabled:
|
if shared.opts.freeu_enabled and state_enabled:
|
||||||
log.info(f'Applying Free-U: b1={shared.opts.freeu_b1} b2={shared.opts.freeu_b2} s1={shared.opts.freeu_s1} s2={shared.opts.freeu_s2}')
|
log.info(f'Applying Free-U: b1={shared.opts.freeu_b1} b2={shared.opts.freeu_b2} s1={shared.opts.freeu_s1} s2={shared.opts.freeu_s2}')
|
||||||
|
|
|
||||||
|
|
@ -178,7 +178,7 @@ def split_attention(layer: nn.Module, tile_size: int=256, min_tile_size: int=128
|
||||||
|
|
||||||
def context_hypertile_vae(p):
|
def context_hypertile_vae(p):
|
||||||
from modules import shared
|
from modules import shared
|
||||||
if p.sd_model is None or not shared.opts.hypertile_vae_enabled:
|
if shared.sd_model is None or not shared.opts.hypertile_vae_enabled:
|
||||||
return nullcontext()
|
return nullcontext()
|
||||||
if shared.opts.cross_attention_optimization == 'Sub-quadratic':
|
if shared.opts.cross_attention_optimization == 'Sub-quadratic':
|
||||||
log.warning('Hypertile UNet is not compatible with Sub-quadratic cross-attention optimization')
|
log.warning('Hypertile UNet is not compatible with Sub-quadratic cross-attention optimization')
|
||||||
|
|
@ -188,7 +188,7 @@ def context_hypertile_vae(p):
|
||||||
error_reported = False
|
error_reported = False
|
||||||
set_resolution(p)
|
set_resolution(p)
|
||||||
max_h, max_w = 0, 0
|
max_h, max_w = 0, 0
|
||||||
vae = getattr(p.sd_model, "vae", None)
|
vae = getattr(shared.sd_model, "vae", None)
|
||||||
if height == 0 or width == 0:
|
if height == 0 or width == 0:
|
||||||
log.warning('Hypertile VAE disabled: resolution unknown')
|
log.warning('Hypertile VAE disabled: resolution unknown')
|
||||||
return nullcontext()
|
return nullcontext()
|
||||||
|
|
@ -207,7 +207,7 @@ def context_hypertile_vae(p):
|
||||||
|
|
||||||
def context_hypertile_unet(p):
|
def context_hypertile_unet(p):
|
||||||
from modules import shared
|
from modules import shared
|
||||||
if p.sd_model is None or not shared.opts.hypertile_unet_enabled:
|
if shared.sd_model is None or not shared.opts.hypertile_unet_enabled:
|
||||||
return nullcontext()
|
return nullcontext()
|
||||||
if shared.opts.cross_attention_optimization == 'Sub-quadratic' and not shared.cmd_opts.experimental:
|
if shared.opts.cross_attention_optimization == 'Sub-quadratic' and not shared.cmd_opts.experimental:
|
||||||
log.warning('Hypertile UNet is not compatible with Sub-quadratic cross-attention optimization')
|
log.warning('Hypertile UNet is not compatible with Sub-quadratic cross-attention optimization')
|
||||||
|
|
@ -216,7 +216,7 @@ def context_hypertile_unet(p):
|
||||||
error_reported = False
|
error_reported = False
|
||||||
set_resolution(p)
|
set_resolution(p)
|
||||||
max_h, max_w = 0, 0
|
max_h, max_w = 0, 0
|
||||||
unet = getattr(p.sd_model, "unet", None)
|
unet = getattr(shared.sd_model, "unet", None)
|
||||||
if height == 0 or width == 0:
|
if height == 0 or width == 0:
|
||||||
log.warning('Hypertile VAE disabled: resolution unknown')
|
log.warning('Hypertile VAE disabled: resolution unknown')
|
||||||
return nullcontext()
|
return nullcontext()
|
||||||
|
|
|
||||||
|
|
@ -478,9 +478,10 @@ def apply_balanced_offload(sd_model=None, exclude:list[str]=None, force:bool=Fal
|
||||||
return sd_model
|
return sd_model
|
||||||
if exclude is None:
|
if exclude is None:
|
||||||
exclude = []
|
exclude = []
|
||||||
t0 = time.time()
|
|
||||||
if sd_model.__class__.__name__ in balanced_offload_exclude:
|
if sd_model.__class__.__name__ in balanced_offload_exclude:
|
||||||
return sd_model
|
return sd_model
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
cached = True
|
cached = True
|
||||||
checkpoint_name = sd_model.sd_checkpoint_info.name if getattr(sd_model, "sd_checkpoint_info", None) is not None else sd_model.__class__.__name__
|
checkpoint_name = sd_model.sd_checkpoint_info.name if getattr(sd_model, "sd_checkpoint_info", None) is not None else sd_model.__class__.__name__
|
||||||
if force or (offload_hook_instance is None) or (offload_hook_instance.min_watermark != shared.opts.diffusers_offload_min_gpu_memory) or (offload_hook_instance.max_watermark != shared.opts.diffusers_offload_max_gpu_memory) or (checkpoint_name != offload_hook_instance.checkpoint_name):
|
if force or (offload_hook_instance is None) or (offload_hook_instance.min_watermark != shared.opts.diffusers_offload_min_gpu_memory) or (offload_hook_instance.max_watermark != shared.opts.diffusers_offload_max_gpu_memory) or (checkpoint_name != offload_hook_instance.checkpoint_name):
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from modules import shared, sd_models, ui_extra_networks, files_cache, modelstats
|
from modules import shared, ui_extra_networks, files_cache, modelstats
|
||||||
from modules.logger import log
|
from modules.logger import log
|
||||||
from modules.textual_inversion import Embedding
|
from modules.textual_inversion import Embedding
|
||||||
|
|
||||||
|
|
@ -12,10 +12,10 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
|
||||||
self.embeddings = []
|
self.embeddings = []
|
||||||
|
|
||||||
def refresh(self):
|
def refresh(self):
|
||||||
if sd_models.model_data.sd_model is None:
|
if not shared.sd_loaded:
|
||||||
return
|
return
|
||||||
if hasattr(sd_models.model_data.sd_model, 'embedding_db'):
|
if hasattr(shared.sd_model, 'embedding_db'):
|
||||||
sd_models.model_data.sd_model.embedding_db.load_textual_inversion_embeddings(force_reload=True)
|
shared.sd_model.embedding_db.load_textual_inversion_embeddings(force_reload=True)
|
||||||
|
|
||||||
def create_item(self, embedding: Embedding):
|
def create_item(self, embedding: Embedding):
|
||||||
record = None
|
record = None
|
||||||
|
|
@ -43,15 +43,15 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
|
||||||
return record
|
return record
|
||||||
|
|
||||||
def list_items(self):
|
def list_items(self):
|
||||||
if sd_models.model_data.sd_model is None:
|
if not shared.sd_loaded:
|
||||||
candidates = list(files_cache.list_files(shared.opts.embeddings_dir, ext_filter=['.pt', '.safetensors'], recursive=files_cache.not_hidden))
|
candidates = list(files_cache.list_files(shared.opts.embeddings_dir, ext_filter=['.pt', '.safetensors'], recursive=files_cache.not_hidden))
|
||||||
self.embeddings = [
|
self.embeddings = [
|
||||||
Embedding(vec=0, name=os.path.basename(embedding_path), filename=embedding_path)
|
Embedding(vec=0, name=os.path.basename(embedding_path), filename=embedding_path)
|
||||||
for embedding_path
|
for embedding_path
|
||||||
in candidates
|
in candidates
|
||||||
]
|
]
|
||||||
elif hasattr(sd_models.model_data.sd_model, 'embedding_db'):
|
elif hasattr(shared.sd_model, 'embedding_db'):
|
||||||
self.embeddings = list(sd_models.model_data.sd_model.embedding_db.word_embeddings.values())
|
self.embeddings = list(shared.sd_model.embedding_db.word_embeddings.values())
|
||||||
else:
|
else:
|
||||||
self.embeddings = []
|
self.embeddings = []
|
||||||
self.embeddings = sorted(self.embeddings, key=lambda emb: emb.filename)
|
self.embeddings = sorted(self.embeddings, key=lambda emb: emb.filename)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue