diff --git a/README.md b/README.md index ee47b5f..998c0c2 100644 --- a/README.md +++ b/README.md @@ -12,8 +12,8 @@ You can also install it manually by running the following command from within th git clone https://github.com/AlUlkesh/sd_save_intermediate_images/ extensions/sd_save_intermediate_images -## Limitations -Does not work with DDIM, PLMS or UNIPC +## Samplers +Works with all a1111 samplers ## Output diff --git a/scripts/sd_save_intermediate_images.py b/scripts/sd_save_intermediate_images.py index 63062f4..89fea1b 100644 --- a/scripts/sd_save_intermediate_images.py +++ b/scripts/sd_save_intermediate_images.py @@ -12,11 +12,9 @@ from modules import paths from modules import scripts from modules import script_callbacks from modules.processing import Processed, process_images, fix_seed, create_infotext -try: - from modules.sd_samplers_kdiffusion import KDiffusionSampler - from modules.sd_samplers_common import sample_to_image -except ImportError: - from modules.sd_samplers import KDiffusionSampler, sample_to_image +from modules.sd_samplers_kdiffusion import KDiffusionSampler +from modules.sd_samplers_timesteps import CompVisSampler as TimestepsSampler, samplers_timesteps +from modules.sd_samplers_common import sample_to_image from modules.images import save_image, FilenameGenerator, get_next_sequence_number from modules.shared import opts, state, cmd_opts @@ -27,7 +25,9 @@ import gradio as gr; gr.__version__ # replace: \1, ssii_add_last_frames, ssii_add_first_frames # plus debug -orig_callback_state = KDiffusionSampler.callback_state +orig_callback_state_KDiffusionSampler = KDiffusionSampler.callback_state +orig_callback_state_TimestepsSampler = TimestepsSampler.callback_state +orig_callback_state = None ui_config_backup = os.path.join(scripts.basedir(), "ui-config_backup.json") video_bat_mode = "" ui_items = { @@ -368,6 +368,12 @@ def hr_active_check(p): hr_active = False return hr_active +def is_TimestepsSampler(sampler_name): + for sampler in samplers_timesteps: + if sampler[0] == sampler_name: + return True + return False + class Script(scripts.Script): def title(self): return "Save intermediate images during the sampling process" @@ -580,6 +586,12 @@ class Script(scripts.Script): """ callback_state runs after each processing step """ + logger.debug(f"sampler_name: {p.sampler_name}") + if is_TimestepsSampler(p.sampler_name): + orig_callback_state = orig_callback_state_TimestepsSampler + else: + orig_callback_state = orig_callback_state_KDiffusionSampler + current_step = d["i"] hr = hr_check(p) @@ -652,9 +664,15 @@ class Script(scripts.Script): if ssii_intermediate_type == "According to Live preview subject setting" and index == 0: image = state.current_image elif ssii_intermediate_type == "Noisy": - image = sample_to_image(d["x"], index=index) + if d["x"] is None: + image = state.current_image + else: + image = sample_to_image(d["x"], index=index) else: - image = sample_to_image(d["denoised"], index=index) + if d["denoised"] is None: + image = state.current_image + else: + image = sample_to_image(d["denoised"], index=index) logger.debug(f"ssii_intermediate_type, ssii_every_n, ssii_start_at_n, ssii_stop_at_n: {ssii_intermediate_type}, {ssii_every_n}, {ssii_start_at_n}, {ssii_stop_at_n}") logger.debug(f"Step, abs_step, hr, hr_active: {current_step}, {abs_step}, {hr}, {hr_active}") @@ -753,10 +771,12 @@ class Script(scripts.Script): return orig_callback_state(self, d) setattr(KDiffusionSampler, "callback_state", callback_state) + setattr(TimestepsSampler, "callback_state", callback_state) def postprocess(self, p, processed, ssii_is_active, ssii_final_save, ssii_intermediate_type, ssii_every_n, ssii_start_at_n, ssii_stop_at_n, ssii_mode, ssii_video_format, ssii_mp4_parms, ssii_video_fps, ssii_add_first_frames, ssii_add_last_frames, ssii_smooth, ssii_seconds, ssii_lores, ssii_hires, ssii_ffmpeg_bat, ssii_bat_only, ssii_debug): logger.debug(f"func: {sys._getframe(0).f_code.co_name}") - setattr(KDiffusionSampler, "callback_state", orig_callback_state) + setattr(KDiffusionSampler, "callback_state", orig_callback_state_KDiffusionSampler) + setattr(TimestepsSampler, "callback_state", orig_callback_state_TimestepsSampler) # Make video for last batch_count make_video(p, ssii_is_active, ssii_final_save, ssii_intermediate_type, ssii_every_n, ssii_start_at_n, ssii_stop_at_n, ssii_mode, ssii_video_format, ssii_mp4_parms, ssii_video_fps, ssii_add_first_frames, ssii_add_last_frames, ssii_smooth, ssii_seconds, ssii_lores, ssii_hires, ssii_ffmpeg_bat, ssii_bat_only, ssii_debug)