Added visualization for beat finder
parent
abab50528b
commit
77dfe01f0c
|
|
@ -5,7 +5,7 @@ import io
|
|||
import typing as T
|
||||
import os
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from PIL import Image, ImageDraw
|
||||
from scipy.io import wavfile
|
||||
import torch
|
||||
import torchaudio
|
||||
|
|
@ -344,16 +344,11 @@ def convert_audio(
|
|||
crop_width: int,
|
||||
rhythm:int,
|
||||
band_start: float,
|
||||
band_end: float,
|
||||
band_length: float,
|
||||
threshold_offset: float,
|
||||
ignore_range: float) -> None:
|
||||
|
||||
images = []
|
||||
|
||||
globs = map(lambda x: x.strip(), file_regex.split(","))
|
||||
|
||||
for g in globs:
|
||||
images.extend(glob.glob(os.path.join(image_dir, g)))
|
||||
images = get_image(file_regex, image_dir)
|
||||
|
||||
print(f"Found {len(images)} images in {image_dir}, pattern {file_regex}")
|
||||
output_files = []
|
||||
|
|
@ -363,13 +358,12 @@ def convert_audio(
|
|||
if crop_method == "Fixed":
|
||||
width = crop_width
|
||||
elif crop_method.startswith("Beat Finder") and (not crop_method.endswith("(Once)") or i == 0):
|
||||
width = find_cutoff(image_file, rhythm, band_start, band_end, threshold_offset, ignore_range)
|
||||
width = find_cutoff(image_file, rhythm, band_start, band_length, threshold_offset, ignore_range)["cutoff"]
|
||||
print("Cutoff found at:", width)
|
||||
|
||||
output_files.append(convert_audio_image(image, image_file, image_dir, width))
|
||||
|
||||
if join_images and len(output_files) > 1:
|
||||
output_files.sort()
|
||||
data = []
|
||||
outfile = os.path.join(
|
||||
image_dir,
|
||||
|
|
@ -387,26 +381,100 @@ def convert_audio(
|
|||
|
||||
print(f"Converted {len(images)} images to audio")
|
||||
|
||||
def find_cutoff(image, rhythm = 4, band_start = 0.25, band_end = 0.75, threshold_offset = 0.75, ignore_range = 0.05):
|
||||
def test_beat_finder(image_dir: str,
|
||||
file_regex: str,
|
||||
rhythm:int,
|
||||
band_start: float,
|
||||
band_length: float,
|
||||
threshold_offset: float,
|
||||
ignore_range: float) -> Image:
|
||||
|
||||
images = get_image(file_regex, image_dir)
|
||||
if (len(images) == 0):
|
||||
return None
|
||||
|
||||
image = images[0]
|
||||
image_file = Image.open(image).convert("RGB")
|
||||
output = Image.new(mode="RGB", size=(image_file.width, image_file.height + 256), color=(255,255,255))
|
||||
output.paste(image_file)
|
||||
beat_finder_result = find_cutoff(image_file, rhythm, band_start, band_length, threshold_offset, ignore_range)
|
||||
draw = ImageDraw.Draw(output, "RGBA")
|
||||
register_rect = [(0, beat_finder_result["register_start"]), (image_file.width, beat_finder_result["register_end"])]
|
||||
|
||||
level_range = beat_finder_result["level_range"]
|
||||
if level_range > 0:
|
||||
min_level = beat_finder_result["min_level"]
|
||||
max_level = beat_finder_result["max_level"]
|
||||
scale = 255 / level_range
|
||||
for i, value in enumerate(beat_finder_result["one_line"]):
|
||||
line_height = (max_level - value) * scale
|
||||
draw.line([(i, output.height), (i, output.height - line_height)], fill=(0,0,0))
|
||||
pass
|
||||
|
||||
scaled_threshold = (max_level - beat_finder_result["threshold"]) * scale
|
||||
draw.line([(0, output.height - scaled_threshold), (output.width, output.height - scaled_threshold)], fill=(0,255,0))
|
||||
|
||||
draw.rectangle(register_rect, fill=(0, 0, 255, 64))
|
||||
for beat in beat_finder_result["above_threshold"]:
|
||||
draw.line([(beat, 0), (beat, output.height)], fill=(255,0,0), width= 2)
|
||||
|
||||
cutoff = beat_finder_result["cutoff"]
|
||||
if cutoff != None:
|
||||
draw.line([(cutoff, 0), (cutoff, output.height)], fill=(255,165,0), width= 2)
|
||||
|
||||
del draw
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def get_image(file_regex, image_dir):
|
||||
images = []
|
||||
globs = map(lambda x: x.strip(), file_regex.split(","))
|
||||
for g in globs:
|
||||
images.extend(glob.glob(os.path.join(image_dir, g)))
|
||||
|
||||
images.sort()
|
||||
|
||||
return images
|
||||
|
||||
def find_cutoff(image, rhythm = 4, band_start = 0.25, band_length = 0.75, threshold_offset = 0.75, ignore_range = 0.05):
|
||||
|
||||
result = { "cutoff": None }
|
||||
|
||||
ignore_distance = image.width * ignore_range
|
||||
register_start = image.height * band_start
|
||||
register_lenght = image.height * band_end
|
||||
register_start = int(image.height * band_start)
|
||||
register_lenght = int(image.height * band_length)
|
||||
register_end = register_start + register_lenght
|
||||
|
||||
band = image.crop((0, register_start, image.width, register_start + register_lenght))
|
||||
result["register_start"] = register_start
|
||||
result["register_end"] = register_end
|
||||
|
||||
band = image.crop((0, register_start, image.width, register_end))
|
||||
gray_image = band.convert('L')
|
||||
one_line = list(gray_image.resize((gray_image.width, 1)).getdata())
|
||||
one_pixel = list(gray_image.resize((1, 1)).getdata())
|
||||
average = one_pixel[0]
|
||||
threshold = min(one_line) + (max(one_line) - min(one_line)) * threshold_offset
|
||||
|
||||
result["one_line"] = one_line
|
||||
|
||||
#one_pixel = list(gray_image.resize((1, 1)).getdata())
|
||||
#average = one_pixel[0]
|
||||
min_level = min(one_line)
|
||||
max_level = max(one_line)
|
||||
level_range = max_level - min_level
|
||||
result["min_level"] = min_level
|
||||
result["max_level"] = max_level
|
||||
result["level_range"] = level_range
|
||||
threshold = min_level + level_range * threshold_offset
|
||||
result["threshold"] = threshold
|
||||
above_threshold = []
|
||||
for i, value in enumerate(one_line):
|
||||
if value < threshold and (len(above_threshold) == 0 or i - above_threshold[-1] > ignore_distance):
|
||||
above_threshold.append(i)
|
||||
|
||||
result["above_threshold"] = above_threshold
|
||||
|
||||
if len(above_threshold) == 0:
|
||||
print("Failed to find beats")
|
||||
return None
|
||||
return result
|
||||
|
||||
print("Beats found:", above_threshold)
|
||||
|
||||
|
|
@ -417,16 +485,25 @@ def find_cutoff(image, rhythm = 4, band_start = 0.25, band_end = 0.75, threshold
|
|||
|
||||
distances.append(value - above_threshold[i - 1])
|
||||
|
||||
if (len(distances) == 0):
|
||||
return result
|
||||
|
||||
result["distances"] = distances
|
||||
|
||||
distance = median(distances)
|
||||
result["distance"] = distance
|
||||
|
||||
print("Interval:", distance)
|
||||
|
||||
beat_count = int(len(above_threshold) / rhythm) * rhythm
|
||||
result["beat_count"] = beat_count
|
||||
if beat_count == 0:
|
||||
print("Missmatching rhythm")
|
||||
return None
|
||||
return result
|
||||
|
||||
cutoff = int(beat_count * distance)
|
||||
return cutoff
|
||||
result["cutoff"] = cutoff
|
||||
return result
|
||||
|
||||
def on_ui_tabs():
|
||||
with gr.Blocks() as riffusion_ui:
|
||||
|
|
@ -476,33 +553,59 @@ def on_ui_tabs():
|
|||
interactive=True,
|
||||
)
|
||||
|
||||
band_start = gr.Number(
|
||||
band_start = gr.Slider(
|
||||
label="Band start",
|
||||
min=0,
|
||||
max=0.99,
|
||||
step=0.05,
|
||||
value=0.25,
|
||||
precision=2,
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
band_end = gr.Number(
|
||||
label="Band end",
|
||||
band_length = gr.Slider(
|
||||
label="Band length",
|
||||
min=0.01,
|
||||
max=1,
|
||||
step=0.05,
|
||||
value=0.5,
|
||||
precision=2,
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Row():
|
||||
threshold_offset = gr.Number(
|
||||
label="Threshold offset",
|
||||
threshold_offset = gr.Slider(
|
||||
label="Threshold",
|
||||
min=0.01,
|
||||
max=1,
|
||||
step=0.05,
|
||||
value=0.1,
|
||||
precision=2,
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
ignore_range = gr.Number(
|
||||
ignore_range = gr.Slider(
|
||||
label="Ignore range",
|
||||
min=0,
|
||||
max=1,
|
||||
step=0.05,
|
||||
value=0.1,
|
||||
precision=23,
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Column():
|
||||
test_beat_finder_button = gr.Button(
|
||||
"Test", label="Test", variant="primary"
|
||||
)
|
||||
beat_finder_image = gr.Image(type="pil")
|
||||
test_beat_finder_button.click(
|
||||
on_beat_finder_test_click,
|
||||
inputs=[
|
||||
image_directory,
|
||||
file_regex,
|
||||
rhythm,
|
||||
band_start,
|
||||
band_length,
|
||||
threshold_offset,
|
||||
ignore_range
|
||||
],
|
||||
outputs=[beat_finder_image],
|
||||
)
|
||||
|
||||
crop_method.change(on_crop_method_change, crop_method, [fixed_block,beat_finder_block])
|
||||
with gr.Column(variant="panel"):
|
||||
|
|
@ -520,7 +623,7 @@ def on_ui_tabs():
|
|||
crop_width,
|
||||
rhythm,
|
||||
band_start,
|
||||
band_end,
|
||||
band_length,
|
||||
threshold_offset,
|
||||
ignore_range
|
||||
],
|
||||
|
|
@ -529,9 +632,22 @@ def on_ui_tabs():
|
|||
gr.HTML(value="<p>Converts all images in a folder to audio</p>")
|
||||
return ((riffusion_ui, "Riffusion", "riffusion_ui"),)
|
||||
|
||||
def on_beat_finder_test_click(image_dir: str,
|
||||
file_regex: str,
|
||||
rhythm:int,
|
||||
band_start: float,
|
||||
band_length: float,
|
||||
threshold_offset: float,
|
||||
ignore_range: float):
|
||||
|
||||
image = test_beat_finder(image_dir, file_regex, rhythm, band_start, band_length, threshold_offset, ignore_range)
|
||||
return gr.update(value=image)
|
||||
|
||||
|
||||
def on_crop_method_change(crop_method):
|
||||
fixed_visible = True if crop_method == "Fixed" else False
|
||||
beat_finder_visible = True if crop_method.startswith("Beat Finder") else False
|
||||
return gr.update(visible=fixed_visible), gr.update(visible=beat_finder_visible)
|
||||
|
||||
script_callbacks.on_ui_tabs(on_ui_tabs)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue