add color grading to processing

Signed-off-by: vladmandic <mandic00@live.com>
pull/4703/head
vladmandic 2026-03-23 08:44:07 +01:00
parent 3261e2eeae
commit 09b9ae32c1
13 changed files with 74 additions and 38 deletions

View File

@ -27,7 +27,14 @@ But also many smaller quality-of-life improvements - for full details, see [Chan
- [Anima Preview-v2](https://huggingface.co/circlestone-labs/Anima) - [Anima Preview-v2](https://huggingface.co/circlestone-labs/Anima)
- **Image manipulation** - **Image manipulation**
- new **color grading** module - new **color grading** module
- update **latent corrections** *(former HDR Corrections)* and expand allowed models apply basic corrections to your images: brightness,contrast,saturation,shadows,highlights
move to professional photo corrections: hue,gamma,sharpness,temperature
correct tone: shadows,midtones,highlights
add effects: vignette,grain
apply professional lut-table using .cube file
*hint* color grading is available as step during generate or as processing item for already existing images
- update **latent corrections** *(former HDR Corrections)*
expand allowed models
- add support for [spandrel](https://github.com/chaiNNer-org/spandrel) - add support for [spandrel](https://github.com/chaiNNer-org/spandrel)
**upscaling** engine with suport for new upscaling model families **upscaling** engine with suport for new upscaling model families
- add two new ai upscalers: *RealPLKSR NomosWebPhoto* and *RealPLKSR AnimeSharpV2* - add two new ai upscalers: *RealPLKSR NomosWebPhoto* and *RealPLKSR AnimeSharpV2*

View File

@ -12,7 +12,6 @@
## Internal ## Internal
- Integrate: [Depth3D](https://github.com/vladmandic/sd-extension-depth3d) - Integrate: [Depth3D](https://github.com/vladmandic/sd-extension-depth3d)
- Feature: Color grading in processing
- Feature: RIFE update - Feature: RIFE update
- Feature: RIFE in processing - Feature: RIFE in processing
- Feature: SeedVR2 in processing - Feature: SeedVR2 in processing
@ -146,3 +145,5 @@ TODO: Investigate which models are diffusers-compatible and prioritize!
- modules/modular_guiders.py:65:58: W0511: TODO: guiders - modules/modular_guiders.py:65:58: W0511: TODO: guiders
- processing: remove duplicate mask params - processing: remove duplicate mask params
- resize image: enable full VAE mode for resize-latent - resize image: enable full VAE mode for resize-latent
modules/sd_samplers_diffusers.py:353:31: W0511: TODO enso-required (fixme)

@ -1 +1 @@
Subproject commit d3f63ee8c3b6220f290e5fa54dc172a772b8c108 Subproject commit 0861ae00f2ad057a914ca82e45fe6635dde7417e

View File

@ -360,7 +360,7 @@ def process_samples(p: StableDiffusionProcessing, samples):
split_tone_balance=getattr(p, 'grading_split_tone_balance', 0.5), split_tone_balance=getattr(p, 'grading_split_tone_balance', 0.5),
vignette=getattr(p, 'grading_vignette', 0.0), vignette=getattr(p, 'grading_vignette', 0.0),
grain=getattr(p, 'grading_grain', 0.0), grain=getattr(p, 'grading_grain', 0.0),
lut_file=getattr(p, 'grading_lut_file', ''), lut_cube_file=getattr(p, 'grading_lut_file', ''),
lut_strength=getattr(p, 'grading_lut_strength', 1.0), lut_strength=getattr(p, 'grading_lut_strength', 1.0),
) )
if processing_grading.is_active(grading_params): if processing_grading.is_active(grading_params):

View File

@ -66,7 +66,7 @@ class GradingParams:
vignette: float = 0.0 vignette: float = 0.0
grain: float = 0.0 grain: float = 0.0
# lut # lut
lut_file: str = "" lut_cube_file: str = ""
lut_strength: float = 1.0 lut_strength: float = 1.0
def __post_init__(self): def __post_init__(self):
@ -179,17 +179,17 @@ def _apply_color_temp(img: torch.Tensor, kelvin: float) -> torch.Tensor:
return (img * scales).clamp(0, 1) return (img * scales).clamp(0, 1)
def _apply_lut(image: Image.Image, lut_file: str, strength: float) -> Image.Image: def _apply_lut(image: Image.Image, lut_cube_file: str, strength: float) -> Image.Image:
"""Apply .cube LUT file via pillow-lut-tools.""" """Apply .cube LUT file via pillow-lut-tools."""
if not lut_file or not os.path.isfile(lut_file): if not lut_cube_file or not os.path.isfile(lut_cube_file):
return image return image
pillow_lut = _ensure_pillow_lut() pillow_lut = _ensure_pillow_lut()
try: try:
cube = pillow_lut.load_cube_file(lut_file) cube = pillow_lut.load_cube_file(lut_cube_file)
if strength != 1.0: if strength != 1.0:
cube = pillow_lut.amplify_lut(cube, strength) cube = pillow_lut.amplify_lut(cube, strength)
result = image.filter(cube) result = image.filter(cube)
debug(f'Grading LUT: file={os.path.basename(lut_file)} strength={strength}') debug(f'Grading LUT: file={os.path.basename(lut_cube_file)} strength={strength}')
return result return result
except Exception as e: except Exception as e:
log.error(f'Grading LUT: {e}') log.error(f'Grading LUT: {e}')
@ -198,8 +198,8 @@ def _apply_lut(image: Image.Image, lut_file: str, strength: float) -> Image.Imag
def grade_image(image: Image.Image, params: GradingParams) -> Image.Image: def grade_image(image: Image.Image, params: GradingParams) -> Image.Image:
"""Full grading pipeline: PIL -> GPU tensor -> kornia ops -> PIL.""" """Full grading pipeline: PIL -> GPU tensor -> kornia ops -> PIL."""
log.debug(f"Grading: params={params}")
kornia = _ensure_kornia() kornia = _ensure_kornia()
debug(f'Grading: params={params}')
arr = np.array(image).astype(np.float32) / 255.0 arr = np.array(image).astype(np.float32) / 255.0
tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0) tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0)
tensor = tensor.to(device=devices.device, dtype=devices.dtype) tensor = tensor.to(device=devices.device, dtype=devices.dtype)
@ -246,7 +246,7 @@ def grade_image(image: Image.Image, params: GradingParams) -> Image.Image:
result = Image.fromarray(arr) result = Image.fromarray(arr)
# LUT applied last (CPU, via pillow-lut-tools) # LUT applied last (CPU, via pillow-lut-tools)
if params.lut_file: if params.lut_cube_file:
result = _apply_lut(result, params.lut_file, params.lut_strength) result = _apply_lut(result, params.lut_cube_file, params.lut_strength)
return result return result

