Support remaining samplers, #39

main
AlUlkesh 2023-10-09 22:15:59 +02:00
parent faf17b44df
commit fba7d9e9de
2 changed files with 31 additions and 11 deletions

View File

@ -12,8 +12,8 @@ You can also install it manually by running the following command from within th
git clone https://github.com/AlUlkesh/sd_save_intermediate_images/ extensions/sd_save_intermediate_images
## Limitations
Does not work with DDIM, PLMS or UNIPC
## Samplers
Works with all a1111 samplers
## Output

View File

@ -12,11 +12,9 @@ from modules import paths
from modules import scripts
from modules import script_callbacks
from modules.processing import Processed, process_images, fix_seed, create_infotext
try:
from modules.sd_samplers_kdiffusion import KDiffusionSampler
from modules.sd_samplers_common import sample_to_image
except ImportError:
from modules.sd_samplers import KDiffusionSampler, sample_to_image
from modules.sd_samplers_kdiffusion import KDiffusionSampler
from modules.sd_samplers_timesteps import CompVisSampler as TimestepsSampler, samplers_timesteps
from modules.sd_samplers_common import sample_to_image
from modules.images import save_image, FilenameGenerator, get_next_sequence_number
from modules.shared import opts, state, cmd_opts
@ -27,7 +25,9 @@ import gradio as gr; gr.__version__
# replace: \1, ssii_add_last_frames, ssii_add_first_frames
# plus debug
orig_callback_state = KDiffusionSampler.callback_state
orig_callback_state_KDiffusionSampler = KDiffusionSampler.callback_state
orig_callback_state_TimestepsSampler = TimestepsSampler.callback_state
orig_callback_state = None
ui_config_backup = os.path.join(scripts.basedir(), "ui-config_backup.json")
video_bat_mode = ""
ui_items = {
@ -368,6 +368,12 @@ def hr_active_check(p):
hr_active = False
return hr_active
def is_TimestepsSampler(sampler_name):
for sampler in samplers_timesteps:
if sampler[0] == sampler_name:
return True
return False
class Script(scripts.Script):
def title(self):
return "Save intermediate images during the sampling process"
@ -580,6 +586,12 @@ class Script(scripts.Script):
"""
callback_state runs after each processing step
"""
logger.debug(f"sampler_name: {p.sampler_name}")
if is_TimestepsSampler(p.sampler_name):
orig_callback_state = orig_callback_state_TimestepsSampler
else:
orig_callback_state = orig_callback_state_KDiffusionSampler
current_step = d["i"]
hr = hr_check(p)
@ -652,9 +664,15 @@ class Script(scripts.Script):
if ssii_intermediate_type == "According to Live preview subject setting" and index == 0:
image = state.current_image
elif ssii_intermediate_type == "Noisy":
image = sample_to_image(d["x"], index=index)
if d["x"] is None:
image = state.current_image
else:
image = sample_to_image(d["x"], index=index)
else:
image = sample_to_image(d["denoised"], index=index)
if d["denoised"] is None:
image = state.current_image
else:
image = sample_to_image(d["denoised"], index=index)
logger.debug(f"ssii_intermediate_type, ssii_every_n, ssii_start_at_n, ssii_stop_at_n: {ssii_intermediate_type}, {ssii_every_n}, {ssii_start_at_n}, {ssii_stop_at_n}")
logger.debug(f"Step, abs_step, hr, hr_active: {current_step}, {abs_step}, {hr}, {hr_active}")
@ -753,10 +771,12 @@ class Script(scripts.Script):
return orig_callback_state(self, d)
setattr(KDiffusionSampler, "callback_state", callback_state)
setattr(TimestepsSampler, "callback_state", callback_state)
def postprocess(self, p, processed, ssii_is_active, ssii_final_save, ssii_intermediate_type, ssii_every_n, ssii_start_at_n, ssii_stop_at_n, ssii_mode, ssii_video_format, ssii_mp4_parms, ssii_video_fps, ssii_add_first_frames, ssii_add_last_frames, ssii_smooth, ssii_seconds, ssii_lores, ssii_hires, ssii_ffmpeg_bat, ssii_bat_only, ssii_debug):
logger.debug(f"func: {sys._getframe(0).f_code.co_name}")
setattr(KDiffusionSampler, "callback_state", orig_callback_state)
setattr(KDiffusionSampler, "callback_state", orig_callback_state_KDiffusionSampler)
setattr(TimestepsSampler, "callback_state", orig_callback_state_TimestepsSampler)
# Make video for last batch_count
make_video(p, ssii_is_active, ssii_final_save, ssii_intermediate_type, ssii_every_n, ssii_start_at_n, ssii_stop_at_n, ssii_mode, ssii_video_format, ssii_mp4_parms, ssii_video_fps, ssii_add_first_frames, ssii_add_last_frames, ssii_smooth, ssii_seconds, ssii_lores, ssii_hires, ssii_ffmpeg_bat, ssii_bat_only, ssii_debug)