sd-webui-riffusion/scripts/riffusion.py

654 lines
21 KiB
Python

# All the audio related code was shamelessly copied from the original author(s):
# https://github.com/hmartiro/riffusion-inference
import io
import typing as T
import os
import numpy as np
from PIL import Image, ImageDraw
from scipy.io import wavfile
import torch
import torchaudio
import gradio as gr
from modules import scripts, script_callbacks
from modules.images import FilenameGenerator
from modules.processing import process_images
import os
import modules.shared as shared
from pedalboard.io import AudioFile
import glob
from datetime import datetime
import wave
import platform
from statistics import mean, median
base_dir = scripts.basedir()
MAX_BATCH_SIZE = 8
class RiffusionScript(scripts.Script):
last_generated_files = []
last_generated_labels = []
def title(self):
return "Riffusion Audio Generator"
def process_wav(self, wav_file, preserve_wav=False):
with AudioFile(wav_file) as f:
audio = f.read(f.frames)
samplerate = f.samplerate
filename = wav_file.replace(".wav", ".mp3")
RiffusionScript.last_generated_files.append(filename)
with AudioFile(filename, "w", samplerate, audio.shape[0]) as f:
f.write(audio)
if not preserve_wav:
os.remove(wav_file)
def ui(self, is_img2img):
path = os.path.join(base_dir, "outputs")
with gr.Row():
riffusion_enabled = gr.Checkbox(label="Riffusion enabled", value=True)
save_wav = gr.Checkbox(label="Preserve Original WAV", value=False)
output_path = gr.Textbox(label="Output path", value=path)
def update_audio_players():
count = len(RiffusionScript.last_generated_files)
updates = [
gr.Audio.update(
value=RiffusionScript.last_generated_files[i],
visible=True,
label=RiffusionScript.last_generated_labels[i],
)
for i in range(count)
]
# pad with empty updates
for _ in range(count, MAX_BATCH_SIZE):
updates.append(gr.Audio.update(value=None, visible=False))
return updates
# create MAX_BATCH_SIZE audio players, and hide the unnecessary ones
audio_players = []
for i in range(MAX_BATCH_SIZE):
audio_players.append(
gr.Audio(
label=f"Audio Player {i}",
visible=False,
value=None,
interactive=False,
)
)
show_audio_button = gr.Button(
"Refresh Inline Audio (Last Batch)",
label="Refresh Inline Audio (Last Batch)",
variant="primary",
)
show_audio_button.click(
fn=lambda: update_audio_players(),
inputs=[],
outputs=audio_players,
)
hide_audio_button = gr.Button("Hide Inline Audio", label="Hide Inline Audio")
hide_audio_button.click(
fn=lambda: [
gr.Audio.update(value=None, visible=False)
for _ in range(MAX_BATCH_SIZE)
],
inputs=[],
outputs=audio_players,
)
return [
riffusion_enabled,
save_wav,
output_path,
show_audio_button,
*audio_players,
]
def play_input_as_sound(self):
pass
def run(self, p, riffusion_enabled, save_wav, output_path, btn, *audio_players):
if riffusion_enabled is False:
return process_images(p)
else:
print("Generating Riffusion mp3")
proc = process_images(p)
RiffusionScript.last_generated_labels = []
RiffusionScript.last_generated_files = []
try:
# try to create output path dir if doesnt exist
os.makedirs(output_path)
except FileExistsError:
pass
for i in range(len(proc.images)):
wav_bytes, duration_s = self.wav_bytes_from_spectrogram_image(
proc.images[i]
)
namegen = FilenameGenerator(p, p.seed, p.prompt, proc.images[i])
name = namegen.apply(f"[job_timestamp]-[seed]-[prompt_spaces]-{i}")
filename = os.path.join(output_path, f"{name}.wav")
with open(filename, "wb") as f:
f.write(wav_bytes.getbuffer())
self.process_wav(filename, preserve_wav=save_wav)
RiffusionScript.last_generated_labels.append(
namegen.apply(f"[seed]-[prompt_spaces]-{i}")
)
return proc
def wav_bytes_from_spectrogram_image(
self,
image: Image.Image,
) -> T.Tuple[io.BytesIO, float]:
"""
Reconstruct a WAV audio clip from a spectrogram image. Also returns the duration in seconds.
"""
max_volume = 50
power_for_image = 0.25
Sxx = self.spectrogram_from_image(
image, max_volume=max_volume, power_for_image=power_for_image
)
sample_rate = 44100 # [Hz]
clip_duration_ms = 5000 # [ms]
bins_per_image = image.height
n_mels = image.height
# FFT parameters
window_duration_ms = 100 # [ms]
padded_duration_ms = 400 # [ms]
step_size_ms = 10 # [ms]
# Derived parameters
num_samples = (
int(image.width / float(bins_per_image) * clip_duration_ms) * sample_rate
)
n_fft = int(padded_duration_ms / 1000.0 * sample_rate)
hop_length = int(step_size_ms / 1000.0 * sample_rate)
win_length = int(window_duration_ms / 1000.0 * sample_rate)
samples = self.waveform_from_spectrogram(
Sxx=Sxx,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
num_samples=num_samples,
sample_rate=sample_rate,
mel_scale=True,
n_mels=n_mels,
max_mel_iters=200,
num_griffin_lim_iters=32,
)
wav_bytes = io.BytesIO()
wavfile.write(wav_bytes, sample_rate, samples.astype(np.int16))
wav_bytes.seek(0)
duration_s = float(len(samples)) / sample_rate
return wav_bytes, duration_s
def spectrogram_from_image(
self, image: Image.Image, max_volume: float = 50, power_for_image: float = 0.25
) -> np.ndarray:
"""
Compute a spectrogram magnitude array from a spectrogram image.
TODO(hayk): Add image_from_spectrogram and call this out as the reverse.
"""
# Convert to a numpy array of floats
data = np.array(image).astype(np.float32)
# Flip Y take a single channel
data = data[::-1, :, 0]
# Invert
data = 255 - data
# Rescale to max volume
data = data * max_volume / 255
# Reverse the power curve
data = np.power(data, 1 / power_for_image)
return data
def spectrogram_from_waveform(
self,
waveform: np.ndarray,
sample_rate: int,
n_fft: int,
hop_length: int,
win_length: int,
mel_scale: bool = True,
n_mels: int = 512,
) -> np.ndarray:
"""
Compute a spectrogram from a waveform.
"""
spectrogram_func = torchaudio.transforms.Spectrogram(
n_fft=n_fft,
power=None,
hop_length=hop_length,
win_length=win_length,
)
waveform_tensor = torch.from_numpy(waveform.astype(np.float32)).reshape(1, -1)
Sxx_complex = spectrogram_func(waveform_tensor).numpy()[0]
Sxx_mag = np.abs(Sxx_complex)
if mel_scale:
mel_scaler = torchaudio.transforms.MelScale(
n_mels=n_mels,
sample_rate=sample_rate,
f_min=0,
f_max=10000,
n_stft=n_fft // 2 + 1,
norm=None,
mel_scale="htk",
)
Sxx_mag = mel_scaler(torch.from_numpy(Sxx_mag)).numpy()
return Sxx_mag
def waveform_from_spectrogram(
self,
Sxx: np.ndarray,
n_fft: int,
hop_length: int,
win_length: int,
num_samples: int,
sample_rate: int,
mel_scale: bool = True,
n_mels: int = 512,
max_mel_iters: int = 200,
num_griffin_lim_iters: int = 32,
device: str = platform.system() == "Darwin" and "cpu" or "cuda:0",
) -> np.ndarray:
"""
Reconstruct a waveform from a spectrogram.
This is an approximate inverse of spectrogram_from_waveform, using the Griffin-Lim algorithm
to approximate the phase.
"""
Sxx_torch = torch.from_numpy(Sxx).to(device)
# TODO(hayk): Make this a class that caches the two things
if mel_scale:
mel_inv_scaler = torchaudio.transforms.InverseMelScale(
n_mels=n_mels,
sample_rate=sample_rate,
f_min=0,
f_max=10000,
n_stft=n_fft // 2 + 1,
norm=None,
mel_scale="htk",
max_iter=max_mel_iters,
).to(device)
Sxx_torch = mel_inv_scaler(Sxx_torch)
griffin_lim = torchaudio.transforms.GriffinLim(
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
power=1.0,
n_iter=num_griffin_lim_iters,
).to(device)
waveform = griffin_lim(Sxx_torch).cpu().numpy()
return waveform
def convert_audio_file(image, output_dir, crop_width = None):
image_file = Image.open(image)
return convert_audio_image(image, image_file, output_dir, crop_width)
def convert_audio_image(image, image_file, output_dir, crop_width = None):
if crop_width is not None and crop_width < image_file.width:
image_file = image_file.crop((0,0,crop_width,image_file.height))
new_filename = os.path.splitext(os.path.basename(image))[0] + ".wav"
filename = os.path.join(output_dir, new_filename)
riffusion = RiffusionScript()
wav_bytes, duration_s = riffusion.wav_bytes_from_spectrogram_image(image_file)
with open(filename, "wb") as f:
f.write(wav_bytes.getbuffer())
return filename
def convert_audio(
image_dir: str,
file_regex: str,
join_images: bool,
crop_method: str,
crop_width: int,
rhythm:int,
band_start: float,
band_length: float,
threshold_offset: float,
ignore_range: float) -> None:
images = get_image(file_regex, image_dir)
print(f"Found {len(images)} images in {image_dir}, pattern {file_regex}")
output_files = []
width = None
for i, image in enumerate(images):
image_file = Image.open(image)
if crop_method == "Fixed":
width = crop_width
elif crop_method.startswith("Beat Finder") and (not crop_method.endswith("(Once)") or i == 0):
width = find_cutoff(image_file, rhythm, band_start, band_length, threshold_offset, ignore_range)["cutoff"]
print("Cutoff found at:", width)
output_files.append(convert_audio_image(image, image_file, image_dir, width))
if join_images and len(output_files) > 1:
data = []
outfile = os.path.join(
image_dir,
f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}_joined.wav",
)
for wav in output_files:
w = wave.open(wav, "rb")
data.append([w.getparams(), w.readframes(w.getnframes())])
w.close()
output = wave.open(outfile, "wb")
output.setparams(data[0][0])
for i in range(len(data)):
output.writeframes(data[i][1])
output.close()
print(f"Converted {len(images)} images to audio")
def test_beat_finder(image_dir: str,
file_regex: str,
rhythm:int,
band_start: float,
band_length: float,
threshold_offset: float,
ignore_range: float) -> Image:
images = get_image(file_regex, image_dir)
if (len(images) == 0):
return None
image = images[0]
image_file = Image.open(image).convert("RGB")
output = Image.new(mode="RGB", size=(image_file.width, image_file.height + 256), color=(255,255,255))
output.paste(image_file)
beat_finder_result = find_cutoff(image_file, rhythm, band_start, band_length, threshold_offset, ignore_range)
draw = ImageDraw.Draw(output, "RGBA")
register_rect = [(0, beat_finder_result["register_start"]), (image_file.width, beat_finder_result["register_end"])]
level_range = beat_finder_result["level_range"]
if level_range > 0:
min_level = beat_finder_result["min_level"]
max_level = beat_finder_result["max_level"]
scale = 255 / level_range
for i, value in enumerate(beat_finder_result["one_line"]):
line_height = (max_level - value) * scale
draw.line([(i, output.height), (i, output.height - line_height)], fill=(0,0,0))
pass
scaled_threshold = (max_level - beat_finder_result["threshold"]) * scale
draw.line([(0, output.height - scaled_threshold), (output.width, output.height - scaled_threshold)], fill=(0,255,0))
draw.rectangle(register_rect, fill=(0, 0, 255, 64))
for beat in beat_finder_result["above_threshold"]:
draw.line([(beat, 0), (beat, output.height)], fill=(255,0,0), width= 2)
cutoff = beat_finder_result["cutoff"]
if cutoff != None:
draw.line([(cutoff, 0), (cutoff, output.height)], fill=(255,165,0), width= 2)
del draw
return output
def get_image(file_regex, image_dir):
images = []
globs = map(lambda x: x.strip(), file_regex.split(","))
for g in globs:
images.extend(glob.glob(os.path.join(image_dir, g)))
images.sort()
return images
def find_cutoff(image, rhythm = 4, band_start = 0.25, band_length = 0.75, threshold_offset = 0.75, ignore_range = 0.05):
result = { "cutoff": None }
ignore_distance = image.width * ignore_range
register_start = int(image.height * band_start)
register_lenght = int(image.height * band_length)
register_end = register_start + register_lenght
result["register_start"] = register_start
result["register_end"] = register_end
band = image.crop((0, register_start, image.width, register_end))
gray_image = band.convert('L')
one_line = list(gray_image.resize((gray_image.width, 1)).getdata())
result["one_line"] = one_line
#one_pixel = list(gray_image.resize((1, 1)).getdata())
#average = one_pixel[0]
min_level = min(one_line)
max_level = max(one_line)
level_range = max_level - min_level
result["min_level"] = min_level
result["max_level"] = max_level
result["level_range"] = level_range
threshold = min_level + level_range * threshold_offset
result["threshold"] = threshold
above_threshold = []
for i, value in enumerate(one_line):
if value < threshold and (len(above_threshold) == 0 or i - above_threshold[-1] > ignore_distance):
above_threshold.append(i)
result["above_threshold"] = above_threshold
if len(above_threshold) == 0:
print("Failed to find beats")
return result
print("Beats found:", above_threshold)
distances = []
for i, value in enumerate(above_threshold):
if i == 0:
continue
distances.append(value - above_threshold[i - 1])
if (len(distances) == 0):
return result
result["distances"] = distances
distance = median(distances)
result["distance"] = distance
print("Interval:", distance)
beat_count = int(len(above_threshold) / rhythm) * rhythm
result["beat_count"] = beat_count
if beat_count == 0:
print("Missmatching rhythm")
return result
cutoff = int(beat_count * distance)
result["cutoff"] = cutoff
return result
def on_ui_tabs():
with gr.Blocks() as riffusion_ui:
with gr.Row():
with gr.Column(variant="panel"):
with gr.Row():
image_directory = gr.Textbox(
label="Image Directory",
placeholder="Directory containing your image files",
value="",
interactive=True,
)
with gr.Row():
join_images = gr.Checkbox(
label="Also output single joined audio file (will be named <date>_joined.wav)",
value=True,
interactive=True,
)
with gr.Row():
file_regex = gr.Textbox(
label="GLOB patterns (comma separated)",
value="*.jpg, *.png",
interactive=True,
)
crop_method = gr.Dropdown(
label="Crop method",
choices=["None", "Fixed", "Beat Finder (Once)", "Beat Finder (Every)"],
value="None",
interactive=True,
allow_custom_value=False
)
with gr.Column(visible=False) as fixed_block:
crop_width = gr.Number(
label="Fixed width",
value=512,
precision=0,
interactive=True,
)
with gr.Column(visible=False) as beat_finder_block:
with gr.Row():
rhythm = gr.Number(
label="Rhythm",
value=4,
precision=0,
interactive=True,
)
band_start = gr.Slider(
label="Band start",
min=0,
max=0.99,
step=0.05,
value=0.25,
interactive=True,
)
band_length = gr.Slider(
label="Band length",
min=0.01,
max=1,
step=0.05,
value=0.5,
interactive=True,
)
with gr.Row():
threshold_offset = gr.Slider(
label="Threshold",
min=0.01,
max=1,
step=0.05,
value=0.1,
interactive=True,
)
ignore_range = gr.Slider(
label="Ignore range",
min=0,
max=1,
step=0.05,
value=0.1,
interactive=True,
)
with gr.Column():
test_beat_finder_button = gr.Button(
"Test", label="Test", variant="primary"
)
beat_finder_image = gr.Image(type="pil")
test_beat_finder_button.click(
on_beat_finder_test_click,
inputs=[
image_directory,
file_regex,
rhythm,
band_start,
band_length,
threshold_offset,
ignore_range
],
outputs=[beat_finder_image],
)
crop_method.change(on_crop_method_change, crop_method, [fixed_block,beat_finder_block])
with gr.Column(variant="panel"):
with gr.Row():
convert_folder_btn = gr.Button(
"Convert Folder", label="Convert Folder", variant="primary"
)
convert_folder_btn.click(
convert_audio,
inputs=[
image_directory,
file_regex,
join_images,
crop_method,
crop_width,
rhythm,
band_start,
band_length,
threshold_offset,
ignore_range
],
outputs=[],
)
gr.HTML(value="<p>Converts all images in a folder to audio</p>")
return ((riffusion_ui, "Riffusion", "riffusion_ui"),)
def on_beat_finder_test_click(image_dir: str,
file_regex: str,
rhythm:int,
band_start: float,
band_length: float,
threshold_offset: float,
ignore_range: float):
image = test_beat_finder(image_dir, file_regex, rhythm, band_start, band_length, threshold_offset, ignore_range)
return gr.update(value=image)
def on_crop_method_change(crop_method):
fixed_visible = True if crop_method == "Fixed" else False
beat_finder_visible = True if crop_method.startswith("Beat Finder") else False
return gr.update(visible=fixed_visible), gr.update(visible=beat_finder_visible)
script_callbacks.on_ui_tabs(on_ui_tabs)