automatic/scripts/regional_prompting.py

68 lines
3.3 KiB
Python

# https://github.com/huggingface/diffusers/blob/main/examples/community/README.md#regional-prompting-pipeline
# https://github.com/huggingface/diffusers/blob/main/examples/community/regional_prompting_stable_diffusion.py
import gradio as gr
from modules import shared, devices, scripts, processing, sd_models
class Script(scripts.Script):
def title(self):
return 'Regional prompting'
def show(self, is_img2img):
return False
return not is_img2img if shared.backend == shared.Backend.DIFFUSERS else False
def change(self, mode):
return [gr.update(visible='Col' in mode or 'Row' in mode), gr.update(visible='Prompt' in mode)]
def ui(self, _is_img2img):
with gr.Row():
gr.HTML('<a href="https://github.com/huggingface/diffusers/blob/main/examples/community/README.md#regional-prompting-pipeline">&nbsp Regional prompting</a>')
with gr.Row():
mode = gr.Radio(label='Mode', choices=['None', 'Prompt', 'Prompt EX', 'Columns', 'Rows'], value='None')
with gr.Row():
power = gr.Slider(label='Power', minimum=0, maximum=1, value=1.0, step=0.01)
threshold = gr.Textbox('', label='Prompt thresholds:', default='', visible=False)
grid = gr.Text('', label='Grid sections:', default='', visible=False)
mode.change(fn=self.change, inputs=[mode], outputs=[grid, threshold])
return mode, grid, power, threshold
def run(self, p: processing.StableDiffusionProcessing, mode, grid, power, threshold): # pylint: disable=arguments-differ
if mode is None or mode == 'None':
return
# backup pipeline and params
orig_pipeline = shared.sd_model
orig_dtype = devices.dtype
orig_prompt_attention = shared.opts.prompt_attention
# create pipeline
if shared.sd_model_type != 'sd':
shared.log.error(f'Regional prompting: incorrect base model: {shared.sd_model.__class__.__name__}')
return
shared.sd_model = sd_models.switch_pipe('regional_prompting_stable_diffusion', shared.sd_model)
if shared.sd_model.__class__.__name__ != 'RegionalPromptingStableDiffusionPipeline': # switch failed
shared.log.error(f'Regional prompting: not a tiling pipeline: {shared.sd_model.__class__.__name__}')
shared.sd_model = orig_pipeline
return
sd_models.set_diffuser_options(shared.sd_model)
shared.opts.data['prompt_attention'] = 'Fixed attention' # this pipeline is not compatible with embeds
processing.fix_seed(p)
# set pipeline specific params, note that standard params are applied when applicable
rp_args = {
'mode': mode.lower(),
'power': power,
}
if 'prompt' in mode.lower():
rp_args['th'] = threshold
else:
rp_args['div'] = grid
p.task_args = { **p.task_args, 'rp_args': rp_args }
# run pipeline
shared.log.debug(f'Regional: args={p.task_args}')
processed: processing.Processed = processing.process_images(p) # runs processing using main loop
# restore pipeline and params
shared.opts.data['prompt_attention'] = orig_prompt_attention
shared.sd_model = orig_pipeline
shared.sd_model.to(orig_dtype)
return processed