diff --git a/scripts/riffusion.py b/scripts/riffusion.py index 64948f0..acbed8e 100644 --- a/scripts/riffusion.py +++ b/scripts/riffusion.py @@ -5,7 +5,7 @@ import io import typing as T import os import numpy as np -from PIL import Image +from PIL import Image, ImageDraw from scipy.io import wavfile import torch import torchaudio @@ -344,16 +344,11 @@ def convert_audio( crop_width: int, rhythm:int, band_start: float, - band_end: float, + band_length: float, threshold_offset: float, ignore_range: float) -> None: - images = [] - - globs = map(lambda x: x.strip(), file_regex.split(",")) - - for g in globs: - images.extend(glob.glob(os.path.join(image_dir, g))) + images = get_image(file_regex, image_dir) print(f"Found {len(images)} images in {image_dir}, pattern {file_regex}") output_files = [] @@ -363,13 +358,12 @@ def convert_audio( if crop_method == "Fixed": width = crop_width elif crop_method.startswith("Beat Finder") and (not crop_method.endswith("(Once)") or i == 0): - width = find_cutoff(image_file, rhythm, band_start, band_end, threshold_offset, ignore_range) + width = find_cutoff(image_file, rhythm, band_start, band_length, threshold_offset, ignore_range)["cutoff"] print("Cutoff found at:", width) output_files.append(convert_audio_image(image, image_file, image_dir, width)) if join_images and len(output_files) > 1: - output_files.sort() data = [] outfile = os.path.join( image_dir, @@ -387,26 +381,100 @@ def convert_audio( print(f"Converted {len(images)} images to audio") -def find_cutoff(image, rhythm = 4, band_start = 0.25, band_end = 0.75, threshold_offset = 0.75, ignore_range = 0.05): +def test_beat_finder(image_dir: str, + file_regex: str, + rhythm:int, + band_start: float, + band_length: float, + threshold_offset: float, + ignore_range: float) -> Image: + + images = get_image(file_regex, image_dir) + if (len(images) == 0): + return None + image = images[0] + image_file = Image.open(image).convert("RGB") + output = Image.new(mode="RGB", size=(image_file.width, image_file.height + 256), color=(255,255,255)) + output.paste(image_file) + beat_finder_result = find_cutoff(image_file, rhythm, band_start, band_length, threshold_offset, ignore_range) + draw = ImageDraw.Draw(output, "RGBA") + register_rect = [(0, beat_finder_result["register_start"]), (image_file.width, beat_finder_result["register_end"])] + + level_range = beat_finder_result["level_range"] + if level_range > 0: + min_level = beat_finder_result["min_level"] + max_level = beat_finder_result["max_level"] + scale = 255 / level_range + for i, value in enumerate(beat_finder_result["one_line"]): + line_height = (max_level - value) * scale + draw.line([(i, output.height), (i, output.height - line_height)], fill=(0,0,0)) + pass + + scaled_threshold = (max_level - beat_finder_result["threshold"]) * scale + draw.line([(0, output.height - scaled_threshold), (output.width, output.height - scaled_threshold)], fill=(0,255,0)) + + draw.rectangle(register_rect, fill=(0, 0, 255, 64)) + for beat in beat_finder_result["above_threshold"]: + draw.line([(beat, 0), (beat, output.height)], fill=(255,0,0), width= 2) + + cutoff = beat_finder_result["cutoff"] + if cutoff != None: + draw.line([(cutoff, 0), (cutoff, output.height)], fill=(255,165,0), width= 2) + + del draw + + return output + + +def get_image(file_regex, image_dir): + images = [] + globs = map(lambda x: x.strip(), file_regex.split(",")) + for g in globs: + images.extend(glob.glob(os.path.join(image_dir, g))) + + images.sort() + + return images + +def find_cutoff(image, rhythm = 4, band_start = 0.25, band_length = 0.75, threshold_offset = 0.75, ignore_range = 0.05): + + result = { "cutoff": None } + ignore_distance = image.width * ignore_range - register_start = image.height * band_start - register_lenght = image.height * band_end + register_start = int(image.height * band_start) + register_lenght = int(image.height * band_length) + register_end = register_start + register_lenght - band = image.crop((0, register_start, image.width, register_start + register_lenght)) + result["register_start"] = register_start + result["register_end"] = register_end + + band = image.crop((0, register_start, image.width, register_end)) gray_image = band.convert('L') one_line = list(gray_image.resize((gray_image.width, 1)).getdata()) - one_pixel = list(gray_image.resize((1, 1)).getdata()) - average = one_pixel[0] - threshold = min(one_line) + (max(one_line) - min(one_line)) * threshold_offset + + result["one_line"] = one_line + + #one_pixel = list(gray_image.resize((1, 1)).getdata()) + #average = one_pixel[0] + min_level = min(one_line) + max_level = max(one_line) + level_range = max_level - min_level + result["min_level"] = min_level + result["max_level"] = max_level + result["level_range"] = level_range + threshold = min_level + level_range * threshold_offset + result["threshold"] = threshold above_threshold = [] for i, value in enumerate(one_line): if value < threshold and (len(above_threshold) == 0 or i - above_threshold[-1] > ignore_distance): above_threshold.append(i) + result["above_threshold"] = above_threshold + if len(above_threshold) == 0: print("Failed to find beats") - return None + return result print("Beats found:", above_threshold) @@ -417,16 +485,25 @@ def find_cutoff(image, rhythm = 4, band_start = 0.25, band_end = 0.75, threshold distances.append(value - above_threshold[i - 1]) + if (len(distances) == 0): + return result + + result["distances"] = distances + distance = median(distances) + result["distance"] = distance + print("Interval:", distance) beat_count = int(len(above_threshold) / rhythm) * rhythm + result["beat_count"] = beat_count if beat_count == 0: print("Missmatching rhythm") - return None + return result cutoff = int(beat_count * distance) - return cutoff + result["cutoff"] = cutoff + return result def on_ui_tabs(): with gr.Blocks() as riffusion_ui: @@ -476,33 +553,59 @@ def on_ui_tabs(): interactive=True, ) - band_start = gr.Number( + band_start = gr.Slider( label="Band start", + min=0, + max=0.99, + step=0.05, value=0.25, - precision=2, interactive=True, ) - band_end = gr.Number( - label="Band end", + band_length = gr.Slider( + label="Band length", + min=0.01, + max=1, + step=0.05, value=0.5, - precision=2, interactive=True, ) with gr.Row(): - threshold_offset = gr.Number( - label="Threshold offset", + threshold_offset = gr.Slider( + label="Threshold", + min=0.01, + max=1, + step=0.05, value=0.1, - precision=2, interactive=True, ) - ignore_range = gr.Number( + ignore_range = gr.Slider( label="Ignore range", + min=0, + max=1, + step=0.05, value=0.1, - precision=23, interactive=True, ) + with gr.Column(): + test_beat_finder_button = gr.Button( + "Test", label="Test", variant="primary" + ) + beat_finder_image = gr.Image(type="pil") + test_beat_finder_button.click( + on_beat_finder_test_click, + inputs=[ + image_directory, + file_regex, + rhythm, + band_start, + band_length, + threshold_offset, + ignore_range + ], + outputs=[beat_finder_image], + ) crop_method.change(on_crop_method_change, crop_method, [fixed_block,beat_finder_block]) with gr.Column(variant="panel"): @@ -520,7 +623,7 @@ def on_ui_tabs(): crop_width, rhythm, band_start, - band_end, + band_length, threshold_offset, ignore_range ], @@ -529,9 +632,22 @@ def on_ui_tabs(): gr.HTML(value="

Converts all images in a folder to audio

") return ((riffusion_ui, "Riffusion", "riffusion_ui"),) +def on_beat_finder_test_click(image_dir: str, + file_regex: str, + rhythm:int, + band_start: float, + band_length: float, + threshold_offset: float, + ignore_range: float): + + image = test_beat_finder(image_dir, file_regex, rhythm, band_start, band_length, threshold_offset, ignore_range) + return gr.update(value=image) + + def on_crop_method_change(crop_method): fixed_visible = True if crop_method == "Fixed" else False beat_finder_visible = True if crop_method.startswith("Beat Finder") else False return gr.update(visible=fixed_visible), gr.update(visible=beat_finder_visible) script_callbacks.on_ui_tabs(on_ui_tabs) +