Added visualization for beat finder

pull/26/head
jahu00 2023-04-29 23:48:26 +02:00
parent abab50528b
commit 77dfe01f0c
1 changed files with 147 additions and 31 deletions

View File

@ -5,7 +5,7 @@ import io
import typing as T
import os
import numpy as np
from PIL import Image
from PIL import Image, ImageDraw
from scipy.io import wavfile
import torch
import torchaudio
@ -344,16 +344,11 @@ def convert_audio(
crop_width: int,
rhythm:int,
band_start: float,
band_end: float,
band_length: float,
threshold_offset: float,
ignore_range: float) -> None:
images = []
globs = map(lambda x: x.strip(), file_regex.split(","))
for g in globs:
images.extend(glob.glob(os.path.join(image_dir, g)))
images = get_image(file_regex, image_dir)
print(f"Found {len(images)} images in {image_dir}, pattern {file_regex}")
output_files = []
@ -363,13 +358,12 @@ def convert_audio(
if crop_method == "Fixed":
width = crop_width
elif crop_method.startswith("Beat Finder") and (not crop_method.endswith("(Once)") or i == 0):
width = find_cutoff(image_file, rhythm, band_start, band_end, threshold_offset, ignore_range)
width = find_cutoff(image_file, rhythm, band_start, band_length, threshold_offset, ignore_range)["cutoff"]
print("Cutoff found at:", width)
output_files.append(convert_audio_image(image, image_file, image_dir, width))
if join_images and len(output_files) > 1:
output_files.sort()
data = []
outfile = os.path.join(
image_dir,
@ -387,26 +381,100 @@ def convert_audio(
print(f"Converted {len(images)} images to audio")
def find_cutoff(image, rhythm = 4, band_start = 0.25, band_end = 0.75, threshold_offset = 0.75, ignore_range = 0.05):
def test_beat_finder(image_dir: str,
file_regex: str,
rhythm:int,
band_start: float,
band_length: float,
threshold_offset: float,
ignore_range: float) -> Image:
images = get_image(file_regex, image_dir)
if (len(images) == 0):
return None
image = images[0]
image_file = Image.open(image).convert("RGB")
output = Image.new(mode="RGB", size=(image_file.width, image_file.height + 256), color=(255,255,255))
output.paste(image_file)
beat_finder_result = find_cutoff(image_file, rhythm, band_start, band_length, threshold_offset, ignore_range)
draw = ImageDraw.Draw(output, "RGBA")
register_rect = [(0, beat_finder_result["register_start"]), (image_file.width, beat_finder_result["register_end"])]
level_range = beat_finder_result["level_range"]
if level_range > 0:
min_level = beat_finder_result["min_level"]
max_level = beat_finder_result["max_level"]
scale = 255 / level_range
for i, value in enumerate(beat_finder_result["one_line"]):
line_height = (max_level - value) * scale
draw.line([(i, output.height), (i, output.height - line_height)], fill=(0,0,0))
pass
scaled_threshold = (max_level - beat_finder_result["threshold"]) * scale
draw.line([(0, output.height - scaled_threshold), (output.width, output.height - scaled_threshold)], fill=(0,255,0))
draw.rectangle(register_rect, fill=(0, 0, 255, 64))
for beat in beat_finder_result["above_threshold"]:
draw.line([(beat, 0), (beat, output.height)], fill=(255,0,0), width= 2)
cutoff = beat_finder_result["cutoff"]
if cutoff != None:
draw.line([(cutoff, 0), (cutoff, output.height)], fill=(255,165,0), width= 2)
del draw
return output
def get_image(file_regex, image_dir):
images = []
globs = map(lambda x: x.strip(), file_regex.split(","))
for g in globs:
images.extend(glob.glob(os.path.join(image_dir, g)))
images.sort()
return images
def find_cutoff(image, rhythm = 4, band_start = 0.25, band_length = 0.75, threshold_offset = 0.75, ignore_range = 0.05):
result = { "cutoff": None }
ignore_distance = image.width * ignore_range
register_start = image.height * band_start
register_lenght = image.height * band_end
register_start = int(image.height * band_start)
register_lenght = int(image.height * band_length)
register_end = register_start + register_lenght
band = image.crop((0, register_start, image.width, register_start + register_lenght))
result["register_start"] = register_start
result["register_end"] = register_end
band = image.crop((0, register_start, image.width, register_end))
gray_image = band.convert('L')
one_line = list(gray_image.resize((gray_image.width, 1)).getdata())
one_pixel = list(gray_image.resize((1, 1)).getdata())
average = one_pixel[0]
threshold = min(one_line) + (max(one_line) - min(one_line)) * threshold_offset
result["one_line"] = one_line
#one_pixel = list(gray_image.resize((1, 1)).getdata())
#average = one_pixel[0]
min_level = min(one_line)
max_level = max(one_line)
level_range = max_level - min_level
result["min_level"] = min_level
result["max_level"] = max_level
result["level_range"] = level_range
threshold = min_level + level_range * threshold_offset
result["threshold"] = threshold
above_threshold = []
for i, value in enumerate(one_line):
if value < threshold and (len(above_threshold) == 0 or i - above_threshold[-1] > ignore_distance):
above_threshold.append(i)
result["above_threshold"] = above_threshold
if len(above_threshold) == 0:
print("Failed to find beats")
return None
return result
print("Beats found:", above_threshold)
@ -417,16 +485,25 @@ def find_cutoff(image, rhythm = 4, band_start = 0.25, band_end = 0.75, threshold
distances.append(value - above_threshold[i - 1])
if (len(distances) == 0):
return result
result["distances"] = distances
distance = median(distances)
result["distance"] = distance
print("Interval:", distance)
beat_count = int(len(above_threshold) / rhythm) * rhythm
result["beat_count"] = beat_count
if beat_count == 0:
print("Missmatching rhythm")
return None
return result
cutoff = int(beat_count * distance)
return cutoff
result["cutoff"] = cutoff
return result
def on_ui_tabs():
with gr.Blocks() as riffusion_ui:
@ -476,33 +553,59 @@ def on_ui_tabs():
interactive=True,
)
band_start = gr.Number(
band_start = gr.Slider(
label="Band start",
min=0,
max=0.99,
step=0.05,
value=0.25,
precision=2,
interactive=True,
)
band_end = gr.Number(
label="Band end",
band_length = gr.Slider(
label="Band length",
min=0.01,
max=1,
step=0.05,
value=0.5,
precision=2,
interactive=True,
)
with gr.Row():
threshold_offset = gr.Number(
label="Threshold offset",
threshold_offset = gr.Slider(
label="Threshold",
min=0.01,
max=1,
step=0.05,
value=0.1,
precision=2,
interactive=True,
)
ignore_range = gr.Number(
ignore_range = gr.Slider(
label="Ignore range",
min=0,
max=1,
step=0.05,
value=0.1,
precision=23,
interactive=True,
)
with gr.Column():
test_beat_finder_button = gr.Button(
"Test", label="Test", variant="primary"
)
beat_finder_image = gr.Image(type="pil")
test_beat_finder_button.click(
on_beat_finder_test_click,
inputs=[
image_directory,
file_regex,
rhythm,
band_start,
band_length,
threshold_offset,
ignore_range
],
outputs=[beat_finder_image],
)
crop_method.change(on_crop_method_change, crop_method, [fixed_block,beat_finder_block])
with gr.Column(variant="panel"):
@ -520,7 +623,7 @@ def on_ui_tabs():
crop_width,
rhythm,
band_start,
band_end,
band_length,
threshold_offset,
ignore_range
],
@ -529,9 +632,22 @@ def on_ui_tabs():
gr.HTML(value="<p>Converts all images in a folder to audio</p>")
return ((riffusion_ui, "Riffusion", "riffusion_ui"),)
def on_beat_finder_test_click(image_dir: str,
file_regex: str,
rhythm:int,
band_start: float,
band_length: float,
threshold_offset: float,
ignore_range: float):
image = test_beat_finder(image_dir, file_regex, rhythm, band_start, band_length, threshold_offset, ignore_range)
return gr.update(value=image)
def on_crop_method_change(crop_method):
fixed_visible = True if crop_method == "Fixed" else False
beat_finder_visible = True if crop_method.startswith("Beat Finder") else False
return gr.update(visible=fixed_visible), gr.update(visible=beat_finder_visible)
script_callbacks.on_ui_tabs(on_ui_tabs)