clean code, fix extract
parent
00d9cc6f62
commit
a511214aaa
|
|
@ -0,0 +1,4 @@
|
|||
#!/bin/bash
|
||||
autoflake --in-place --remove-unused-variables -r --remove-all-unused-imports .
|
||||
mypy --install-types
|
||||
pre-commit run --all-files
|
||||
|
|
@ -1,5 +1,4 @@
|
|||
import requests
|
||||
from PIL import Image
|
||||
from client_utils import (
|
||||
FaceSwapRequest,
|
||||
FaceSwapUnit,
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ model_name = os.path.basename(model_url)
|
|||
model_path = os.path.join(models_dir, model_name)
|
||||
|
||||
|
||||
def download(url, path):
|
||||
def download(url: str, path: str) -> None:
|
||||
request = urllib.request.urlopen(url)
|
||||
total = int(request.headers.get("Content-Length", 0))
|
||||
with tqdm(
|
||||
|
|
|
|||
3
mypy.ini
3
mypy.ini
|
|
@ -4,4 +4,5 @@ disallow_any_generics = True
|
|||
disallow_untyped_calls = True
|
||||
disallow_untyped_defs = True
|
||||
ignore_missing_imports = True
|
||||
strict_optional = False
|
||||
strict_optional = False
|
||||
explicit_package_bases=True
|
||||
|
|
@ -1,4 +1,7 @@
|
|||
def preload(parser):
|
||||
from argparse import ArgumentParser
|
||||
|
||||
|
||||
def preload(parser: ArgumentParser) -> None:
|
||||
parser.add_argument(
|
||||
"--faceswaplab_loglevel",
|
||||
default="INFO",
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ from scripts.faceswaplab_settings import faceswaplab_settings
|
|||
from scripts.faceswaplab_ui import faceswaplab_tab, faceswaplab_unit_ui
|
||||
from scripts.faceswaplab_utils.models_utils import (
|
||||
get_current_model,
|
||||
get_face_checkpoints,
|
||||
)
|
||||
|
||||
from scripts import faceswaplab_globals
|
||||
|
|
@ -12,7 +11,6 @@ from scripts.faceswaplab_swapping import swapper
|
|||
from scripts.faceswaplab_utils import faceswaplab_logging, imgutils
|
||||
from scripts.faceswaplab_utils import models_utils
|
||||
from scripts.faceswaplab_postprocessing import upscaling
|
||||
import numpy as np
|
||||
|
||||
# Reload all the modules when using "apply and restart"
|
||||
# This is mainly done for development purposes
|
||||
|
|
@ -29,15 +27,13 @@ importlib.reload(faceswaplab_api)
|
|||
import os
|
||||
from dataclasses import fields
|
||||
from pprint import pformat
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
import dill as pickle
|
||||
import gradio as gr
|
||||
import modules.scripts as scripts
|
||||
from modules import script_callbacks, scripts
|
||||
from insightface.app.common import Face
|
||||
from modules import scripts, shared
|
||||
from modules.images import save_image, image_grid
|
||||
from modules.images import save_image
|
||||
from modules.processing import (
|
||||
Processed,
|
||||
StableDiffusionProcessing,
|
||||
|
|
@ -46,7 +42,6 @@ from modules.processing import (
|
|||
from modules.shared import opts
|
||||
from PIL import Image
|
||||
|
||||
from scripts.faceswaplab_utils.imgutils import pil_to_cv2, check_against_nsfw
|
||||
from scripts.faceswaplab_utils.faceswaplab_logging import logger, save_img_debug
|
||||
from scripts.faceswaplab_globals import VERSION_FLAG
|
||||
from scripts.faceswaplab_postprocessing.postprocessing_options import (
|
||||
|
|
@ -76,15 +71,15 @@ class FaceSwapScript(scripts.Script):
|
|||
super().__init__()
|
||||
|
||||
@property
|
||||
def units_count(self):
|
||||
def units_count(self) -> int:
|
||||
return opts.data.get("faceswaplab_units_count", 3)
|
||||
|
||||
@property
|
||||
def upscaled_swapper_in_generated(self):
|
||||
def upscaled_swapper_in_generated(self) -> bool:
|
||||
return opts.data.get("faceswaplab_upscaled_swapper", False)
|
||||
|
||||
@property
|
||||
def upscaled_swapper_in_source(self):
|
||||
def upscaled_swapper_in_source(self) -> bool:
|
||||
return opts.data.get("faceswaplab_upscaled_swapper_in_source", False)
|
||||
|
||||
@property
|
||||
|
|
@ -93,24 +88,24 @@ class FaceSwapScript(scripts.Script):
|
|||
return any([u.enable for u in self.units]) and not shared.state.interrupted
|
||||
|
||||
@property
|
||||
def keep_original_images(self):
|
||||
def keep_original_images(self) -> bool:
|
||||
return opts.data.get("faceswaplab_keep_original", False)
|
||||
|
||||
@property
|
||||
def swap_in_generated_units(self):
|
||||
def swap_in_generated_units(self) -> List[FaceSwapUnitSettings]:
|
||||
return [u for u in self.units if u.swap_in_generated and u.enable]
|
||||
|
||||
@property
|
||||
def swap_in_source_units(self):
|
||||
def swap_in_source_units(self) -> List[FaceSwapUnitSettings]:
|
||||
return [u for u in self.units if u.swap_in_source and u.enable]
|
||||
|
||||
def title(self):
|
||||
def title(self) -> str:
|
||||
return f"faceswaplab"
|
||||
|
||||
def show(self, is_img2img):
|
||||
def show(self, is_img2img: bool) -> bool:
|
||||
return scripts.AlwaysVisible
|
||||
|
||||
def ui(self, is_img2img):
|
||||
def ui(self, is_img2img: bool) -> List[gr.components.Component]:
|
||||
with gr.Accordion(f"FaceSwapLab {VERSION_FLAG}", open=False):
|
||||
components = []
|
||||
for i in range(1, self.units_count + 1):
|
||||
|
|
@ -119,16 +114,9 @@ class FaceSwapScript(scripts.Script):
|
|||
# If the order is modified, the before_process should be changed accordingly.
|
||||
return components + upscaler
|
||||
|
||||
# def make_script_first(self,p: StableDiffusionProcessing) :
|
||||
# FIXME : not really useful, will only impact postprocessing (kept for further testing)
|
||||
# runner : scripts.ScriptRunner = p.scripts
|
||||
# alwayson = runner.alwayson_scripts
|
||||
# alwayson.pop(alwayson.index(self))
|
||||
# alwayson.insert(0, self)
|
||||
# print("Running in ", alwayson.index(self), "position")
|
||||
# logger.info("Running scripts : %s", pformat(runner.alwayson_scripts))
|
||||
|
||||
def read_config(self, p: StableDiffusionProcessing, *components):
|
||||
def read_config(
|
||||
self, p: StableDiffusionProcessing, *components: List[gr.components.Component]
|
||||
) -> None:
|
||||
# The order of processing for the components is important
|
||||
# The method first process faceswap units then postprocessing units
|
||||
|
||||
|
|
@ -148,14 +136,16 @@ class FaceSwapScript(scripts.Script):
|
|||
len_conf: int = len(fields(FaceSwapUnitSettings))
|
||||
shift: int = self.units_count * len_conf
|
||||
self.postprocess_options = PostProcessingOptions(
|
||||
*components[shift : shift + len(fields(PostProcessingOptions))]
|
||||
*components[shift : shift + len(fields(PostProcessingOptions))] # type: ignore
|
||||
)
|
||||
logger.debug("%s", pformat(self.postprocess_options))
|
||||
|
||||
if self.enabled:
|
||||
p.do_not_save_samples = not self.keep_original_images
|
||||
|
||||
def process(self, p: StableDiffusionProcessing, *components):
|
||||
def process(
|
||||
self, p: StableDiffusionProcessing, *components: List[gr.components.Component]
|
||||
) -> None:
|
||||
self.read_config(p, *components)
|
||||
|
||||
# If is instance of img2img, we check if face swapping in source is required.
|
||||
|
|
@ -175,7 +165,9 @@ class FaceSwapScript(scripts.Script):
|
|||
if new_inits is not None:
|
||||
p.init_images = [img[0] for img in new_inits]
|
||||
|
||||
def postprocess(self, p: StableDiffusionProcessing, processed: Processed, *args):
|
||||
def postprocess(
|
||||
self, p: StableDiffusionProcessing, processed: Processed, *args: List[Any]
|
||||
) -> None:
|
||||
if self.enabled:
|
||||
# Get the original images without the grid
|
||||
orig_images: List[Image.Image] = processed.images[
|
||||
|
|
|
|||
|
|
@ -1,52 +1,68 @@
|
|||
from PIL import Image
|
||||
import numpy as np
|
||||
from fastapi import FastAPI, Body
|
||||
from fastapi.exceptions import HTTPException
|
||||
from modules.api.models import *
|
||||
from fastapi import FastAPI
|
||||
from modules.api import api
|
||||
from scripts.faceswaplab_api.faceswaplab_api_types import (
|
||||
FaceSwapUnit,
|
||||
FaceSwapRequest,
|
||||
FaceSwapResponse,
|
||||
)
|
||||
from scripts.faceswaplab_globals import VERSION_FLAG
|
||||
import gradio as gr
|
||||
from typing import List, Optional
|
||||
from typing import Dict, List, Optional, Union
|
||||
from scripts.faceswaplab_swapping import swapper
|
||||
from scripts.faceswaplab_utils.faceswaplab_logging import save_img_debug
|
||||
from scripts.faceswaplab_ui.faceswaplab_unit_settings import FaceSwapUnitSettings
|
||||
from scripts.faceswaplab_utils.imgutils import (
|
||||
pil_to_cv2,
|
||||
check_against_nsfw,
|
||||
base64_to_pil,
|
||||
)
|
||||
from scripts.faceswaplab_utils.models_utils import get_current_model
|
||||
from modules.shared import opts
|
||||
|
||||
|
||||
def encode_to_base64(image):
|
||||
if type(image) is str:
|
||||
def encode_to_base64(image: Union[str, Image.Image, np.ndarray]) -> str:
|
||||
"""
|
||||
Encode an image to a base64 string.
|
||||
|
||||
The image can be a file path (str), a PIL Image, or a NumPy array.
|
||||
|
||||
Args:
|
||||
image (Union[str, Image.Image, np.ndarray]): The image to encode.
|
||||
|
||||
Returns:
|
||||
str: The base64-encoded image if successful, otherwise an empty string.
|
||||
"""
|
||||
if isinstance(image, str):
|
||||
return image
|
||||
elif type(image) is Image.Image:
|
||||
elif isinstance(image, Image.Image):
|
||||
return api.encode_pil_to_base64(image)
|
||||
elif type(image) is np.ndarray:
|
||||
elif isinstance(image, np.ndarray):
|
||||
return encode_np_to_base64(image)
|
||||
else:
|
||||
return ""
|
||||
|
||||
|
||||
def encode_np_to_base64(image):
|
||||
def encode_np_to_base64(image: np.ndarray) -> str:
|
||||
"""
|
||||
Encode a NumPy array to a base64 string.
|
||||
|
||||
The array is first converted to a PIL Image, then encoded.
|
||||
|
||||
Args:
|
||||
image (np.ndarray): The NumPy array to encode.
|
||||
|
||||
Returns:
|
||||
str: The base64-encoded image.
|
||||
"""
|
||||
pil = Image.fromarray(image)
|
||||
return api.encode_pil_to_base64(pil)
|
||||
|
||||
|
||||
def faceswaplab_api(_: gr.Blocks, app: FastAPI):
|
||||
def faceswaplab_api(_: gr.Blocks, app: FastAPI) -> None:
|
||||
@app.get(
|
||||
"/faceswaplab/version",
|
||||
tags=["faceswaplab"],
|
||||
description="Get faceswaplab version",
|
||||
)
|
||||
async def version():
|
||||
async def version() -> Dict[str, str]:
|
||||
return {"version": VERSION_FLAG}
|
||||
|
||||
# use post as we consider the method non idempotent (which is debatable)
|
||||
|
|
|
|||
|
|
@ -1,16 +1,8 @@
|
|||
from scripts.faceswaplab_swapping import swapper
|
||||
import numpy as np
|
||||
from typing import List, Tuple
|
||||
import dill as pickle
|
||||
import gradio as gr
|
||||
from insightface.app.common import Face
|
||||
from PIL import Image
|
||||
from scripts.faceswaplab_utils.imgutils import (
|
||||
pil_to_cv2,
|
||||
check_against_nsfw,
|
||||
base64_to_pil,
|
||||
)
|
||||
from scripts.faceswaplab_utils.faceswaplab_logging import logger
|
||||
from pydantic import BaseModel, Field
|
||||
from scripts.faceswaplab_postprocessing.postprocessing_options import InpaintingWhen
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
from scripts.faceswaplab_utils.faceswaplab_logging import logger
|
||||
import os
|
||||
|
||||
MODELS_DIR = os.path.abspath(os.path.join("models", "faceswaplab"))
|
||||
ANALYZER_DIR = os.path.abspath(os.path.join(MODELS_DIR, "analysers"))
|
||||
FACE_PARSER_DIR = os.path.abspath(os.path.join(MODELS_DIR, "parser"))
|
||||
|
||||
VERSION_FLAG = "v1.1.0"
|
||||
VERSION_FLAG: str = "v1.1.0"
|
||||
EXTENSION_PATH = os.path.join("extensions", "sd-webui-faceswaplab")
|
||||
NSFW_SCORE = 0.7
|
||||
|
||||
# The NSFW score threshold. If any part of the image has a score greater than this threshold, the image will be considered NSFW.
|
||||
NSFW_SCORE_THRESHOLD: float = 0.7
|
||||
|
|
|
|||
|
|
@ -1,15 +1,11 @@
|
|||
from modules.face_restoration import FaceRestoration
|
||||
from modules.upscaler import UpscalerData
|
||||
from scripts.faceswaplab_utils.faceswaplab_logging import logger
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from modules import shared
|
||||
from scripts.faceswaplab_utils import imgutils
|
||||
from modules import shared, processing, codeformer_model
|
||||
from modules import shared, processing
|
||||
from modules.processing import StableDiffusionProcessingImg2Img
|
||||
from scripts.faceswaplab_postprocessing.postprocessing_options import (
|
||||
PostProcessingOptions,
|
||||
InpaintingWhen,
|
||||
)
|
||||
from modules import sd_models
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
from modules.face_restoration import FaceRestoration
|
||||
from scripts.faceswaplab_utils.faceswaplab_logging import logger
|
||||
from PIL import Image
|
||||
from scripts.faceswaplab_postprocessing.postprocessing_options import (
|
||||
|
|
|
|||
|
|
@ -1,11 +1,10 @@
|
|||
from scripts.faceswaplab_postprocessing.postprocessing_options import (
|
||||
PostProcessingOptions,
|
||||
InpaintingWhen,
|
||||
)
|
||||
from scripts.faceswaplab_utils.faceswaplab_logging import logger
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from modules import shared, processing, codeformer_model
|
||||
from modules import codeformer_model
|
||||
|
||||
|
||||
def upscale_img(image: Image.Image, pp_options: PostProcessingOptions) -> Image.Image:
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from modules import script_callbacks, shared
|
|||
import gradio as gr
|
||||
|
||||
|
||||
def on_ui_settings():
|
||||
def on_ui_settings() -> None:
|
||||
section = ("faceswaplab", "FaceSwapLab")
|
||||
models = get_models()
|
||||
shared.opts.add_option(
|
||||
|
|
|
|||
|
|
@ -40,7 +40,6 @@ please contact the contributor(s) of the work.
|
|||
|
||||
|
||||
import torch
|
||||
import cv2
|
||||
import os
|
||||
import torch
|
||||
from torch.hub import download_url_to_file, get_dir
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import copy
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Set, Tuple, Optional, Union
|
||||
from typing import Any, Dict, List, Set, Tuple, Optional
|
||||
|
||||
import cv2
|
||||
import insightface
|
||||
|
|
@ -13,7 +13,6 @@ from sklearn.metrics.pairwise import cosine_similarity
|
|||
|
||||
from scripts.faceswaplab_swapping import upscaled_inswapper
|
||||
from scripts.faceswaplab_utils.imgutils import (
|
||||
cv2_to_pil,
|
||||
pil_to_cv2,
|
||||
check_against_nsfw,
|
||||
)
|
||||
|
|
@ -26,7 +25,7 @@ from scripts.faceswaplab_ui.faceswaplab_unit_settings import FaceSwapUnitSetting
|
|||
providers = ["CPUExecutionProvider"]
|
||||
|
||||
|
||||
def cosine_similarity_face(face1, face2) -> float:
|
||||
def cosine_similarity_face(face1: Face, face2: Face) -> float:
|
||||
"""
|
||||
Calculates the cosine similarity between two face embeddings.
|
||||
|
||||
|
|
@ -92,7 +91,7 @@ class FaceModelException(Exception):
|
|||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def getAnalysisModel():
|
||||
def getAnalysisModel() -> insightface.app.FaceAnalysis:
|
||||
"""
|
||||
Retrieves the analysis model for face analysis.
|
||||
|
||||
|
|
@ -112,11 +111,11 @@ def getAnalysisModel():
|
|||
logger.error(
|
||||
"Loading of swapping model failed, please check the requirements (On Windows, download and install Visual Studio. During the install, make sure to include the Python and C++ packages.)"
|
||||
)
|
||||
raise FaceModelException("Loading of swapping model failed")
|
||||
raise FaceModelException("Loading of analysis model failed")
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def getFaceSwapModel(model_path: str):
|
||||
def getFaceSwapModel(model_path: str) -> upscaled_inswapper.UpscaledINSwapper:
|
||||
"""
|
||||
Retrieves the face swap model and initializes it if necessary.
|
||||
|
||||
|
|
@ -135,13 +134,14 @@ def getFaceSwapModel(model_path: str):
|
|||
logger.error(
|
||||
"Loading of swapping model failed, please check the requirements (On Windows, download and install Visual Studio. During the install, make sure to include the Python and C++ packages.)"
|
||||
)
|
||||
raise FaceModelException("Loading of swapping model failed")
|
||||
|
||||
|
||||
def get_faces(
|
||||
img_data: np.ndarray,
|
||||
det_size=(640, 640),
|
||||
det_thresh: Optional[int] = None,
|
||||
sort_by_face_size=False,
|
||||
img_data: np.ndarray, # type: ignore
|
||||
det_size: Tuple[int, int] = (640, 640),
|
||||
det_thresh: Optional[float] = None,
|
||||
sort_by_face_size: bool = False,
|
||||
) -> List[Face]:
|
||||
"""
|
||||
Detects and retrieves faces from an image using an analysis model.
|
||||
|
|
@ -211,7 +211,7 @@ class ImageResult:
|
|||
"""
|
||||
|
||||
|
||||
def get_or_default(l, index, default):
|
||||
def get_or_default(l: List[Any], index: int, default: Any) -> Any:
|
||||
"""
|
||||
Retrieve the value at the specified index from the given list.
|
||||
If the index is out of bounds, return the default value instead.
|
||||
|
|
@ -227,7 +227,10 @@ def get_or_default(l, index, default):
|
|||
return l[index] if index < len(l) else default
|
||||
|
||||
|
||||
def get_faces_from_img_files(files):
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def get_faces_from_img_files(files: List[gr.File]) -> List[Optional[np.ndarray]]: # type: ignore
|
||||
"""
|
||||
Extracts faces from a list of image files.
|
||||
|
||||
|
|
@ -300,15 +303,15 @@ def blend_faces(faces: List[Face]) -> Face:
|
|||
|
||||
|
||||
def swap_face(
|
||||
reference_face: np.ndarray,
|
||||
source_face: np.ndarray,
|
||||
reference_face: np.ndarray, # type: ignore
|
||||
source_face: np.ndarray, # type: ignore
|
||||
target_img: Image.Image,
|
||||
model: str,
|
||||
faces_index: Set[int] = {0},
|
||||
same_gender=True,
|
||||
upscaled_swapper=False,
|
||||
compute_similarity=True,
|
||||
sort_by_face_size=False,
|
||||
same_gender: bool = True,
|
||||
upscaled_swapper: bool = False,
|
||||
compute_similarity: bool = True,
|
||||
sort_by_face_size: bool = False,
|
||||
) -> ImageResult:
|
||||
"""
|
||||
Swaps faces in the target image with the source face.
|
||||
|
|
@ -344,6 +347,7 @@ def swap_face(
|
|||
for i, swapped_face in enumerate(target_faces):
|
||||
logger.info(f"swap face {i}")
|
||||
if i in faces_index:
|
||||
# type : ignore
|
||||
result = face_swapper.get(
|
||||
result, swapped_face, source_face, upscale=upscaled_swapper
|
||||
)
|
||||
|
|
@ -385,13 +389,13 @@ def swap_face(
|
|||
|
||||
|
||||
def process_image_unit(
|
||||
model,
|
||||
model: str,
|
||||
unit: FaceSwapUnitSettings,
|
||||
image: Image.Image,
|
||||
info=None,
|
||||
upscaled_swapper=False,
|
||||
force_blend=False,
|
||||
) -> List:
|
||||
info: str = None,
|
||||
upscaled_swapper: bool = False,
|
||||
force_blend: bool = False,
|
||||
) -> List[Tuple[Image.Image, str]]:
|
||||
"""Process one image and return a List of (image, info) (one if blended, many if not).
|
||||
|
||||
Args:
|
||||
|
|
@ -472,12 +476,12 @@ def process_image_unit(
|
|||
|
||||
|
||||
def process_images_units(
|
||||
model,
|
||||
model: str,
|
||||
units: List[FaceSwapUnitSettings],
|
||||
images: List[Tuple[Optional[Image.Image], Optional[str]]],
|
||||
upscaled_swapper=False,
|
||||
force_blend=False,
|
||||
) -> Union[List, None]:
|
||||
upscaled_swapper: bool = False,
|
||||
force_blend: bool = False,
|
||||
) -> Optional[List[Tuple[Image.Image, str]]]:
|
||||
if len(units) == 0:
|
||||
logger.info("Finished processing image, return %s images", len(images))
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -1,17 +1,11 @@
|
|||
import cv2
|
||||
import numpy as np
|
||||
import onnx
|
||||
import onnxruntime
|
||||
from insightface.model_zoo.inswapper import INSwapper
|
||||
from insightface.utils import face_align
|
||||
from modules import codeformer_model, processing, scripts, shared
|
||||
from modules.face_restoration import FaceRestoration
|
||||
from modules.shared import cmd_opts, opts, state
|
||||
from modules import processing, shared
|
||||
from modules.shared import opts
|
||||
from modules.upscaler import UpscalerData
|
||||
from onnx import numpy_helper
|
||||
from PIL import Image
|
||||
|
||||
from scripts.faceswaplab_utils.faceswaplab_logging import logger
|
||||
from scripts.faceswaplab_postprocessing import upscaling
|
||||
from scripts.faceswaplab_postprocessing.postprocessing_options import (
|
||||
PostProcessingOptions,
|
||||
|
|
|
|||
|
|
@ -5,13 +5,12 @@ from pprint import pformat, pprint
|
|||
import dill as pickle
|
||||
import gradio as gr
|
||||
import modules.scripts as scripts
|
||||
import numpy as np
|
||||
import onnx
|
||||
import pandas as pd
|
||||
from scripts.faceswaplab_ui.faceswaplab_unit_ui import faceswap_unit_ui
|
||||
from scripts.faceswaplab_ui.faceswaplab_upscaler_ui import upscaler_ui
|
||||
from insightface.app.common import Face
|
||||
from modules import script_callbacks, scripts
|
||||
from modules import scripts
|
||||
from PIL import Image
|
||||
from modules.shared import opts
|
||||
|
||||
|
|
@ -25,12 +24,13 @@ from scripts.faceswaplab_postprocessing.postprocessing_options import (
|
|||
)
|
||||
from scripts.faceswaplab_postprocessing.postprocessing import enhance_image
|
||||
from dataclasses import fields
|
||||
from typing import List
|
||||
from typing import Any, List, Optional, Union
|
||||
from scripts.faceswaplab_ui.faceswaplab_unit_settings import FaceSwapUnitSettings
|
||||
from scripts.faceswaplab_utils.models_utils import get_current_model
|
||||
import re
|
||||
|
||||
|
||||
def compare(img1, img2):
|
||||
def compare(img1: Image.Image, img2: Image.Image) -> Union[float, str]:
|
||||
if img1 is not None and img2 is not None:
|
||||
return swapper.compare_faces(img1, img2)
|
||||
|
||||
|
|
@ -40,19 +40,10 @@ def compare(img1, img2):
|
|||
def extract_faces(
|
||||
files,
|
||||
extract_path,
|
||||
face_restorer_name,
|
||||
face_restorer_visibility,
|
||||
codeformer_weight,
|
||||
upscaler_name,
|
||||
upscaler_scale,
|
||||
upscaler_visibility,
|
||||
inpainting_denoising_strengh,
|
||||
inpainting_prompt,
|
||||
inpainting_negative_prompt,
|
||||
inpainting_steps,
|
||||
inpainting_sampler,
|
||||
inpainting_when,
|
||||
*components: List[gr.components.Component],
|
||||
):
|
||||
postprocess_options = PostProcessingOptions(*components) # type: ignore
|
||||
|
||||
if not extract_path:
|
||||
tempfile.mkdtemp()
|
||||
if files is not None:
|
||||
|
|
@ -66,24 +57,16 @@ def extract_faces(
|
|||
bbox = face.bbox.astype(int)
|
||||
x_min, y_min, x_max, y_max = bbox
|
||||
face_image = img.crop((x_min, y_min, x_max, y_max))
|
||||
if face_restorer_name or face_restorer_visibility:
|
||||
scale = 1 if face_image.width > 512 else 512 // face_image.width
|
||||
if (
|
||||
postprocess_options.face_restorer_name
|
||||
or postprocess_options.restorer_visibility
|
||||
):
|
||||
postprocess_options.scale = (
|
||||
1 if face_image.width > 512 else 512 // face_image.width
|
||||
)
|
||||
face_image = enhance_image(
|
||||
face_image,
|
||||
PostProcessingOptions(
|
||||
face_restorer_name=face_restorer_name,
|
||||
restorer_visibility=face_restorer_visibility,
|
||||
codeformer_weight=codeformer_weight,
|
||||
upscaler_name=upscaler_name,
|
||||
upscale_visibility=upscaler_visibility,
|
||||
scale=scale,
|
||||
inpainting_denoising_strengh=inpainting_denoising_strengh,
|
||||
inpainting_prompt=inpainting_prompt,
|
||||
inpainting_steps=inpainting_steps,
|
||||
inpainting_negative_prompt=inpainting_negative_prompt,
|
||||
inpainting_when=inpainting_when,
|
||||
inpainting_sampler=inpainting_sampler,
|
||||
),
|
||||
postprocess_options,
|
||||
)
|
||||
path = tempfile.NamedTemporaryFile(
|
||||
delete=False, suffix=".png", dir=extract_path
|
||||
|
|
@ -95,7 +78,7 @@ def extract_faces(
|
|||
return None
|
||||
|
||||
|
||||
def analyse_faces(image, det_threshold=0.5):
|
||||
def analyse_faces(image: Image.Image, det_threshold: float = 0.5) -> str:
|
||||
try:
|
||||
faces = swapper.get_faces(imgutils.pil_to_cv2(image), det_thresh=det_threshold)
|
||||
result = ""
|
||||
|
|
@ -110,27 +93,40 @@ def analyse_faces(image, det_threshold=0.5):
|
|||
return "Analysis Failed"
|
||||
|
||||
|
||||
def build_face_checkpoint_and_save(batch_files, name):
|
||||
def sanitize_name(name: str) -> str:
|
||||
logger.debug(f"Sanitize name {name}")
|
||||
name = re.sub("[^A-Za-z0-9_. ]+", "", name)
|
||||
name = name.replace(" ", "_")
|
||||
logger.debug(f"Sanitized name {name[:255]}")
|
||||
return name[:255]
|
||||
|
||||
|
||||
def build_face_checkpoint_and_save(
|
||||
batch_files: gr.File, name: str
|
||||
) -> Optional[Image.Image]:
|
||||
"""
|
||||
Builds a face checkpoint, swaps faces, and saves the result to a file.
|
||||
Builds a face checkpoint using the provided image files, performs face swapping,
|
||||
and saves the result to a file. If a blended face is successfully obtained and the face swapping
|
||||
process succeeds, the resulting image is returned. Otherwise, None is returned.
|
||||
|
||||
Args:
|
||||
batch_files (list): List of image file paths.
|
||||
name (str): Name of the face checkpoint
|
||||
batch_files (list): List of image file paths used to create the face checkpoint.
|
||||
name (str): The name assigned to the face checkpoint.
|
||||
|
||||
Returns:
|
||||
PIL.Image.Image or None: Resulting swapped face image if successful, otherwise None.
|
||||
PIL.Image.Image or None: The resulting swapped face image if the process is successful; None otherwise.
|
||||
"""
|
||||
name = sanitize_name(name)
|
||||
batch_files = batch_files or []
|
||||
print("Build", name, [x.name for x in batch_files])
|
||||
logger.info("Build %s %s", name, [x.name for x in batch_files])
|
||||
faces = swapper.get_faces_from_img_files(batch_files)
|
||||
blended_face = swapper.blend_faces(faces)
|
||||
preview_path = os.path.join(
|
||||
scripts.basedir(), "extensions", "sd-webui-faceswaplab", "references"
|
||||
)
|
||||
faces_path = os.path.join(scripts.basedir(), "models", "faceswaplab", "faces")
|
||||
if not os.path.exists(faces_path):
|
||||
os.makedirs(faces_path)
|
||||
|
||||
os.makedirs(faces_path, exist_ok=True)
|
||||
|
||||
target_img = None
|
||||
if blended_face:
|
||||
|
|
@ -208,7 +204,9 @@ def explore_onnx_faceswap_model(model_path):
|
|||
return df
|
||||
|
||||
|
||||
def batch_process(files, save_path, *components):
|
||||
def batch_process(
|
||||
files, save_path, *components: List[gr.components.Component]
|
||||
) -> Optional[List[Image.Image]]:
|
||||
try:
|
||||
if save_path is not None:
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
|
|
@ -228,7 +226,7 @@ def batch_process(files, save_path, *components):
|
|||
len_conf: int = len(fields(FaceSwapUnitSettings))
|
||||
shift: int = units_count * len_conf
|
||||
postprocess_options = PostProcessingOptions(
|
||||
*components[shift : shift + len(fields(PostProcessingOptions))]
|
||||
*components[shift : shift + len(fields(PostProcessingOptions))] # type: ignore
|
||||
)
|
||||
logger.debug("%s", pformat(postprocess_options))
|
||||
|
||||
|
|
@ -247,7 +245,7 @@ def batch_process(files, save_path, *components):
|
|||
),
|
||||
)
|
||||
if len(swapped_images) > 0:
|
||||
current_images += [img for img, info in swapped_images]
|
||||
current_images += [img for img, _ in swapped_images]
|
||||
|
||||
logger.info("%s images generated", len(current_images))
|
||||
for i, img in enumerate(current_images):
|
||||
|
|
@ -269,7 +267,7 @@ def batch_process(files, save_path, *components):
|
|||
return None
|
||||
|
||||
|
||||
def tools_ui():
|
||||
def tools_ui() -> None:
|
||||
models = get_models()
|
||||
with gr.Tab("Tools"):
|
||||
with gr.Tab("Build"):
|
||||
|
|
@ -431,7 +429,7 @@ def tools_ui():
|
|||
)
|
||||
|
||||
|
||||
def on_ui_tabs():
|
||||
def on_ui_tabs() -> List[Any]:
|
||||
with gr.Blocks(analytics_enabled=False) as ui_faceswap:
|
||||
tools_ui()
|
||||
return [(ui_faceswap, "FaceSwapLab", "faceswaplab_tab")]
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import dill as pickle
|
|||
import gradio as gr
|
||||
from insightface.app.common import Face
|
||||
from PIL import Image
|
||||
from scripts.faceswaplab_utils.imgutils import pil_to_cv2, check_against_nsfw
|
||||
from scripts.faceswaplab_utils.imgutils import pil_to_cv2
|
||||
from scripts.faceswaplab_utils.faceswaplab_logging import logger
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,11 @@
|
|||
from typing import List
|
||||
from scripts.faceswaplab_utils.models_utils import get_face_checkpoints
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def faceswap_unit_ui(is_img2img, unit_num=1, id_prefix="faceswaplab"):
|
||||
def faceswap_unit_ui(
|
||||
is_img2img: bool, unit_num: int = 1, id_prefix: str = "faceswaplab"
|
||||
) -> List[gr.components.Component]:
|
||||
with gr.Tab(f"Face {unit_num}"):
|
||||
with gr.Column():
|
||||
gr.Markdown(
|
||||
|
|
@ -37,7 +40,7 @@ def faceswap_unit_ui(is_img2img, unit_num=1, id_prefix="faceswaplab"):
|
|||
elem_id=f"{id_prefix}_face{unit_num}_refresh_checkpoints",
|
||||
)
|
||||
|
||||
def refresh_fn(selected):
|
||||
def refresh_fn(selected: str) -> None:
|
||||
return gr.Dropdown.update(
|
||||
value=selected, choices=get_face_checkpoints()
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,13 +1,12 @@
|
|||
from typing import List
|
||||
import gradio as gr
|
||||
import modules
|
||||
from modules import shared, sd_models
|
||||
from modules.shared import cmd_opts, opts, state
|
||||
|
||||
import scripts.faceswaplab_postprocessing.upscaling as upscaling
|
||||
from scripts.faceswaplab_utils.faceswaplab_logging import logger
|
||||
from modules.shared import opts
|
||||
from scripts.faceswaplab_postprocessing.postprocessing_options import InpaintingWhen
|
||||
|
||||
|
||||
def upscaler_ui():
|
||||
def upscaler_ui() -> List[gr.components.Component]:
|
||||
with gr.Tab(f"Post-Processing"):
|
||||
gr.Markdown(
|
||||
"""Upscaling is performed on the whole image. Upscaling happens before face restoration."""
|
||||
|
|
@ -74,10 +73,8 @@ def upscaler_ui():
|
|||
)
|
||||
inpainting_when = gr.Dropdown(
|
||||
elem_id="faceswaplab_pp_inpainting_when",
|
||||
choices=[
|
||||
e.value for e in upscaling.InpaintingWhen.__members__.values()
|
||||
],
|
||||
value=[upscaling.InpaintingWhen.BEFORE_RESTORE_FACE.value],
|
||||
choices=[e.value for e in InpaintingWhen.__members__.values()],
|
||||
value=[InpaintingWhen.BEFORE_RESTORE_FACE.value],
|
||||
label="Enable/When",
|
||||
)
|
||||
inpainting_denoising_strength = gr.Slider(
|
||||
|
|
|
|||
|
|
@ -1,12 +1,24 @@
|
|||
import logging
|
||||
import copy
|
||||
import sys
|
||||
from typing import Any
|
||||
from modules import shared
|
||||
from PIL import Image
|
||||
from logging import LogRecord
|
||||
|
||||
|
||||
class ColoredFormatter(logging.Formatter):
|
||||
COLORS = {
|
||||
"""
|
||||
A custom logging formatter that outputs logs with level names colored.
|
||||
|
||||
Class Attributes:
|
||||
COLORS (dict): A dictionary mapping logging level names to their corresponding color codes.
|
||||
|
||||
Inherits From:
|
||||
logging.Formatter
|
||||
"""
|
||||
|
||||
COLORS: dict[str, str] = {
|
||||
"DEBUG": "\033[0;36m", # CYAN
|
||||
"INFO": "\033[0;32m", # GREEN
|
||||
"WARNING": "\033[0;33m", # YELLOW
|
||||
|
|
@ -15,7 +27,21 @@ class ColoredFormatter(logging.Formatter):
|
|||
"RESET": "\033[0m", # RESET COLOR
|
||||
}
|
||||
|
||||
def format(self, record):
|
||||
def format(self, record: LogRecord) -> str:
|
||||
"""
|
||||
Format the specified record as text.
|
||||
|
||||
The record's attribute dictionary is used as the operand to a string
|
||||
formatting operation which yields the returned string. Before formatting
|
||||
the dictionary, a check is made to see if the format uses the levelname
|
||||
of the record. If it does, a colorized version is created and used.
|
||||
|
||||
Args:
|
||||
record (LogRecord): The log record to be formatted.
|
||||
|
||||
Returns:
|
||||
str: The formatted string which includes the colorized levelname.
|
||||
"""
|
||||
colored_record = copy.copy(record)
|
||||
levelname = colored_record.levelname
|
||||
seq = self.COLORS.get(levelname, self.COLORS["RESET"])
|
||||
|
|
@ -46,7 +72,24 @@ if logger.getEffectiveLevel() <= logging.DEBUG:
|
|||
DEBUG_DIR = tempfile.mkdtemp()
|
||||
|
||||
|
||||
def save_img_debug(img: Image.Image, message: str, *opts):
|
||||
def save_img_debug(img: Image.Image, message: str, *opts: Any) -> None:
|
||||
"""
|
||||
Saves an image to a temporary file if the logger's effective level is set to DEBUG or lower.
|
||||
After saving, it logs a debug message along with the file URI of the image.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
img : Image.Image
|
||||
The image to be saved.
|
||||
message : str
|
||||
The message to be logged.
|
||||
*opts : Any
|
||||
Additional arguments to be passed to the logger's debug method.
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
"""
|
||||
if logger.getEffectiveLevel() <= logging.DEBUG:
|
||||
with tempfile.NamedTemporaryFile(
|
||||
dir=DEBUG_DIR, delete=False, suffix=".png"
|
||||
|
|
|
|||
|
|
@ -1,35 +1,76 @@
|
|||
import io
|
||||
from typing import Optional
|
||||
from PIL import Image, ImageChops, ImageOps, ImageFilter
|
||||
from typing import List, Optional, Tuple, Union, Dict
|
||||
from PIL import Image
|
||||
import cv2
|
||||
import numpy as np
|
||||
from math import isqrt, ceil
|
||||
import torch
|
||||
from ifnude import detect
|
||||
from scripts.faceswaplab_globals import NSFW_SCORE
|
||||
from scripts.faceswaplab_globals import NSFW_SCORE_THRESHOLD
|
||||
from modules import processing
|
||||
import base64
|
||||
from collections import Counter
|
||||
|
||||
|
||||
def check_against_nsfw(img):
|
||||
shapes = []
|
||||
chunks = detect(img)
|
||||
def check_against_nsfw(img: Image.Image) -> bool:
|
||||
"""
|
||||
Check if an image exceeds the Not Safe for Work (NSFW) score.
|
||||
|
||||
Parameters:
|
||||
img (PIL.Image.Image): The image to be checked.
|
||||
|
||||
Returns:
|
||||
bool: True if any part of the image is considered NSFW, False otherwise.
|
||||
"""
|
||||
|
||||
shapes: List[bool] = []
|
||||
chunks: List[Dict[str, Union[int, float]]] = detect(img)
|
||||
|
||||
for chunk in chunks:
|
||||
shapes.append(chunk["score"] > NSFW_SCORE)
|
||||
shapes.append(chunk["score"] > NSFW_SCORE_THRESHOLD)
|
||||
|
||||
return any(shapes)
|
||||
|
||||
|
||||
def pil_to_cv2(pil_img):
|
||||
def pil_to_cv2(pil_img: Image.Image) -> np.ndarray: # type: ignore
|
||||
"""
|
||||
Convert a PIL Image into an OpenCV image (cv2).
|
||||
|
||||
Args:
|
||||
pil_img (PIL.Image.Image): An image in PIL format.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The input image converted to OpenCV format (BGR).
|
||||
"""
|
||||
return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
|
||||
|
||||
|
||||
def cv2_to_pil(cv2_img):
|
||||
def cv2_to_pil(cv2_img: np.ndarray) -> Image.Image: # type: ignore
|
||||
"""
|
||||
Convert an OpenCV image (cv2) into a PIL Image.
|
||||
|
||||
Args:
|
||||
cv2_img (np.ndarray): An image in OpenCV format (BGR).
|
||||
|
||||
Returns:
|
||||
PIL.Image.Image: The input image converted to PIL format (RGB).
|
||||
"""
|
||||
return Image.fromarray(cv2.cvtColor(cv2_img, cv2.COLOR_BGR2RGB))
|
||||
|
||||
|
||||
def torch_to_pil(images):
|
||||
def torch_to_pil(images: torch.Tensor) -> List[Image.Image]:
|
||||
"""
|
||||
Convert a numpy image or a batch of images to a PIL image.
|
||||
Converts a tensor image or a batch of tensor images to a PIL image or a list of PIL images.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
images : torch.Tensor
|
||||
A tensor representing an image or a batch of images.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list
|
||||
A list of PIL images.
|
||||
"""
|
||||
images = images.cpu().permute(0, 2, 3, 1).numpy()
|
||||
if images.ndim == 3:
|
||||
|
|
@ -39,9 +80,19 @@ def torch_to_pil(images):
|
|||
return pil_images
|
||||
|
||||
|
||||
def pil_to_torch(pil_images):
|
||||
def pil_to_torch(pil_images: Union[Image.Image, List[Image.Image]]) -> torch.Tensor:
|
||||
"""
|
||||
Convert a PIL image or a list of PIL images to a torch tensor or a batch of torch tensors.
|
||||
Converts a PIL image or a list of PIL images to a torch tensor or a batch of torch tensors.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pil_images : Union[Image.Image, List[Image.Image]]
|
||||
A PIL image or a list of PIL images.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
A tensor representing an image or a batch of images.
|
||||
"""
|
||||
if isinstance(pil_images, list):
|
||||
numpy_images = [np.array(image) for image in pil_images]
|
||||
|
|
@ -53,10 +104,7 @@ def pil_to_torch(pil_images):
|
|||
return torch_image
|
||||
|
||||
|
||||
from collections import Counter
|
||||
|
||||
|
||||
def create_square_image(image_list):
|
||||
def create_square_image(image_list: List[Image.Image]) -> Optional[Image.Image]:
|
||||
"""
|
||||
Creates a square image by combining multiple images in a grid pattern.
|
||||
|
||||
|
|
@ -108,16 +156,41 @@ def create_square_image(image_list):
|
|||
return None
|
||||
|
||||
|
||||
def create_mask(image, box_coords):
|
||||
# def create_mask(image : Image.Image, box_coords : Tuple[int, int, int, int]) -> Image.Image:
|
||||
# width, height = image.size
|
||||
# mask = Image.new("L", (width, height), 255)
|
||||
# x1, y1, x2, y2 = box_coords
|
||||
# for x in range(width):
|
||||
# for y in range(height):
|
||||
# if x1 <= x <= x2 and y1 <= y <= y2:
|
||||
# mask.putpixel((x, y), 255)
|
||||
# else:
|
||||
# mask.putpixel((x, y), 0)
|
||||
# return mask
|
||||
|
||||
|
||||
def create_mask(
|
||||
image: Image.Image, box_coords: Tuple[int, int, int, int]
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Create a binary mask for a given image and bounding box coordinates.
|
||||
|
||||
Args:
|
||||
image (PIL.Image.Image): The input image.
|
||||
box_coords (Tuple[int, int, int, int]): A tuple of 4 integers defining the bounding box.
|
||||
It follows the pattern (x1, y1, x2, y2), where (x1, y1) is the top-left coordinate of the
|
||||
box and (x2, y2) is the bottom-right coordinate of the box.
|
||||
|
||||
Returns:
|
||||
PIL.Image.Image: A binary mask of the same size as the input image, where pixels within
|
||||
the bounding box are white (255) and pixels outside the bounding box are black (0).
|
||||
"""
|
||||
width, height = image.size
|
||||
mask = Image.new("L", (width, height), 255)
|
||||
mask = Image.new("L", (width, height), 0)
|
||||
x1, y1, x2, y2 = box_coords
|
||||
for x in range(width):
|
||||
for y in range(height):
|
||||
if x1 <= x <= x2 and y1 <= y <= y2:
|
||||
mask.putpixel((x, y), 255)
|
||||
else:
|
||||
mask.putpixel((x, y), 0)
|
||||
for x in range(x1, x2 + 1):
|
||||
for y in range(y1, y2 + 1):
|
||||
mask.putpixel((x, y), 255)
|
||||
return mask
|
||||
|
||||
|
||||
|
|
@ -185,12 +258,32 @@ def prepare_mask(
|
|||
|
||||
|
||||
def base64_to_pil(base64str: Optional[str]) -> Optional[Image.Image]:
|
||||
"""
|
||||
Converts a base64 string to a PIL Image object.
|
||||
|
||||
Parameters:
|
||||
base64str (Optional[str]): The base64 string to convert. This string may contain a data URL scheme
|
||||
(i.e., 'data:image/jpeg;base64,') or just be the raw base64 encoded data. If None, the function
|
||||
will return None.
|
||||
|
||||
Returns:
|
||||
Optional[Image.Image]: A PIL Image object created from the base64 string. If the input is None,
|
||||
the function returns None.
|
||||
|
||||
Raises:
|
||||
binascii.Error: If the base64 string is not properly formatted or encoded.
|
||||
PIL.UnidentifiedImageError: If the image format cannot be identified.
|
||||
"""
|
||||
|
||||
if base64str is None:
|
||||
return None
|
||||
if "base64," in base64str: # check if the base64 string has a data URL scheme
|
||||
|
||||
# Check if the base64 string has a data URL scheme
|
||||
if "base64," in base64str:
|
||||
base64_data = base64str.split("base64,")[-1]
|
||||
img_bytes = base64.b64decode(base64_data)
|
||||
else:
|
||||
# if no data URL scheme, just decode
|
||||
# If no data URL scheme, just decode
|
||||
img_bytes = base64.b64decode(base64str)
|
||||
|
||||
return Image.open(io.BytesIO(img_bytes))
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import glob
|
||||
import os
|
||||
from typing import List
|
||||
import modules.scripts as scripts
|
||||
from modules import scripts
|
||||
from scripts.faceswaplab_globals import EXTENSION_PATH
|
||||
|
|
@ -7,7 +8,7 @@ from modules.shared import opts
|
|||
from scripts.faceswaplab_utils.faceswaplab_logging import logger
|
||||
|
||||
|
||||
def get_models():
|
||||
def get_models() -> List[str]:
|
||||
"""
|
||||
Retrieve a list of swap model files.
|
||||
|
||||
|
|
@ -44,7 +45,7 @@ def get_current_model() -> str:
|
|||
return model
|
||||
|
||||
|
||||
def get_face_checkpoints():
|
||||
def get_face_checkpoints() -> List[str]:
|
||||
"""
|
||||
Retrieve a list of face checkpoint paths.
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue