411 lines
17 KiB
Python
411 lines
17 KiB
Python
import importlib
|
|
from modules.scripts import PostprocessImageArgs,scripts_postprocessing
|
|
from scripts.roop_utils.models_utils import get_models, get_face_checkpoints
|
|
|
|
from scripts import (roop_globals, roop_logging, faceswap_settings, faceswap_tab)
|
|
from scripts.roop_swapping import swapper
|
|
from scripts.roop_utils import imgutils
|
|
from scripts.roop_utils import models_utils
|
|
from scripts.roop_postprocessing import upscaling
|
|
|
|
|
|
#Reload all the modules when using "apply and restart"
|
|
importlib.reload(swapper)
|
|
importlib.reload(roop_logging)
|
|
importlib.reload(roop_globals)
|
|
importlib.reload(imgutils)
|
|
importlib.reload(upscaling)
|
|
importlib.reload(faceswap_settings)
|
|
importlib.reload(models_utils)
|
|
|
|
import base64
|
|
import io
|
|
import os
|
|
from dataclasses import dataclass, fields
|
|
from pprint import pformat
|
|
from typing import Dict, List, Set, Tuple, Union, Optional
|
|
|
|
import dill as pickle
|
|
import gradio as gr
|
|
import modules.scripts as scripts
|
|
from modules import script_callbacks, scripts
|
|
import torch
|
|
from insightface.app.common import Face
|
|
from modules import processing, scripts, shared
|
|
from modules.images import save_image, image_grid
|
|
from modules.processing import (Processed, StableDiffusionProcessing,
|
|
StableDiffusionProcessingImg2Img)
|
|
from modules.shared import cmd_opts, opts, state
|
|
from PIL import Image
|
|
|
|
from scripts.roop_utils.imgutils import (pil_to_cv2,convert_to_sd)
|
|
|
|
from scripts.roop_logging import logger
|
|
from scripts.roop_globals import VERSION_FLAG
|
|
from scripts.roop_postprocessing.postprocessing_options import PostProcessingOptions
|
|
from scripts.roop_postprocessing.postprocessing import enhance_image
|
|
|
|
|
|
import modules
|
|
|
|
EXTENSION_PATH=os.path.join("extensions","sd-webui-roop")
|
|
|
|
@dataclass
|
|
class FaceSwapUnitSettings:
|
|
# The image given in reference
|
|
source_img: Union[Image.Image, str]
|
|
# The checkpoint file
|
|
source_face : str
|
|
# The batch source images
|
|
_batch_files: gr.components.File
|
|
# Will blend faces if True
|
|
blend_faces: bool
|
|
# Enable this unit
|
|
enable: bool
|
|
# Use same gender filtering
|
|
same_gender: bool
|
|
|
|
# If True, discard images with low similarity
|
|
check_similarity : bool
|
|
|
|
# Minimum similarity against the used face (reference, batch or checkpoint)
|
|
min_sim: float
|
|
# Minimum similarity against the reference (reference or checkpoint if checkpoint is given)
|
|
min_ref_sim: float
|
|
# The face index to use for swapping
|
|
_faces_index: int
|
|
# Swap in the source image in img2img (before processing)
|
|
swap_in_source: bool
|
|
# Swap in the generated image in img2img (always on for txt2img)
|
|
swap_in_generated: bool
|
|
|
|
@staticmethod
|
|
def get_unit_configuration(unit: int, components):
|
|
fields_count = len(fields(FaceSwapUnitSettings))
|
|
return FaceSwapUnitSettings(
|
|
*components[unit * fields_count : unit * fields_count + fields_count]
|
|
)
|
|
|
|
@property
|
|
def faces_index(self):
|
|
"""
|
|
Convert _faces_index from str to int
|
|
"""
|
|
faces_index = {
|
|
int(x) for x in self._faces_index.strip(",").split(",") if x.isnumeric()
|
|
}
|
|
if len(faces_index) == 0:
|
|
return {0}
|
|
|
|
return faces_index
|
|
|
|
@property
|
|
def batch_files(self):
|
|
"""
|
|
Return empty array instead of None for batch files
|
|
"""
|
|
return self._batch_files or []
|
|
|
|
@property
|
|
def reference_face(self) :
|
|
"""
|
|
Extract reference face (only once and store it for the rest of processing).
|
|
Reference face is the checkpoint or the source image or the first image in the batch in that order.
|
|
"""
|
|
if not hasattr(self,"_reference_face") :
|
|
if self.source_face and self.source_face != "None" :
|
|
with open(self.source_face, "rb") as file:
|
|
try :
|
|
logger.info(f"loading pickle {file.name}")
|
|
face = Face(pickle.load(file))
|
|
self._reference_face = face
|
|
except Exception as e :
|
|
logger.error("Failed to load checkpoint : %s", e)
|
|
elif self.source_img is not None :
|
|
if isinstance(self.source_img, str): # source_img is a base64 string
|
|
if 'base64,' in self.source_img: # check if the base64 string has a data URL scheme
|
|
base64_data = self.source_img.split('base64,')[-1]
|
|
img_bytes = base64.b64decode(base64_data)
|
|
else:
|
|
# if no data URL scheme, just decode
|
|
img_bytes = base64.b64decode(self.source_img)
|
|
self.source_img = Image.open(io.BytesIO(img_bytes))
|
|
source_img = pil_to_cv2(self.source_img)
|
|
self._reference_face = swapper.get_or_default(swapper.get_faces(source_img), 0, None)
|
|
else :
|
|
logger.error("You need at least one face")
|
|
self._reference_face = None
|
|
|
|
return self._reference_face
|
|
|
|
@property
|
|
def faces(self) :
|
|
"""_summary_
|
|
Extract all faces (including reference face) to provide an array of faces
|
|
Only processed once.
|
|
"""
|
|
if self.batch_files is not None and not hasattr(self,"_faces") :
|
|
self._faces = [self.reference_face] if self.reference_face is not None else []
|
|
for file in self.batch_files :
|
|
img = Image.open(file.name)
|
|
face = swapper.get_or_default(swapper.get_faces(pil_to_cv2(img)), 0, None)
|
|
if face is not None :
|
|
self._faces.append(face)
|
|
return self._faces
|
|
|
|
@property
|
|
def blended_faces(self):
|
|
"""
|
|
Blend the faces using the mean of all embeddings
|
|
"""
|
|
if not hasattr(self,"_blended_faces") :
|
|
self._blended_faces = swapper.blend_faces(self.faces)
|
|
return self._blended_faces
|
|
|
|
|
|
script_callbacks.on_ui_tabs(faceswap_tab.on_ui_tabs)
|
|
|
|
|
|
class FaceSwapScript(scripts.Script):
|
|
|
|
@property
|
|
def units_count(self) :
|
|
return opts.data.get("roop_units_count", 3)
|
|
|
|
@property
|
|
def upscaled_swapper(self) :
|
|
return opts.data.get("roop_upscaled_swapper", False)
|
|
|
|
@property
|
|
def enabled(self) :
|
|
return any([u.enable for u in self.units]) and not shared.state.interrupted
|
|
|
|
@property
|
|
def model(self) :
|
|
model = opts.data.get("roop_model", None)
|
|
if model is None :
|
|
models = get_models()
|
|
model = models[0] if len(models) else None
|
|
logger.info("Try to use model : %s", model)
|
|
if not os.path.isfile(model):
|
|
logger.error("The model %s cannot be found or loaded", model)
|
|
raise FileNotFoundError("No faceswap model found. Please add it to the roop directory.")
|
|
return model
|
|
|
|
@property
|
|
def keep_original_images(self) :
|
|
return opts.data.get("roop_keep_original", False)
|
|
|
|
def title(self):
|
|
return f"roop"
|
|
|
|
def show(self, is_img2img):
|
|
return scripts.AlwaysVisible
|
|
|
|
def faceswap_unit_ui(self, is_img2img, unit_num=1):
|
|
with gr.Tab(f"Face {unit_num}"):
|
|
with gr.Column():
|
|
gr.Markdown(
|
|
"""Reference is an image. First face will be extracted.
|
|
First face of batches sources will be extracted and used as input (or blended if blend is activated).""")
|
|
with gr.Row():
|
|
img = gr.components.Image(type="pil", label="Reference")
|
|
batch_files = gr.components.File(
|
|
type="file",
|
|
file_count="multiple",
|
|
label="Batch Sources Images",
|
|
optional=True,
|
|
)
|
|
gr.Markdown(
|
|
"""Face checkpoint built with the checkpoint builder in tools. Will overwrite reference image.""")
|
|
with gr.Row() :
|
|
|
|
face = gr.inputs.Dropdown(
|
|
choices=get_face_checkpoints(),
|
|
label="Face Checkpoint (precedence over reference face)",
|
|
)
|
|
refresh = gr.Button(value='↻', variant='tool')
|
|
def refresh_fn(selected):
|
|
return gr.Dropdown.update(value=selected, choices=get_face_checkpoints())
|
|
refresh.click(fn=refresh_fn,inputs=face, outputs=face)
|
|
|
|
with gr.Row():
|
|
enable = gr.Checkbox(False, placeholder="enable", label="Enable")
|
|
same_gender = gr.Checkbox(
|
|
False, placeholder="Same Gender", label="Same Gender"
|
|
)
|
|
blend_faces = gr.Checkbox(
|
|
True, placeholder="Blend Faces", label="Blend Faces ((Source|Checkpoint)+References = 1)"
|
|
)
|
|
gr.Markdown("""Discard images with low similarity or no faces :""")
|
|
check_similarity = gr.Checkbox(False, placeholder="discard", label="Check similarity")
|
|
min_sim = gr.Slider(0, 1, 0, step=0.01, label="Min similarity")
|
|
min_ref_sim = gr.Slider(
|
|
0, 1, 0, step=0.01, label="Min reference similarity"
|
|
)
|
|
faces_index = gr.Textbox(
|
|
value="0",
|
|
placeholder="Which face to swap (comma separated), start from 0 (by gender if same_gender is enabled)",
|
|
label="Comma separated face number(s)",
|
|
)
|
|
gr.Markdown("""Configure swapping. Swapping can occure before img2img, after or both :""", visible=is_img2img)
|
|
swap_in_source = gr.Checkbox(
|
|
False,
|
|
placeholder="Swap face in source image",
|
|
label="Swap in source image (must be blended)",
|
|
visible=is_img2img,
|
|
)
|
|
swap_in_generated = gr.Checkbox(
|
|
True,
|
|
placeholder="Swap face in generated image",
|
|
label="Swap in generated image",
|
|
visible=is_img2img,
|
|
)
|
|
return [
|
|
img,
|
|
face,
|
|
batch_files,
|
|
blend_faces,
|
|
enable,
|
|
same_gender,
|
|
check_similarity,
|
|
min_sim,
|
|
min_ref_sim,
|
|
faces_index,
|
|
swap_in_source,
|
|
swap_in_generated,
|
|
]
|
|
|
|
def ui(self, is_img2img):
|
|
with gr.Accordion(f"Roop {VERSION_FLAG}", open=False):
|
|
components = []
|
|
for i in range(1, self.units_count + 1):
|
|
components += self.faceswap_unit_ui(is_img2img, i)
|
|
upscaler = faceswap_tab.upscaler_ui()
|
|
return components + upscaler
|
|
|
|
def before_process(self, p: StableDiffusionProcessing, *components):
|
|
self.units: List[FaceSwapUnitSettings] = []
|
|
for i in range(0, self.units_count):
|
|
self.units += [FaceSwapUnitSettings.get_unit_configuration(i, components)]
|
|
|
|
for i, u in enumerate(self.units):
|
|
logger.debug("%s, %s", pformat(i), pformat(u))
|
|
|
|
len_conf: int = len(fields(FaceSwapUnitSettings))
|
|
shift: int = self.units_count * len_conf
|
|
self.postprocess_options = PostProcessingOptions(
|
|
*components[shift : shift + len(fields(PostProcessingOptions))]
|
|
)
|
|
logger.debug("%s", pformat(self.postprocess_options))
|
|
|
|
|
|
if isinstance(p, StableDiffusionProcessingImg2Img):
|
|
if any([u.enable for u in self.units]):
|
|
init_images = p.init_images
|
|
for i, unit in enumerate(self.units):
|
|
if unit.enable and unit.swap_in_source :
|
|
(init_images, result_infos) = self.process_images_unit(unit, init_images)
|
|
logger.info(f"unit {i+1}> processed init images: {len(init_images)}, {len(result_infos)}")
|
|
|
|
p.init_images = init_images
|
|
|
|
|
|
def postprocess_batch(self, p, *args, **kwargs):
|
|
if self.enabled :
|
|
if self.keep_original_images:
|
|
batch_index = kwargs.pop('batch_number', 0)
|
|
torch_images : torch.Tensor = kwargs["images"]
|
|
pil_images = imgutils.torch_to_pil(torch_images)
|
|
self._orig_images = pil_images
|
|
for img in pil_images :
|
|
if p.outpath_samples and opts.samples_save :
|
|
save_image(img, p.outpath_samples, "", p.seeds[batch_index], p.prompts[batch_index], opts.samples_format, p=p, suffix="-before-swap")
|
|
|
|
return
|
|
|
|
def process_image_unit(self, unit : FaceSwapUnitSettings, image, info = None) -> Tuple[Optional[Image.Image], Optional[str]]:
|
|
if unit.enable :
|
|
if convert_to_sd(image) :
|
|
return (image, info)
|
|
if not unit.blend_faces :
|
|
src_faces = unit.faces
|
|
logger.info(f"will generate {len(src_faces)} images")
|
|
else :
|
|
logger.info("blend all faces together")
|
|
src_faces = [unit.blended_faces]
|
|
for i,src_face in enumerate(src_faces):
|
|
logger.info(f"Process face {i}")
|
|
result: swapper.ImageResult = swapper.swap_face(
|
|
unit.reference_face if unit.reference_face is not None else src_face,
|
|
src_face,
|
|
image,
|
|
faces_index=unit.faces_index,
|
|
model=self.model,
|
|
same_gender=unit.same_gender,
|
|
upscaled_swapper=self.upscaled_swapper
|
|
)
|
|
if (not unit.check_similarity) or result.similarity and all([result.similarity.values()!=0]+[x >= unit.min_sim for x in result.similarity.values()]) and all([result.ref_similarity.values()!=0]+[x >= unit.min_ref_sim for x in result.ref_similarity.values()]):
|
|
return (result.image, f"{info}, similarity = {result.similarity}, ref_similarity = {result.ref_similarity}")
|
|
else:
|
|
logger.warning(
|
|
f"skip, similarity to low, sim = {result.similarity} (target {unit.min_sim}) ref sim = {result.ref_similarity} (target = {unit.min_ref_sim})"
|
|
)
|
|
return (None, None)
|
|
|
|
|
|
def process_images_unit(self, unit : FaceSwapUnitSettings, images : List[Image.Image], infos = None) -> Tuple[List[Image.Image], List[str]] :
|
|
if unit.enable :
|
|
result_images : List[Image.Image] = []
|
|
result_infos : List[str]= []
|
|
if not infos :
|
|
infos = [None] * len(images)
|
|
for i, (img, info) in enumerate(zip(images, infos)):
|
|
(result_image, result_info) = self.process_image_unit(unit, img, info)
|
|
if result_image is not None and result_info is not None :
|
|
result_images.append(result_image)
|
|
result_infos.append(result_info)
|
|
logger.info(f"{len(result_images)} images processed")
|
|
return (result_images, result_infos)
|
|
return (images, infos)
|
|
|
|
def postprocess_image(self, p, script_pp: PostprocessImageArgs, *args):
|
|
if self.enabled :
|
|
img : Image.Image = script_pp.image
|
|
infos = ""
|
|
if any([u.enable for u in self.units]):
|
|
for i, unit in enumerate(self.units):
|
|
if unit.enable :
|
|
img,info = self.process_image_unit(image=img, unit=unit, info="")
|
|
logger.info(f"unit {i+1}> processed")
|
|
infos += info or ""
|
|
if img is None :
|
|
logger.error("Failed to process image - Switch back to original image")
|
|
img = script_pp.image
|
|
try :
|
|
if self.postprocess_options is not None:
|
|
img = enhance_image(img, self.postprocess_options)
|
|
except Exception as e:
|
|
logger.error("Failed to upscale : %s", e)
|
|
pp = scripts_postprocessing.PostprocessedImage(img)
|
|
pp.info = {"face.similarity" : infos}
|
|
p.extra_generation_params.update(pp.info)
|
|
script_pp.image = pp.image
|
|
|
|
def postprocess(self, p : StableDiffusionProcessing, processed: Processed, *args):
|
|
if self.enabled :
|
|
|
|
images = processed.images[processed.index_of_first_image:]
|
|
for i,img in enumerate(images) :
|
|
images[i] = processing.apply_overlay(img, p.paste_to, i%p.batch_size, p.overlay_images)
|
|
|
|
processed.images = images
|
|
|
|
if self.keep_original_images:
|
|
if len(self._orig_images)> 1 :
|
|
processed.images.append(image_grid(self._orig_images))
|
|
processed.images += self._orig_images
|
|
processed.infotexts+= processed.infotexts # duplicate infotexts
|
|
|
|
|