add color grading to processing

Signed-off-by: vladmandic <mandic00@live.com>
pull/4703/head
vladmandic 2026-03-23 08:44:07 +01:00
parent 3261e2eeae
commit 09b9ae32c1
13 changed files with 74 additions and 38 deletions

View File

@ -27,7 +27,14 @@ But also many smaller quality-of-life improvements - for full details, see [Chan
- [Anima Preview-v2](https://huggingface.co/circlestone-labs/Anima)
- **Image manipulation**
- new **color grading** module
- update **latent corrections** *(former HDR Corrections)* and expand allowed models
apply basic corrections to your images: brightness,contrast,saturation,shadows,highlights
move to professional photo corrections: hue,gamma,sharpness,temperature
correct tone: shadows,midtones,highlights
add effects: vignette,grain
apply professional lut-table using .cube file
*hint* color grading is available as step during generate or as processing item for already existing images
- update **latent corrections** *(former HDR Corrections)*
expand allowed models
- add support for [spandrel](https://github.com/chaiNNer-org/spandrel)
**upscaling** engine with suport for new upscaling model families
- add two new ai upscalers: *RealPLKSR NomosWebPhoto* and *RealPLKSR AnimeSharpV2*

View File

@ -12,7 +12,6 @@
## Internal
- Integrate: [Depth3D](https://github.com/vladmandic/sd-extension-depth3d)
- Feature: Color grading in processing
- Feature: RIFE update
- Feature: RIFE in processing
- Feature: SeedVR2 in processing
@ -146,3 +145,5 @@ TODO: Investigate which models are diffusers-compatible and prioritize!
- modules/modular_guiders.py:65:58: W0511: TODO: guiders
- processing: remove duplicate mask params
- resize image: enable full VAE mode for resize-latent
modules/sd_samplers_diffusers.py:353:31: W0511: TODO enso-required (fixme)

@ -1 +1 @@
Subproject commit d3f63ee8c3b6220f290e5fa54dc172a772b8c108
Subproject commit 0861ae00f2ad057a914ca82e45fe6635dde7417e

View File

@ -360,7 +360,7 @@ def process_samples(p: StableDiffusionProcessing, samples):
split_tone_balance=getattr(p, 'grading_split_tone_balance', 0.5),
vignette=getattr(p, 'grading_vignette', 0.0),
grain=getattr(p, 'grading_grain', 0.0),
lut_file=getattr(p, 'grading_lut_file', ''),
lut_cube_file=getattr(p, 'grading_lut_file', ''),
lut_strength=getattr(p, 'grading_lut_strength', 1.0),
)
if processing_grading.is_active(grading_params):

View File

@ -66,7 +66,7 @@ class GradingParams:
vignette: float = 0.0
grain: float = 0.0
# lut
lut_file: str = ""
lut_cube_file: str = ""
lut_strength: float = 1.0
def __post_init__(self):
@ -179,17 +179,17 @@ def _apply_color_temp(img: torch.Tensor, kelvin: float) -> torch.Tensor:
return (img * scales).clamp(0, 1)
def _apply_lut(image: Image.Image, lut_file: str, strength: float) -> Image.Image:
def _apply_lut(image: Image.Image, lut_cube_file: str, strength: float) -> Image.Image:
"""Apply .cube LUT file via pillow-lut-tools."""
if not lut_file or not os.path.isfile(lut_file):
if not lut_cube_file or not os.path.isfile(lut_cube_file):
return image
pillow_lut = _ensure_pillow_lut()
try:
cube = pillow_lut.load_cube_file(lut_file)
cube = pillow_lut.load_cube_file(lut_cube_file)
if strength != 1.0:
cube = pillow_lut.amplify_lut(cube, strength)
result = image.filter(cube)
debug(f'Grading LUT: file={os.path.basename(lut_file)} strength={strength}')
debug(f'Grading LUT: file={os.path.basename(lut_cube_file)} strength={strength}')
return result
except Exception as e:
log.error(f'Grading LUT: {e}')
@ -198,8 +198,8 @@ def _apply_lut(image: Image.Image, lut_file: str, strength: float) -> Image.Imag
def grade_image(image: Image.Image, params: GradingParams) -> Image.Image:
"""Full grading pipeline: PIL -> GPU tensor -> kornia ops -> PIL."""
log.debug(f"Grading: params={params}")
kornia = _ensure_kornia()
debug(f'Grading: params={params}')
arr = np.array(image).astype(np.float32) / 255.0
tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0)
tensor = tensor.to(device=devices.device, dtype=devices.dtype)
@ -246,7 +246,7 @@ def grade_image(image: Image.Image, params: GradingParams) -> Image.Image:
result = Image.fromarray(arr)
# LUT applied last (CPU, via pillow-lut-tools)
if params.lut_file:
result = _apply_lut(result, params.lut_file, params.lut_strength)
if params.lut_cube_file:
result = _apply_lut(result, params.lut_cube_file, params.lut_strength)
return result

View File

@ -76,9 +76,14 @@ class ScriptPostprocessingRunner:
script.controls = wrap_call(script.ui, script.filename, "ui")
if script.controls is None:
script.controls = {}
for control in script.controls.values():
control.custom_script_source = os.path.basename(script.filename)
inputs += list(script.controls.values())
if isinstance(script.controls, list) or isinstance(script.controls, tuple):
for control in script.controls:
control.custom_script_source = os.path.basename(script.filename)
inputs += script.controls
else:
for control in script.controls.values():
control.custom_script_source = os.path.basename(script.filename)
inputs += list(script.controls.values())
script.args_to = len(inputs)
def scripts_in_preferred_order(self):
@ -109,11 +114,16 @@ class ScriptPostprocessingRunner:
for script in self.scripts_in_preferred_order():
jobid = shared.state.begin(script.name)
script_args = args[script.args_from:script.args_to]
process_args = {}
for (name, _component), value in zip(script.controls.items(), script_args, strict=False):
process_args[name] = value
log.debug(f'Process: script="{script.name}" args={process_args}')
script.process(pp, **process_args)
process_args = []
process_kwargs = {}
if isinstance(script.controls, list) or isinstance(script.controls, tuple):
for _control, value in zip(script.controls, script_args, strict=False):
process_args.append(value)
else:
for (name, _component), value in zip(script.controls.items(), script_args, strict=False):
process_kwargs[name] = value
log.debug(f'Process: script="{script.name}" args={process_args} kwargs={process_kwargs}')
script.process(pp, *process_args, **process_kwargs)
shared.state.end(jobid)
def create_args_for_run(self, scripts_args):
@ -139,9 +149,14 @@ class ScriptPostprocessingRunner:
continue
jobid = shared.state.begin(script.name)
script_args = args[script.args_from:script.args_to]
process_args = {}
for (name, _component), value in zip(script.controls.items(), script_args, strict=False):
process_args[name] = value
log.debug(f'Postprocess: script={script.name} args={process_args}')
script.postprocess(filenames, **process_args)
process_args = []
process_kwargs = {}
if isinstance(script.controls, list) or isinstance(script.controls, tuple):
for _control, value in zip(script.controls, script_args, strict=False):
process_args.append(value)
else:
for (name, _component), value in zip(script.controls.items(), script_args, strict=False):
process_kwargs[name] = value
log.debug(f'Postprocess: script={script.name} args={process_args} kwargs={process_kwargs}')
script.postprocess(filenames, *process_args, **process_kwargs)
shared.state.end(jobid)

View File

@ -17,7 +17,6 @@ def update_model_hashes():
unets = {}
for k, v in sd_unet.unet_dict.items():
unets[k] = sd_checkpoint.CheckpointInfo(name=k, filename=v, model_type='unet')
print('HERE3', unets[k])
yield from sd_models.update_model_hashes(unets, model_type='unet')
yield from sd_models.update_model_hashes(model_type='checkpoint')

View File

@ -197,7 +197,7 @@ def create_color_inputs(tab):
grading_hue = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0, label='Hue', elem_id=f"{tab}_grading_hue")
grading_gamma = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, value=1.0, label='Gamma', elem_id=f"{tab}_grading_gamma")
grading_sharpness = gr.Slider(minimum=0.0, maximum=2.0, step=0.05, value=0, label='Sharpness', elem_id=f"{tab}_grading_sharpness")
grading_color_temp = gr.Slider(minimum=2000, maximum=12000, step=100, value=6500, label='Color temp (K)', elem_id=f"{tab}_grading_color_temp")
grading_color_temp = gr.Slider(minimum=2000, maximum=12000, step=100, value=6500, label='Color temp', elem_id=f"{tab}_grading_color_temp")
with gr.Group():
gr.HTML('<h3>Tone</h3>')
with gr.Row(elem_id=f"{tab}_grading_tone_row"):
@ -211,7 +211,7 @@ def create_color_inputs(tab):
with gr.Row(elem_id=f"{tab}_grading_split_row"):
grading_shadows_tint = gr.ColorPicker(label="Shadows tint", value="#000000", elem_id=f"{tab}_grading_shadows_tint")
grading_highlights_tint = gr.ColorPicker(label="Highlights tint", value="#ffffff", elem_id=f"{tab}_grading_highlights_tint")
grading_split_tone_balance = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.5, label='Balance', elem_id=f"{tab}_grading_split_tone_balance")
grading_split_tone_balance = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.5, label='Split tone balance', elem_id=f"{tab}_grading_split_tone_balance")
with gr.Group():
gr.HTML('<h3>Effects</h3>')
with gr.Row(elem_id=f"{tab}_grading_effects_row"):
@ -220,9 +220,9 @@ def create_color_inputs(tab):
with gr.Group():
gr.HTML('<h3>LUT</h3>')
with gr.Row(elem_id=f"{tab}_grading_lut_row"):
grading_lut_file = gr.File(label='LUT .cube file', file_types=['.cube'], elem_id=f"{tab}_grading_lut_file")
grading_lut_cube_file = gr.File(label='LUT .cube file', file_types=['.cube'], elem_id=f"{tab}_grading_lut_file")
grading_lut_strength = gr.Slider(minimum=0.0, maximum=2.0, step=0.05, value=1.0, label='LUT strength', elem_id=f"{tab}_grading_lut_strength")
return grading_brightness, grading_contrast, grading_saturation, grading_hue, grading_gamma, grading_sharpness, grading_color_temp, grading_shadows, grading_midtones, grading_highlights, grading_clahe_clip, grading_clahe_grid, grading_shadows_tint, grading_highlights_tint, grading_split_tone_balance, grading_vignette, grading_grain, grading_lut_file, grading_lut_strength
return grading_brightness, grading_contrast, grading_saturation, grading_hue, grading_gamma, grading_sharpness, grading_color_temp, grading_shadows, grading_midtones, grading_highlights, grading_clahe_clip, grading_clahe_grid, grading_shadows_tint, grading_highlights_tint, grading_split_tone_balance, grading_vignette, grading_grain, grading_lut_cube_file, grading_lut_strength
def create_sampler_and_steps_selection(choices, tabname, default_steps:int=20):

