From 09b9ae32c10267952c493b5d947c1670237eb9a1 Mon Sep 17 00:00:00 2001 From: vladmandic Date: Mon, 23 Mar 2026 08:44:07 +0100 Subject: [PATCH] add color grading to processing Signed-off-by: vladmandic --- CHANGELOG.md | 9 ++++++- TODO.md | 3 ++- extensions-builtin/sdnext-modernui | 2 +- modules/processing.py | 2 +- modules/processing_grading.py | 16 ++++++------ modules/scripts_postprocessing.py | 41 ++++++++++++++++++++---------- modules/ui_models.py | 1 - modules/ui_sections.py | 8 +++--- pipelines/flux/flux2_lora.py | 3 +-- scripts/color_grading.py | 15 +++++++++++ scripts/nudenet_ext.py | 4 +-- scripts/postprocessing_video.py | 6 ++--- test/test-grading.py | 2 +- 13 files changed, 74 insertions(+), 38 deletions(-) create mode 100644 scripts/color_grading.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 13fd02cbb..7c4cab49a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,7 +27,14 @@ But also many smaller quality-of-life improvements - for full details, see [Chan - [Anima Preview-v2](https://huggingface.co/circlestone-labs/Anima) - **Image manipulation** - new **color grading** module - - update **latent corrections** *(former HDR Corrections)* and expand allowed models + apply basic corrections to your images: brightness,contrast,saturation,shadows,highlights + move to professional photo corrections: hue,gamma,sharpness,temperature + correct tone: shadows,midtones,highlights + add effects: vignette,grain + apply professional lut-table using .cube file + *hint* color grading is available as step during generate or as processing item for already existing images + - update **latent corrections** *(former HDR Corrections)* + expand allowed models - add support for [spandrel](https://github.com/chaiNNer-org/spandrel) **upscaling** engine with suport for new upscaling model families - add two new ai upscalers: *RealPLKSR NomosWebPhoto* and *RealPLKSR AnimeSharpV2* diff --git a/TODO.md b/TODO.md index 2c035e628..218e99b60 100644 --- a/TODO.md +++ b/TODO.md @@ -12,7 +12,6 @@ ## Internal - Integrate: [Depth3D](https://github.com/vladmandic/sd-extension-depth3d) -- Feature: Color grading in processing - Feature: RIFE update - Feature: RIFE in processing - Feature: SeedVR2 in processing @@ -146,3 +145,5 @@ TODO: Investigate which models are diffusers-compatible and prioritize! - modules/modular_guiders.py:65:58: W0511: TODO: guiders - processing: remove duplicate mask params - resize image: enable full VAE mode for resize-latent + +modules/sd_samplers_diffusers.py:353:31: W0511: TODO enso-required (fixme) diff --git a/extensions-builtin/sdnext-modernui b/extensions-builtin/sdnext-modernui index d3f63ee8c..0861ae00f 160000 --- a/extensions-builtin/sdnext-modernui +++ b/extensions-builtin/sdnext-modernui @@ -1 +1 @@ -Subproject commit d3f63ee8c3b6220f290e5fa54dc172a772b8c108 +Subproject commit 0861ae00f2ad057a914ca82e45fe6635dde7417e diff --git a/modules/processing.py b/modules/processing.py index 22ef90f62..e4119a9f8 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -360,7 +360,7 @@ def process_samples(p: StableDiffusionProcessing, samples): split_tone_balance=getattr(p, 'grading_split_tone_balance', 0.5), vignette=getattr(p, 'grading_vignette', 0.0), grain=getattr(p, 'grading_grain', 0.0), - lut_file=getattr(p, 'grading_lut_file', ''), + lut_cube_file=getattr(p, 'grading_lut_file', ''), lut_strength=getattr(p, 'grading_lut_strength', 1.0), ) if processing_grading.is_active(grading_params): diff --git a/modules/processing_grading.py b/modules/processing_grading.py index 003f93f07..c0fe0fe09 100644 --- a/modules/processing_grading.py +++ b/modules/processing_grading.py @@ -66,7 +66,7 @@ class GradingParams: vignette: float = 0.0 grain: float = 0.0 # lut - lut_file: str = "" + lut_cube_file: str = "" lut_strength: float = 1.0 def __post_init__(self): @@ -179,17 +179,17 @@ def _apply_color_temp(img: torch.Tensor, kelvin: float) -> torch.Tensor: return (img * scales).clamp(0, 1) -def _apply_lut(image: Image.Image, lut_file: str, strength: float) -> Image.Image: +def _apply_lut(image: Image.Image, lut_cube_file: str, strength: float) -> Image.Image: """Apply .cube LUT file via pillow-lut-tools.""" - if not lut_file or not os.path.isfile(lut_file): + if not lut_cube_file or not os.path.isfile(lut_cube_file): return image pillow_lut = _ensure_pillow_lut() try: - cube = pillow_lut.load_cube_file(lut_file) + cube = pillow_lut.load_cube_file(lut_cube_file) if strength != 1.0: cube = pillow_lut.amplify_lut(cube, strength) result = image.filter(cube) - debug(f'Grading LUT: file={os.path.basename(lut_file)} strength={strength}') + debug(f'Grading LUT: file={os.path.basename(lut_cube_file)} strength={strength}') return result except Exception as e: log.error(f'Grading LUT: {e}') @@ -198,8 +198,8 @@ def _apply_lut(image: Image.Image, lut_file: str, strength: float) -> Image.Imag def grade_image(image: Image.Image, params: GradingParams) -> Image.Image: """Full grading pipeline: PIL -> GPU tensor -> kornia ops -> PIL.""" + log.debug(f"Grading: params={params}") kornia = _ensure_kornia() - debug(f'Grading: params={params}') arr = np.array(image).astype(np.float32) / 255.0 tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0) tensor = tensor.to(device=devices.device, dtype=devices.dtype) @@ -246,7 +246,7 @@ def grade_image(image: Image.Image, params: GradingParams) -> Image.Image: result = Image.fromarray(arr) # LUT applied last (CPU, via pillow-lut-tools) - if params.lut_file: - result = _apply_lut(result, params.lut_file, params.lut_strength) + if params.lut_cube_file: + result = _apply_lut(result, params.lut_cube_file, params.lut_strength) return result diff --git a/modules/scripts_postprocessing.py b/modules/scripts_postprocessing.py index 33726232d..70410826a 100644 --- a/modules/scripts_postprocessing.py +++ b/modules/scripts_postprocessing.py @@ -76,9 +76,14 @@ class ScriptPostprocessingRunner: script.controls = wrap_call(script.ui, script.filename, "ui") if script.controls is None: script.controls = {} - for control in script.controls.values(): - control.custom_script_source = os.path.basename(script.filename) - inputs += list(script.controls.values()) + if isinstance(script.controls, list) or isinstance(script.controls, tuple): + for control in script.controls: + control.custom_script_source = os.path.basename(script.filename) + inputs += script.controls + else: + for control in script.controls.values(): + control.custom_script_source = os.path.basename(script.filename) + inputs += list(script.controls.values()) script.args_to = len(inputs) def scripts_in_preferred_order(self): @@ -109,11 +114,16 @@ class ScriptPostprocessingRunner: for script in self.scripts_in_preferred_order(): jobid = shared.state.begin(script.name) script_args = args[script.args_from:script.args_to] - process_args = {} - for (name, _component), value in zip(script.controls.items(), script_args, strict=False): - process_args[name] = value - log.debug(f'Process: script="{script.name}" args={process_args}') - script.process(pp, **process_args) + process_args = [] + process_kwargs = {} + if isinstance(script.controls, list) or isinstance(script.controls, tuple): + for _control, value in zip(script.controls, script_args, strict=False): + process_args.append(value) + else: + for (name, _component), value in zip(script.controls.items(), script_args, strict=False): + process_kwargs[name] = value + log.debug(f'Process: script="{script.name}" args={process_args} kwargs={process_kwargs}') + script.process(pp, *process_args, **process_kwargs) shared.state.end(jobid) def create_args_for_run(self, scripts_args): @@ -139,9 +149,14 @@ class ScriptPostprocessingRunner: continue jobid = shared.state.begin(script.name) script_args = args[script.args_from:script.args_to] - process_args = {} - for (name, _component), value in zip(script.controls.items(), script_args, strict=False): - process_args[name] = value - log.debug(f'Postprocess: script={script.name} args={process_args}') - script.postprocess(filenames, **process_args) + process_args = [] + process_kwargs = {} + if isinstance(script.controls, list) or isinstance(script.controls, tuple): + for _control, value in zip(script.controls, script_args, strict=False): + process_args.append(value) + else: + for (name, _component), value in zip(script.controls.items(), script_args, strict=False): + process_kwargs[name] = value + log.debug(f'Postprocess: script={script.name} args={process_args} kwargs={process_kwargs}') + script.postprocess(filenames, *process_args, **process_kwargs) shared.state.end(jobid) diff --git a/modules/ui_models.py b/modules/ui_models.py index 2edfd0ade..d512cfe9c 100644 --- a/modules/ui_models.py +++ b/modules/ui_models.py @@ -17,7 +17,6 @@ def update_model_hashes(): unets = {} for k, v in sd_unet.unet_dict.items(): unets[k] = sd_checkpoint.CheckpointInfo(name=k, filename=v, model_type='unet') - print('HERE3', unets[k]) yield from sd_models.update_model_hashes(unets, model_type='unet') yield from sd_models.update_model_hashes(model_type='checkpoint') diff --git a/modules/ui_sections.py b/modules/ui_sections.py index 4a6b2d05a..d8cb25127 100644 --- a/modules/ui_sections.py +++ b/modules/ui_sections.py @@ -197,7 +197,7 @@ def create_color_inputs(tab): grading_hue = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0, label='Hue', elem_id=f"{tab}_grading_hue") grading_gamma = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, value=1.0, label='Gamma', elem_id=f"{tab}_grading_gamma") grading_sharpness = gr.Slider(minimum=0.0, maximum=2.0, step=0.05, value=0, label='Sharpness', elem_id=f"{tab}_grading_sharpness") - grading_color_temp = gr.Slider(minimum=2000, maximum=12000, step=100, value=6500, label='Color temp (K)', elem_id=f"{tab}_grading_color_temp") + grading_color_temp = gr.Slider(minimum=2000, maximum=12000, step=100, value=6500, label='Color temp', elem_id=f"{tab}_grading_color_temp") with gr.Group(): gr.HTML('

Tone

') with gr.Row(elem_id=f"{tab}_grading_tone_row"): @@ -211,7 +211,7 @@ def create_color_inputs(tab): with gr.Row(elem_id=f"{tab}_grading_split_row"): grading_shadows_tint = gr.ColorPicker(label="Shadows tint", value="#000000", elem_id=f"{tab}_grading_shadows_tint") grading_highlights_tint = gr.ColorPicker(label="Highlights tint", value="#ffffff", elem_id=f"{tab}_grading_highlights_tint") - grading_split_tone_balance = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.5, label='Balance', elem_id=f"{tab}_grading_split_tone_balance") + grading_split_tone_balance = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.5, label='Split tone balance', elem_id=f"{tab}_grading_split_tone_balance") with gr.Group(): gr.HTML('

Effects

') with gr.Row(elem_id=f"{tab}_grading_effects_row"): @@ -220,9 +220,9 @@ def create_color_inputs(tab): with gr.Group(): gr.HTML('

LUT

') with gr.Row(elem_id=f"{tab}_grading_lut_row"): - grading_lut_file = gr.File(label='LUT .cube file', file_types=['.cube'], elem_id=f"{tab}_grading_lut_file") + grading_lut_cube_file = gr.File(label='LUT .cube file', file_types=['.cube'], elem_id=f"{tab}_grading_lut_file") grading_lut_strength = gr.Slider(minimum=0.0, maximum=2.0, step=0.05, value=1.0, label='LUT strength', elem_id=f"{tab}_grading_lut_strength") - return grading_brightness, grading_contrast, grading_saturation, grading_hue, grading_gamma, grading_sharpness, grading_color_temp, grading_shadows, grading_midtones, grading_highlights, grading_clahe_clip, grading_clahe_grid, grading_shadows_tint, grading_highlights_tint, grading_split_tone_balance, grading_vignette, grading_grain, grading_lut_file, grading_lut_strength + return grading_brightness, grading_contrast, grading_saturation, grading_hue, grading_gamma, grading_sharpness, grading_color_temp, grading_shadows, grading_midtones, grading_highlights, grading_clahe_clip, grading_clahe_grid, grading_shadows_tint, grading_highlights_tint, grading_split_tone_balance, grading_vignette, grading_grain, grading_lut_cube_file, grading_lut_strength def create_sampler_and_steps_selection(choices, tabname, default_steps:int=20): diff --git a/pipelines/flux/flux2_lora.py b/pipelines/flux/flux2_lora.py index 048d60857..a2d433992 100644 --- a/pipelines/flux/flux2_lora.py +++ b/pipelines/flux/flux2_lora.py @@ -9,7 +9,6 @@ Installed via apply_patch() during pipeline loading. import os import time -import torch from modules import shared, sd_models from modules.logger import log from modules.lora import network, network_lokr, lora_convert @@ -180,7 +179,7 @@ def apply_patch(): lora_state_dict won't detect them as AI toolkit format. This patch checks for bare keys after the original returns and adds the prefix + re-runs conversion. """ - global patched + global patched # pylint: disable=global-statement if patched: return patched = True diff --git a/scripts/color_grading.py b/scripts/color_grading.py new file mode 100644 index 000000000..7ef518662 --- /dev/null +++ b/scripts/color_grading.py @@ -0,0 +1,15 @@ +from modules import scripts_postprocessing, ui_sections, processing_grading + + +class ScriptPostprocessingColorGrading(scripts_postprocessing.ScriptPostprocessing): + name = "Color Grading" + + def ui(self): + ui_controls = ui_sections.create_color_inputs('process') + ui_controls_dict = {control.label.replace(" ", "_").replace(".", "").lower(): control for control in ui_controls} + return ui_controls_dict + + def process(self, pp: scripts_postprocessing.PostprocessedImage, *args, **kwargs): # pylint: disable=arguments-differ + grading_params = processing_grading.GradingParams(*args, **kwargs) + if processing_grading.is_active(grading_params): + pp.image = processing_grading.grade_image(pp.image, grading_params) diff --git a/scripts/nudenet_ext.py b/scripts/nudenet_ext.py index be24ea1ed..3a735ba32 100644 --- a/scripts/nudenet_ext.py +++ b/scripts/nudenet_ext.py @@ -124,7 +124,7 @@ def process( # defines script for dual-mode usage -class Script(scripts.Script): +class ScriptNudeNet(scripts.Script): # see below for all available options and callbacks # @@ -148,7 +148,7 @@ class Script(scripts.Script): # defines postprocessing script for dual-mode usage -class ScriptPostprocessing(scripts_postprocessing.ScriptPostprocessing): +class ScriptPostprocessingNudeNet(scripts_postprocessing.ScriptPostprocessing): name = 'NudeNet' order = 10000 diff --git a/scripts/postprocessing_video.py b/scripts/postprocessing_video.py index 84836abca..d855be5d5 100644 --- a/scripts/postprocessing_video.py +++ b/scripts/postprocessing_video.py @@ -2,8 +2,8 @@ import gradio as gr from modules import video, scripts_postprocessing -class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing): - name = "Video" +class ScriptPostprocessingVideo(scripts_postprocessing.ScriptPostprocessing): + name = "Create Video" def ui(self): with gr.Accordion('Create video', open = False, elem_id="postprocess_video_accordion"): @@ -18,7 +18,7 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing): ] with gr.Row(): - gr.HTML("  Video
") + gr.HTML("  Create video from generated images
") with gr.Row(): video_type = gr.Dropdown(label='Video file', choices=['None', 'GIF', 'PNG', 'MP4'], value='None', elem_id="extras_video_type") duration = gr.Slider(label='Duration', minimum=0.25, maximum=10, step=0.25, value=2, visible=False, elem_id="extras_video_duration") diff --git a/test/test-grading.py b/test/test-grading.py index 2018519c6..0334faba6 100644 --- a/test/test-grading.py +++ b/test/test-grading.py @@ -95,7 +95,7 @@ def test_grading_params_defaults(): assert p.split_tone_balance == 0.5 assert p.vignette == 0.0 assert p.grain == 0.0 - assert p.lut_file == "" + assert p.lut_cube_file == "" assert p.lut_strength == 1.0 return True