forge classic
parent
352c0047ec
commit
8ba24888f1
|
|
@ -36,6 +36,7 @@ def is_installed(pip_package):
|
|||
requirements = [
|
||||
"diffusers==0.31.0",
|
||||
"scikit-learn",
|
||||
"accelerate"
|
||||
]
|
||||
|
||||
for module in requirements:
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ import scripts.A1111.network_oft as network_oft
|
|||
import torch
|
||||
from typing import Union
|
||||
|
||||
from modules import shared, devices, sd_models, errors, scripts, sd_hijack
|
||||
from modules import shared, devices, sd_models, errors, scripts, sd_hijack, launch_utils
|
||||
import modules.textual_inversion.textual_inversion as textual_inversion
|
||||
|
||||
class QkvLinear(torch.nn.Linear):
|
||||
|
|
@ -34,8 +34,8 @@ module_types = [
|
|||
network_glora.ModuleTypeGLora(),
|
||||
network_oft.ModuleTypeOFT(),
|
||||
]
|
||||
from modules.ui import versions_html
|
||||
forge = "forge" in versions_html()
|
||||
|
||||
forge = launch_utils.git_tag()[0:2] == "f2"
|
||||
|
||||
re_digits = re.compile(r"\d+")
|
||||
re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")
|
||||
|
|
@ -676,7 +676,10 @@ def network_MultiheadAttention_load_state_dict(self, *args, **kwargs):
|
|||
|
||||
def process_network_files(names: list[str] | None = None):
|
||||
candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
|
||||
candidates += list(shared.walk_files(shared.cmd_opts.lyco_dir_backcompat, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
|
||||
try:
|
||||
candidates += list(shared.walk_files(shared.cmd_opts.lyco_dir_backcompat, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
|
||||
except:
|
||||
pass
|
||||
for filename in candidates:
|
||||
if os.path.isdir(filename):
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -71,7 +71,6 @@ class GenParamGetter(scripts.Script):
|
|||
return components
|
||||
|
||||
def compare_components_with_ids(components: list[gr.Blocks], ids: list[int]):
|
||||
|
||||
try:
|
||||
return len(components) == len(ids) and all(component._id == _id for component, _id in zip(components, ids))
|
||||
except:
|
||||
|
|
@ -80,8 +79,10 @@ class GenParamGetter(scripts.Script):
|
|||
def get_params_components(demo: gr.Blocks, app):
|
||||
for _id, _is_txt2img in zip([GenParamGetter.txt2img_gen_button._id, GenParamGetter.img2img_gen_button._id], [True, False]):
|
||||
if hasattr(demo,"dependencies"):
|
||||
dependencies: list[dict] = [x for x in demo.dependencies if x["trigger"] == "click" and _id in x["targets"]]
|
||||
#dependencies: list[dict] = [x for x in demo.dependencies if x["trigger"] == "click" and _id in x["targets"]]
|
||||
dependencies: list[dict] = [x for x in demo.dependencies if _id in x["targets"]]
|
||||
g4 = False
|
||||
|
||||
else:
|
||||
dependencies: list[dict] = [x for x in demo.config["dependencies"] if x["targets"][0][1] == "click" and _id in x["targets"][0]]
|
||||
g4 = True
|
||||
|
|
@ -91,6 +92,7 @@ class GenParamGetter(scripts.Script):
|
|||
for d in dependencies:
|
||||
if len(d["outputs"]) == 4:
|
||||
dependency = d
|
||||
print("GenParamsGetter detected!")
|
||||
|
||||
if g4:
|
||||
params = [demo.blocks[x] for x in dependency['inputs']]
|
||||
|
|
@ -102,6 +104,8 @@ class GenParamGetter(scripts.Script):
|
|||
else:
|
||||
components.img2img_params = params
|
||||
else:
|
||||
if dependency is None:continue
|
||||
|
||||
params = [params for params in demo.fns if GenParamGetter.compare_components_with_ids(params.inputs, dependency["inputs"])]
|
||||
|
||||
if _is_txt2img:
|
||||
|
|
@ -111,7 +115,6 @@ class GenParamGetter(scripts.Script):
|
|||
components.txt2img_params = params[0].inputs
|
||||
else:
|
||||
components.img2img_params = params[0].inputs
|
||||
|
||||
|
||||
if not GenParamGetter.events_assigned:
|
||||
with demo:
|
||||
|
|
|
|||
|
|
@ -1187,7 +1187,7 @@ def simggen(s_prompt,s_nprompt,s_steps,s_sampler,s_cfg,s_seed,s_w,s_h,s_batch_si
|
|||
elif wantedv and wantedv in paramsnames:return txt2imgparams[paramsnames.index(wantedv)]
|
||||
else:return None
|
||||
|
||||
sampler_index = g("Sampling method")
|
||||
sampler_index = g("Sampling method","Sampling Method")
|
||||
if type(sampler_index) is str:
|
||||
sampler_name = sampler_index
|
||||
else:
|
||||
|
|
@ -1206,7 +1206,7 @@ def simggen(s_prompt,s_nprompt,s_steps,s_sampler,s_cfg,s_seed,s_w,s_h,s_batch_si
|
|||
outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids,
|
||||
prompt=g("Prompt"),
|
||||
styles=g("Styles"),
|
||||
negative_prompt=g('Negative prompt'),
|
||||
negative_prompt=g('Negative prompt','Negative Prompt'),
|
||||
seed=g("Seed","Initial seed"),
|
||||
subseed=g("Variation seed"),
|
||||
subseed_strength=g("Variation strength"),
|
||||
|
|
@ -1214,9 +1214,9 @@ def simggen(s_prompt,s_nprompt,s_steps,s_sampler,s_cfg,s_seed,s_w,s_h,s_batch_si
|
|||
seed_resize_from_w=g("Resize seed from width"),
|
||||
seed_enable_extras=g("Extra"),
|
||||
sampler_name=sampler_name,
|
||||
batch_size=g("Batch size"),
|
||||
n_iter=g("Batch count"),
|
||||
steps=g("Sampling steps"),
|
||||
batch_size=g("Batch size","Batch Size"),
|
||||
n_iter=g("Batch count","Batch Count"),
|
||||
steps=g("Sampling steps","Sampling Steps"),
|
||||
cfg_scale=g("CFG Scale"),
|
||||
width=g("Width"),
|
||||
height=g("Height"),
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ import numpy as np
|
|||
import safetensors.torch
|
||||
import scripts.mergers.components as components
|
||||
import torch
|
||||
from modules import extra_networks, scripts, sd_models, lowvram
|
||||
from modules import extra_networks, scripts, sd_models, launch_utils
|
||||
from modules.ui import create_refresh_button
|
||||
from safetensors.torch import load_file, save_file
|
||||
from scripts.kohyas import extract_lora_from_models as ext
|
||||
|
|
@ -23,9 +23,8 @@ from scripts.A1111 import networks as nets
|
|||
from scripts.mergers.model_util import filenamecutter, savemodel
|
||||
from scripts.mergers.mergers import extract_super, unload_forge, q_dequantize, q_quantize, qdtyper, prefixer, BLOCKIDFLUX
|
||||
from tqdm import tqdm
|
||||
from modules.ui import versions_html
|
||||
|
||||
forge = "forge" in versions_html()
|
||||
forge = launch_utils.git_tag()[0:2] == "f2"
|
||||
|
||||
selectable = []
|
||||
pchanged = False
|
||||
|
|
@ -314,7 +313,7 @@ def makelora(model_a,model_b,dim,saveto,settings,alpha,beta,save_precision,calc_
|
|||
except:
|
||||
currentinfo = None
|
||||
|
||||
lowvram.module_in_gpu = None #web-uiのバグ対策
|
||||
lowvramdealer() #web-uiのバグ対策
|
||||
|
||||
checkpoint_info = sd_models.get_closet_checkpoint_match(model_a)
|
||||
load_model(checkpoint_info)
|
||||
|
|
@ -809,7 +808,7 @@ def pluslora(lnames,loraratios,settings,output,model,save_precision,calc_precisi
|
|||
elif "model_" in skey:
|
||||
keychanger[skey.split("model_",1)[1]] = key
|
||||
|
||||
lowvram.module_in_gpu = None #web-uiのバグ対策
|
||||
lowvramdealer() #web-uiのバグ対策
|
||||
|
||||
if is15:
|
||||
if shared.sd_model is not None:
|
||||
|
|
@ -1668,6 +1667,13 @@ def convert_diffusers_name_to_compvis(key, is_sd2):
|
|||
|
||||
return key
|
||||
|
||||
def lowvramdealer():
|
||||
try:
|
||||
from modules import lowvram
|
||||
lowvram.module_in_gpu = None #web-uiのバグ対策
|
||||
except:
|
||||
pass
|
||||
|
||||
def get_flux_blocks(key):
|
||||
if "vae" in key:
|
||||
return "VAE"
|
||||
|
|
|
|||
Loading…
Reference in New Issue