forge classic

pull/464/head
hako-mikan 2025-06-22 00:25:16 +09:00
parent 352c0047ec
commit 8ba24888f1
5 changed files with 30 additions and 17 deletions

View File

@ -36,6 +36,7 @@ def is_installed(pip_package):
requirements = [
"diffusers==0.31.0",
"scikit-learn",
"accelerate"
]
for module in requirements:

View File

@ -18,7 +18,7 @@ import scripts.A1111.network_oft as network_oft
import torch
from typing import Union
from modules import shared, devices, sd_models, errors, scripts, sd_hijack
from modules import shared, devices, sd_models, errors, scripts, sd_hijack, launch_utils
import modules.textual_inversion.textual_inversion as textual_inversion
class QkvLinear(torch.nn.Linear):
@ -34,8 +34,8 @@ module_types = [
network_glora.ModuleTypeGLora(),
network_oft.ModuleTypeOFT(),
]
from modules.ui import versions_html
forge = "forge" in versions_html()
forge = launch_utils.git_tag()[0:2] == "f2"
re_digits = re.compile(r"\d+")
re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")
@ -676,7 +676,10 @@ def network_MultiheadAttention_load_state_dict(self, *args, **kwargs):
def process_network_files(names: list[str] | None = None):
candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
candidates += list(shared.walk_files(shared.cmd_opts.lyco_dir_backcompat, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
try:
candidates += list(shared.walk_files(shared.cmd_opts.lyco_dir_backcompat, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
except:
pass
for filename in candidates:
if os.path.isdir(filename):
continue

View File

@ -71,7 +71,6 @@ class GenParamGetter(scripts.Script):
return components
def compare_components_with_ids(components: list[gr.Blocks], ids: list[int]):
try:
return len(components) == len(ids) and all(component._id == _id for component, _id in zip(components, ids))
except:
@ -80,8 +79,10 @@ class GenParamGetter(scripts.Script):
def get_params_components(demo: gr.Blocks, app):
for _id, _is_txt2img in zip([GenParamGetter.txt2img_gen_button._id, GenParamGetter.img2img_gen_button._id], [True, False]):
if hasattr(demo,"dependencies"):
dependencies: list[dict] = [x for x in demo.dependencies if x["trigger"] == "click" and _id in x["targets"]]
#dependencies: list[dict] = [x for x in demo.dependencies if x["trigger"] == "click" and _id in x["targets"]]
dependencies: list[dict] = [x for x in demo.dependencies if _id in x["targets"]]
g4 = False
else:
dependencies: list[dict] = [x for x in demo.config["dependencies"] if x["targets"][0][1] == "click" and _id in x["targets"][0]]
g4 = True
@ -91,6 +92,7 @@ class GenParamGetter(scripts.Script):
for d in dependencies:
if len(d["outputs"]) == 4:
dependency = d
print("GenParamsGetter detected!")
if g4:
params = [demo.blocks[x] for x in dependency['inputs']]
@ -102,6 +104,8 @@ class GenParamGetter(scripts.Script):
else:
components.img2img_params = params
else:
if dependency is None:continue
params = [params for params in demo.fns if GenParamGetter.compare_components_with_ids(params.inputs, dependency["inputs"])]
if _is_txt2img:
@ -111,7 +115,6 @@ class GenParamGetter(scripts.Script):
components.txt2img_params = params[0].inputs
else:
components.img2img_params = params[0].inputs
if not GenParamGetter.events_assigned:
with demo:

View File

@ -1187,7 +1187,7 @@ def simggen(s_prompt,s_nprompt,s_steps,s_sampler,s_cfg,s_seed,s_w,s_h,s_batch_si
elif wantedv and wantedv in paramsnames:return txt2imgparams[paramsnames.index(wantedv)]
else:return None
sampler_index = g("Sampling method")
sampler_index = g("Sampling method","Sampling Method")
if type(sampler_index) is str:
sampler_name = sampler_index
else:
@ -1206,7 +1206,7 @@ def simggen(s_prompt,s_nprompt,s_steps,s_sampler,s_cfg,s_seed,s_w,s_h,s_batch_si
outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids,
prompt=g("Prompt"),
styles=g("Styles"),
negative_prompt=g('Negative prompt'),
negative_prompt=g('Negative prompt','Negative Prompt'),
seed=g("Seed","Initial seed"),
subseed=g("Variation seed"),
subseed_strength=g("Variation strength"),
@ -1214,9 +1214,9 @@ def simggen(s_prompt,s_nprompt,s_steps,s_sampler,s_cfg,s_seed,s_w,s_h,s_batch_si
seed_resize_from_w=g("Resize seed from width"),
seed_enable_extras=g("Extra"),
sampler_name=sampler_name,
batch_size=g("Batch size"),
n_iter=g("Batch count"),
steps=g("Sampling steps"),
batch_size=g("Batch size","Batch Size"),
n_iter=g("Batch count","Batch Count"),
steps=g("Sampling steps","Sampling Steps"),
cfg_scale=g("CFG Scale"),
width=g("Width"),
height=g("Height"),

View File

@ -15,7 +15,7 @@ import numpy as np
import safetensors.torch
import scripts.mergers.components as components
import torch
from modules import extra_networks, scripts, sd_models, lowvram
from modules import extra_networks, scripts, sd_models, launch_utils
from modules.ui import create_refresh_button
from safetensors.torch import load_file, save_file
from scripts.kohyas import extract_lora_from_models as ext
@ -23,9 +23,8 @@ from scripts.A1111 import networks as nets
from scripts.mergers.model_util import filenamecutter, savemodel
from scripts.mergers.mergers import extract_super, unload_forge, q_dequantize, q_quantize, qdtyper, prefixer, BLOCKIDFLUX
from tqdm import tqdm
from modules.ui import versions_html
forge = "forge" in versions_html()
forge = launch_utils.git_tag()[0:2] == "f2"
selectable = []
pchanged = False
@ -314,7 +313,7 @@ def makelora(model_a,model_b,dim,saveto,settings,alpha,beta,save_precision,calc_
except:
currentinfo = None
lowvram.module_in_gpu = None #web-uiのバグ対策
lowvramdealer() #web-uiのバグ対策
checkpoint_info = sd_models.get_closet_checkpoint_match(model_a)
load_model(checkpoint_info)
@ -809,7 +808,7 @@ def pluslora(lnames,loraratios,settings,output,model,save_precision,calc_precisi
elif "model_" in skey:
keychanger[skey.split("model_",1)[1]] = key
lowvram.module_in_gpu = None #web-uiのバグ対策
lowvramdealer() #web-uiのバグ対策
if is15:
if shared.sd_model is not None:
@ -1668,6 +1667,13 @@ def convert_diffusers_name_to_compvis(key, is_sd2):
return key
def lowvramdealer():
try:
from modules import lowvram
lowvram.module_in_gpu = None #web-uiのバグ対策
except:
pass
def get_flux_blocks(key):
if "vae" in key:
return "VAE"