View File

@ -9,7 +9,6 @@ Installed via apply_patch() during pipeline loading.
import os
import time
import torch
from modules import shared, sd_models
from modules.logger import log
from modules.lora import network, network_lokr, lora_convert
@ -180,7 +179,7 @@ def apply_patch():
lora_state_dict won't detect them as AI toolkit format. This patch checks for
bare keys after the original returns and adds the prefix + re-runs conversion.
"""
global patched
global patched # pylint: disable=global-statement
if patched:
return
patched = True

15
scripts/color_grading.py Normal file
View File

@ -0,0 +1,15 @@
from modules import scripts_postprocessing, ui_sections, processing_grading
class ScriptPostprocessingColorGrading(scripts_postprocessing.ScriptPostprocessing):
name = "Color Grading"
def ui(self):
ui_controls = ui_sections.create_color_inputs('process')
ui_controls_dict = {control.label.replace(" ", "_").replace(".", "").lower(): control for control in ui_controls}
return ui_controls_dict
def process(self, pp: scripts_postprocessing.PostprocessedImage, *args, **kwargs): # pylint: disable=arguments-differ
grading_params = processing_grading.GradingParams(*args, **kwargs)
if processing_grading.is_active(grading_params):
pp.image = processing_grading.grade_image(pp.image, grading_params)

View File

@ -124,7 +124,7 @@ def process(
# defines script for dual-mode usage
class Script(scripts.Script):
class ScriptNudeNet(scripts.Script):
# see below for all available options and callbacks
# <https://github.com/vladmandic/automatic/blob/master/modules/scripts.py#L26>
@ -148,7 +148,7 @@ class Script(scripts.Script):
# defines postprocessing script for dual-mode usage
class ScriptPostprocessing(scripts_postprocessing.ScriptPostprocessing):
class ScriptPostprocessingNudeNet(scripts_postprocessing.ScriptPostprocessing):
name = 'NudeNet'
order = 10000

View File

@ -2,8 +2,8 @@ import gradio as gr
from modules import video, scripts_postprocessing
class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
name = "Video"
class ScriptPostprocessingVideo(scripts_postprocessing.ScriptPostprocessing):
name = "Create Video"
def ui(self):
with gr.Accordion('Create video', open = False, elem_id="postprocess_video_accordion"):
@ -18,7 +18,7 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
]
with gr.Row():
gr.HTML("<span>&nbsp Video</span><br>")
gr.HTML("<span>&nbsp Create video from generated images</span><br>")
with gr.Row():
video_type = gr.Dropdown(label='Video file', choices=['None', 'GIF', 'PNG', 'MP4'], value='None', elem_id="extras_video_type")
duration = gr.Slider(label='Duration', minimum=0.25, maximum=10, step=0.25, value=2, visible=False, elem_id="extras_video_duration")

View File

@ -95,7 +95,7 @@ def test_grading_params_defaults():
assert p.split_tone_balance == 0.5
assert p.vignette == 0.0
assert p.grain == 0.0
assert p.lut_file == ""
assert p.lut_cube_file == ""
assert p.lut_strength == 1.0
return True