164 lines
5.7 KiB
Python
164 lines
5.7 KiB
Python
import dataclasses
|
|
|
|
from lib_neutral_prompt import global_state, hijacker, neutral_prompt_parser, prompt_parser_hijack, cfg_denoiser_hijack, ui, xyz_grid
|
|
from modules import scripts, processing, shared, script_callbacks
|
|
from typing import Dict, List, Tuple
|
|
import torch
|
|
import functools
|
|
|
|
|
|
class NeutralPromptScript(scripts.Script):
|
|
def __init__(self):
|
|
self.accordion_interface = None
|
|
self._is_img2img = False
|
|
|
|
@property
|
|
def is_img2img(self):
|
|
return self._is_img2img
|
|
|
|
@is_img2img.setter
|
|
def is_img2img(self, is_img2img):
|
|
self._is_img2img = is_img2img
|
|
if self.accordion_interface is None:
|
|
self.accordion_interface = ui.AccordionInterface(self.elem_id)
|
|
|
|
def title(self) -> str:
|
|
return "Neutral Prompt"
|
|
|
|
def show(self, is_img2img: bool):
|
|
return scripts.AlwaysVisible
|
|
|
|
def ui(self, is_img2img: bool):
|
|
self.hijack_composable_lora(is_img2img)
|
|
|
|
self.accordion_interface.arrange_components(is_img2img)
|
|
self.accordion_interface.connect_events(is_img2img)
|
|
self.infotext_fields = self.accordion_interface.get_infotext_fields()
|
|
self.paste_field_names = self.accordion_interface.get_paste_field_names()
|
|
self.accordion_interface.set_rendered()
|
|
return self.accordion_interface.get_components()
|
|
|
|
def process(self, p: processing.StableDiffusionProcessing, *args):
|
|
args = self.accordion_interface.unpack_processing_args(*args)
|
|
|
|
self.update_global_state(args)
|
|
if global_state.is_enabled:
|
|
p.extra_generation_params.update(self.accordion_interface.get_extra_generation_params(args))
|
|
|
|
global sampling_step
|
|
sampling_step = 0
|
|
|
|
def update_global_state(self, args: Dict):
|
|
if shared.state.job_no == 0:
|
|
global_state.is_enabled = shared.opts.data.get('neutral_prompt_enabled', True)
|
|
|
|
for k, v in args.items():
|
|
try:
|
|
getattr(global_state, k)
|
|
except AttributeError:
|
|
continue
|
|
|
|
if getattr(getattr(global_state, k), 'is_xyz', False):
|
|
xyz_attr = getattr(global_state, k)
|
|
xyz_attr.is_xyz = False
|
|
args[k] = xyz_attr
|
|
continue
|
|
|
|
if shared.state.job_no > 0:
|
|
continue
|
|
|
|
setattr(global_state, k, v)
|
|
|
|
def hijack_composable_lora(self, is_img2img):
|
|
if self.accordion_interface.is_rendered:
|
|
return
|
|
|
|
lora_script = None
|
|
script_runner = scripts.scripts_img2img if is_img2img else scripts.scripts_txt2img
|
|
|
|
for script in script_runner.alwayson_scripts:
|
|
if script.title().lower() == "composable lora":
|
|
lora_script = script
|
|
break
|
|
|
|
if lora_script is not None:
|
|
lora_script.process = functools.partial(composable_lora_process_hijack, original_function=lora_script.process)
|
|
|
|
|
|
def composable_lora_process_hijack(p: processing.StableDiffusionProcessing, *args, original_function, **kwargs):
|
|
if not global_state.is_enabled:
|
|
return original_function(p, *args, **kwargs)
|
|
|
|
exprs = prompt_parser_hijack.parse_prompts(p.all_prompts)
|
|
all_prompts, p.all_prompts = p.all_prompts, prompt_parser_hijack.transpile_exprs(exprs)
|
|
res = original_function(p, *args, **kwargs)
|
|
# restore original prompts
|
|
p.all_prompts = all_prompts
|
|
return res
|
|
|
|
|
|
xyz_grid.patch()
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class CombinePreNoiseArgs:
|
|
x_out: torch.Tensor
|
|
cond_indices: List[Tuple[int, float]]
|
|
|
|
|
|
def on_cfg_denoiser(params: script_callbacks.CFGDenoiserParams):
|
|
if not global_state.is_enabled:
|
|
return
|
|
|
|
for prompt, cond_indices in zip(global_state.prompt_exprs, global_state.batch_cond_indices):
|
|
args = CombinePreNoiseArgs(params.x, cond_indices)
|
|
inv_transforms = prompt.accept(GlobalToLocalAffineVisitor(), args, 0)
|
|
for cond_index, _ in cond_indices:
|
|
params.x[cond_index] = apply_affine_transform(params.x[cond_index], inv_transforms[cond_index])
|
|
|
|
|
|
script_callbacks.on_cfg_denoiser(on_cfg_denoiser)
|
|
|
|
|
|
class GlobalToLocalAffineVisitor:
|
|
def visit_leaf_prompt(
|
|
self,
|
|
that: neutral_prompt_parser.LeafPrompt,
|
|
args: CombinePreNoiseArgs,
|
|
index: int,
|
|
) -> Dict[int, torch.Tensor]:
|
|
cond_index = args.cond_indices[index][0]
|
|
transform = torch.linalg.inv(torch.vstack([that.local_transform, torch.tensor([0, 0, 1])]))[:-1] if that.local_transform is not None else torch.eye(3)[:-1]
|
|
return {cond_index: transform}
|
|
|
|
def visit_composite_prompt(
|
|
self,
|
|
that: neutral_prompt_parser.CompositePrompt,
|
|
args: CombinePreNoiseArgs,
|
|
index: int,
|
|
) -> Dict[int, torch.Tensor]:
|
|
inv_transforms = {}
|
|
|
|
for child in that.children:
|
|
inv_transforms.update(child.accept(GlobalToLocalAffineVisitor(), args, index))
|
|
index += child.accept(neutral_prompt_parser.FlatSizeVisitor())
|
|
|
|
if that.local_transform is not None:
|
|
that_inv_transform = torch.linalg.inv(torch.vstack([that.local_transform, torch.tensor([0, 0, 1])]))
|
|
for inv_transform in inv_transforms.values():
|
|
inv_transform[:] = that_inv_transform @ inv_transform
|
|
|
|
return inv_transforms
|
|
|
|
|
|
import torch.nn.functional as F
|
|
def apply_affine_transform(tensor, affine):
|
|
affine = affine.to(tensor.device)
|
|
aspect_ratio = tensor.shape[-2] / tensor.shape[-1]
|
|
affine[0, 1] *= aspect_ratio
|
|
affine[1, 0] /= aspect_ratio
|
|
|
|
grid = F.affine_grid(affine.unsqueeze(0), tensor.unsqueeze(0).size(), align_corners=False)
|
|
transformed_tensors = F.grid_sample(tensor.unsqueeze(0), grid, align_corners=False)
|
|
return transformed_tensors.squeeze(0)
|