View File

@ -76,9 +76,14 @@ class ScriptPostprocessingRunner:
script.controls = wrap_call(script.ui, script.filename, "ui") script.controls = wrap_call(script.ui, script.filename, "ui")
if script.controls is None: if script.controls is None:
script.controls = {} script.controls = {}
for control in script.controls.values(): if isinstance(script.controls, list) or isinstance(script.controls, tuple):
control.custom_script_source = os.path.basename(script.filename) for control in script.controls:
inputs += list(script.controls.values()) control.custom_script_source = os.path.basename(script.filename)
inputs += script.controls
else:
for control in script.controls.values():
control.custom_script_source = os.path.basename(script.filename)
inputs += list(script.controls.values())
script.args_to = len(inputs) script.args_to = len(inputs)
def scripts_in_preferred_order(self): def scripts_in_preferred_order(self):
@ -109,11 +114,16 @@ class ScriptPostprocessingRunner:
for script in self.scripts_in_preferred_order(): for script in self.scripts_in_preferred_order():
jobid = shared.state.begin(script.name) jobid = shared.state.begin(script.name)
script_args = args[script.args_from:script.args_to] script_args = args[script.args_from:script.args_to]
process_args = {} process_args = []
for (name, _component), value in zip(script.controls.items(), script_args, strict=False): process_kwargs = {}
process_args[name] = value if isinstance(script.controls, list) or isinstance(script.controls, tuple):
log.debug(f'Process: script="{script.name}" args={process_args}') for _control, value in zip(script.controls, script_args, strict=False):
script.process(pp, **process_args) process_args.append(value)
else:
for (name, _component), value in zip(script.controls.items(), script_args, strict=False):
process_kwargs[name] = value
log.debug(f'Process: script="{script.name}" args={process_args} kwargs={process_kwargs}')
script.process(pp, *process_args, **process_kwargs)
shared.state.end(jobid) shared.state.end(jobid)
def create_args_for_run(self, scripts_args): def create_args_for_run(self, scripts_args):
@ -139,9 +149,14 @@ class ScriptPostprocessingRunner:
continue continue
jobid = shared.state.begin(script.name) jobid = shared.state.begin(script.name)
script_args = args[script.args_from:script.args_to] script_args = args[script.args_from:script.args_to]
process_args = {} process_args = []
for (name, _component), value in zip(script.controls.items(), script_args, strict=False): process_kwargs = {}
process_args[name] = value if isinstance(script.controls, list) or isinstance(script.controls, tuple):
log.debug(f'Postprocess: script={script.name} args={process_args}') for _control, value in zip(script.controls, script_args, strict=False):
script.postprocess(filenames, **process_args) process_args.append(value)
else:
for (name, _component), value in zip(script.controls.items(), script_args, strict=False):
process_kwargs[name] = value
log.debug(f'Postprocess: script={script.name} args={process_args} kwargs={process_kwargs}')
script.postprocess(filenames, *process_args, **process_kwargs)
shared.state.end(jobid) shared.state.end(jobid)

View File

@ -17,7 +17,6 @@ def update_model_hashes():
unets = {} unets = {}
for k, v in sd_unet.unet_dict.items(): for k, v in sd_unet.unet_dict.items():
unets[k] = sd_checkpoint.CheckpointInfo(name=k, filename=v, model_type='unet') unets[k] = sd_checkpoint.CheckpointInfo(name=k, filename=v, model_type='unet')
print('HERE3', unets[k])
yield from sd_models.update_model_hashes(unets, model_type='unet') yield from sd_models.update_model_hashes(unets, model_type='unet')
yield from sd_models.update_model_hashes(model_type='checkpoint') yield from sd_models.update_model_hashes(model_type='checkpoint')

View File

@ -197,7 +197,7 @@ def create_color_inputs(tab):
grading_hue = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0, label='Hue', elem_id=f"{tab}_grading_hue") grading_hue = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0, label='Hue', elem_id=f"{tab}_grading_hue")
grading_gamma = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, value=1.0, label='Gamma', elem_id=f"{tab}_grading_gamma") grading_gamma = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, value=1.0, label='Gamma', elem_id=f"{tab}_grading_gamma")
grading_sharpness = gr.Slider(minimum=0.0, maximum=2.0, step=0.05, value=0, label='Sharpness', elem_id=f"{tab}_grading_sharpness") grading_sharpness = gr.Slider(minimum=0.0, maximum=2.0, step=0.05, value=0, label='Sharpness', elem_id=f"{tab}_grading_sharpness")
grading_color_temp = gr.Slider(minimum=2000, maximum=12000, step=100, value=6500, label='Color temp (K)', elem_id=f"{tab}_grading_color_temp") grading_color_temp = gr.Slider(minimum=2000, maximum=12000, step=100, value=6500, label='Color temp', elem_id=f"{tab}_grading_color_temp")
with gr.Group(): with gr.Group():
gr.HTML('<h3>Tone</h3>') gr.HTML('<h3>Tone</h3>')
with gr.Row(elem_id=f"{tab}_grading_tone_row"): with gr.Row(elem_id=f"{tab}_grading_tone_row"):
@ -211,7 +211,7 @@ def create_color_inputs(tab):
with gr.Row(elem_id=f"{tab}_grading_split_row"): with gr.Row(elem_id=f"{tab}_grading_split_row"):
grading_shadows_tint = gr.ColorPicker(label="Shadows tint", value="#000000", elem_id=f"{tab}_grading_shadows_tint") grading_shadows_tint = gr.ColorPicker(label="Shadows tint", value="#000000", elem_id=f"{tab}_grading_shadows_tint")
grading_highlights_tint = gr.ColorPicker(label="Highlights tint", value="#ffffff", elem_id=f"{tab}_grading_highlights_tint") grading_highlights_tint = gr.ColorPicker(label="Highlights tint", value="#ffffff", elem_id=f"{tab}_grading_highlights_tint")
grading_split_tone_balance = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.5, label='Balance', elem_id=f"{tab}_grading_split_tone_balance") grading_split_tone_balance = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.5, label='Split tone balance', elem_id=f"{tab}_grading_split_tone_balance")
with gr.Group(): with gr.Group():
gr.HTML('<h3>Effects</h3>') gr.HTML('<h3>Effects</h3>')
with gr.Row(elem_id=f"{tab}_grading_effects_row"): with gr.Row(elem_id=f"{tab}_grading_effects_row"):
@ -220,9 +220,9 @@ def create_color_inputs(tab):
with gr.Group(): with gr.Group():
gr.HTML('<h3>LUT</h3>') gr.HTML('<h3>LUT</h3>')
with gr.Row(elem_id=f"{tab}_grading_lut_row"): with gr.Row(elem_id=f"{tab}_grading_lut_row"):
grading_lut_file = gr.File(label='LUT .cube file', file_types=['.cube'], elem_id=f"{tab}_grading_lut_file") grading_lut_cube_file = gr.File(label='LUT .cube file', file_types=['.cube'], elem_id=f"{tab}_grading_lut_file")
grading_lut_strength = gr.Slider(minimum=0.0, maximum=2.0, step=0.05, value=1.0, label='LUT strength', elem_id=f"{tab}_grading_lut_strength") grading_lut_strength = gr.Slider(minimum=0.0, maximum=2.0, step=0.05, value=1.0, label='LUT strength', elem_id=f"{tab}_grading_lut_strength")
return grading_brightness, grading_contrast, grading_saturation, grading_hue, grading_gamma, grading_sharpness, grading_color_temp, grading_shadows, grading_midtones, grading_highlights, grading_clahe_clip, grading_clahe_grid, grading_shadows_tint, grading_highlights_tint, grading_split_tone_balance, grading_vignette, grading_grain, grading_lut_file, grading_lut_strength return grading_brightness, grading_contrast, grading_saturation, grading_hue, grading_gamma, grading_sharpness, grading_color_temp, grading_shadows, grading_midtones, grading_highlights, grading_clahe_clip, grading_clahe_grid, grading_shadows_tint, grading_highlights_tint, grading_split_tone_balance, grading_vignette, grading_grain, grading_lut_cube_file, grading_lut_strength
def create_sampler_and_steps_selection(choices, tabname, default_steps:int=20): def create_sampler_and_steps_selection(choices, tabname, default_steps:int=20):

View File

@ -9,7 +9,6 @@ Installed via apply_patch() during pipeline loading.
import os import os
import time import time
import torch
from modules import shared, sd_models from modules import shared, sd_models
from modules.logger import log from modules.logger import log
from modules.lora import network, network_lokr, lora_convert from modules.lora import network, network_lokr, lora_convert
@ -180,7 +179,7 @@ def apply_patch():
lora_state_dict won't detect them as AI toolkit format. This patch checks for lora_state_dict won't detect them as AI toolkit format. This patch checks for
bare keys after the original returns and adds the prefix + re-runs conversion. bare keys after the original returns and adds the prefix + re-runs conversion.
""" """
global patched global patched # pylint: disable=global-statement
if patched: if patched:
return return
patched = True patched = True

15
scripts/color_grading.py Normal file
View File

@ -0,0 +1,15 @@
from modules import scripts_postprocessing, ui_sections, processing_grading
class ScriptPostprocessingColorGrading(scripts_postprocessing.ScriptPostprocessing):
name = "Color Grading"
def ui(self):
ui_controls = ui_sections.create_color_inputs('process')
ui_controls_dict = {control.label.replace(" ", "_").replace(".", "").lower(): control for control in ui_controls}
return ui_controls_dict
def process(self, pp: scripts_postprocessing.PostprocessedImage, *args, **kwargs): # pylint: disable=arguments-differ
grading_params = processing_grading.GradingParams(*args, **kwargs)
if processing_grading.is_active(grading_params):
pp.image = processing_grading.grade_image(pp.image, grading_params)

View File

@ -124,7 +124,7 @@ def process(
# defines script for dual-mode usage # defines script for dual-mode usage
class Script(scripts.Script): class ScriptNudeNet(scripts.Script):
# see below for all available options and callbacks # see below for all available options and callbacks
# <https://github.com/vladmandic/automatic/blob/master/modules/scripts.py#L26> # <https://github.com/vladmandic/automatic/blob/master/modules/scripts.py#L26>
@ -148,7 +148,7 @@ class Script(scripts.Script):
# defines postprocessing script for dual-mode usage # defines postprocessing script for dual-mode usage
class ScriptPostprocessing(scripts_postprocessing.ScriptPostprocessing): class ScriptPostprocessingNudeNet(scripts_postprocessing.ScriptPostprocessing):
name = 'NudeNet' name = 'NudeNet'
order = 10000 order = 10000

View File

@ -2,8 +2,8 @@ import gradio as gr
from modules import video, scripts_postprocessing from modules import video, scripts_postprocessing
class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing): class ScriptPostprocessingVideo(scripts_postprocessing.ScriptPostprocessing):
name = "Video" name = "Create Video"
def ui(self): def ui(self):
with gr.Accordion('Create video', open = False, elem_id="postprocess_video_accordion"): with gr.Accordion('Create video', open = False, elem_id="postprocess_video_accordion"):
@ -18,7 +18,7 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
] ]
with gr.Row(): with gr.Row():
gr.HTML("<span>&nbsp Video</span><br>") gr.HTML("<span>&nbsp Create video from generated images</span><br>")
with gr.Row(): with gr.Row():
video_type = gr.Dropdown(label='Video file', choices=['None', 'GIF', 'PNG', 'MP4'], value='None', elem_id="extras_video_type") video_type = gr.Dropdown(label='Video file', choices=['None', 'GIF', 'PNG', 'MP4'], value='None', elem_id="extras_video_type")
duration = gr.Slider(label='Duration', minimum=0.25, maximum=10, step=0.25, value=2, visible=False, elem_id="extras_video_duration") duration = gr.Slider(label='Duration', minimum=0.25, maximum=10, step=0.25, value=2, visible=False, elem_id="extras_video_duration")

View File

@ -95,7 +95,7 @@ def test_grading_params_defaults():
assert p.split_tone_balance == 0.5 assert p.split_tone_balance == 0.5
assert p.vignette == 0.0 assert p.vignette == 0.0
assert p.grain == 0.0 assert p.grain == 0.0
assert p.lut_file == "" assert p.lut_cube_file == ""
assert p.lut_strength == 1.0 assert p.lut_strength == 1.0
return True return True