add tests
parent
ddc403db1f
commit
15e9366eb6
2
check.sh
2
check.sh
|
|
@ -1,4 +1,4 @@
|
|||
#!/bin/bash
|
||||
autoflake --in-place --remove-unused-variables -r --remove-all-unused-imports .
|
||||
mypy --install-types
|
||||
pre-commit run --all-files
|
||||
pre-commit run --all-files
|
||||
|
|
|
|||
|
|
@ -12,13 +12,13 @@ address = "http://127.0.0.1:7860"
|
|||
|
||||
# First face unit :
|
||||
unit1 = FaceSwapUnit(
|
||||
source_img=pil_to_base64("../../references/man.png"), # The face you want to use
|
||||
source_img=pil_to_base64("../references/man.png"), # The face you want to use
|
||||
faces_index=(0,), # Replace first face
|
||||
)
|
||||
|
||||
# Second face unit :
|
||||
unit2 = FaceSwapUnit(
|
||||
source_img=pil_to_base64("../../references/woman.png"), # The face you want to use
|
||||
source_img=pil_to_base64("../references/woman.png"), # The face you want to use
|
||||
same_gender=True,
|
||||
faces_index=(0,), # Replace first woman since same gender is on
|
||||
)
|
||||
|
|
@ -48,5 +48,6 @@ result = requests.post(
|
|||
)
|
||||
response = FaceSwapResponse.parse_obj(result.json())
|
||||
|
||||
for img, info in zip(response.pil_images, response.infos):
|
||||
img.show(title=info)
|
||||
print(response.json())
|
||||
for img in response.pil_images:
|
||||
img.show()
|
||||
|
Before Width: | Height: | Size: 99 KiB After Width: | Height: | Size: 99 KiB |
|
|
@ -13,9 +13,6 @@ from scripts.faceswaplab_ui.faceswaplab_unit_settings import FaceSwapUnitSetting
|
|||
from scripts.faceswaplab_utils.imgutils import (
|
||||
base64_to_pil,
|
||||
)
|
||||
from scripts.faceswaplab_utils.models_utils import get_current_model
|
||||
from modules.shared import opts
|
||||
from scripts.faceswaplab_postprocessing.postprocessing import enhance_image
|
||||
from scripts.faceswaplab_postprocessing.postprocessing_options import (
|
||||
PostProcessingOptions,
|
||||
)
|
||||
|
|
@ -135,22 +132,18 @@ def faceswaplab_api(_: gr.Blocks, app: FastAPI) -> None:
|
|||
units: List[FaceSwapUnitSettings] = []
|
||||
src_image: Optional[Image.Image] = base64_to_pil(request.image)
|
||||
response = FaceSwapResponse(images=[], infos=[])
|
||||
if request.postprocessing:
|
||||
pp_options = get_postprocessing_options(request.postprocessing)
|
||||
|
||||
if src_image is not None:
|
||||
if request.postprocessing:
|
||||
pp_options = get_postprocessing_options(request.postprocessing)
|
||||
units = get_faceswap_units_settings(request.units)
|
||||
|
||||
swapped_images = swapper.process_images_units(
|
||||
get_current_model(),
|
||||
images=[(src_image, None)],
|
||||
units=units,
|
||||
upscaled_swapper=opts.data.get("faceswaplab_upscaled_swapper", False),
|
||||
swapped_images = swapper.batch_process(
|
||||
[src_image], None, units=units, postprocess_options=pp_options
|
||||
)
|
||||
for img, info in swapped_images:
|
||||
if pp_options:
|
||||
img = enhance_image(img, pp_options)
|
||||
response.images.append(encode_to_base64(img))
|
||||
response.infos.append(info)
|
||||
|
||||
for img in swapped_images:
|
||||
response.images.append(encode_to_base64(img))
|
||||
|
||||
response.infos = [] # Not used atm
|
||||
return response
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ REFERENCE_PATH = os.path.join(
|
|||
scripts.basedir(), "extensions", "sd-webui-faceswaplab", "references"
|
||||
)
|
||||
|
||||
VERSION_FLAG: str = "v1.1.0"
|
||||
VERSION_FLAG: str = "v1.1.1"
|
||||
EXTENSION_PATH = os.path.join("extensions", "sd-webui-faceswaplab")
|
||||
|
||||
# The NSFW score threshold. If any part of the image has a score greater than this threshold, the image will be considered NSFW.
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import copy
|
|||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Set, Tuple, Optional
|
||||
import tempfile
|
||||
|
||||
import cv2
|
||||
import insightface
|
||||
|
|
@ -21,6 +22,12 @@ from scripts import faceswaplab_globals
|
|||
from modules.shared import opts
|
||||
from functools import lru_cache
|
||||
from scripts.faceswaplab_ui.faceswaplab_unit_settings import FaceSwapUnitSettings
|
||||
from scripts.faceswaplab_postprocessing.postprocessing import enhance_image
|
||||
from scripts.faceswaplab_postprocessing.postprocessing_options import (
|
||||
PostProcessingOptions,
|
||||
)
|
||||
from scripts.faceswaplab_utils.models_utils import get_current_model
|
||||
|
||||
|
||||
providers = ["CPUExecutionProvider"]
|
||||
|
||||
|
|
@ -78,6 +85,53 @@ def compare_faces(img1: Image.Image, img2: Image.Image) -> float:
|
|||
return -1
|
||||
|
||||
|
||||
def batch_process(
|
||||
src_images: List[Image.Image],
|
||||
save_path: Optional[str],
|
||||
units: List[FaceSwapUnitSettings],
|
||||
postprocess_options: PostProcessingOptions,
|
||||
) -> Optional[List[Image.Image]]:
|
||||
try:
|
||||
if save_path:
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
|
||||
units = [u for u in units if u.enable]
|
||||
if src_images is not None and len(units) > 0:
|
||||
result_images = []
|
||||
for src_image in src_images:
|
||||
current_images = []
|
||||
swapped_images = process_images_units(
|
||||
get_current_model(),
|
||||
images=[(src_image, None)],
|
||||
units=units,
|
||||
upscaled_swapper=opts.data.get(
|
||||
"faceswaplab_upscaled_swapper", False
|
||||
),
|
||||
)
|
||||
if len(swapped_images) > 0:
|
||||
current_images += [img for img, _ in swapped_images]
|
||||
|
||||
logger.info("%s images generated", len(current_images))
|
||||
for i, img in enumerate(current_images):
|
||||
current_images[i] = enhance_image(img, postprocess_options)
|
||||
|
||||
if save_path:
|
||||
for img in current_images:
|
||||
path = tempfile.NamedTemporaryFile(
|
||||
delete=False, suffix=".png", dir=save_path
|
||||
).name
|
||||
img.save(path)
|
||||
|
||||
result_images += current_images
|
||||
return result_images
|
||||
except Exception as e:
|
||||
logger.error("Batch Process error : %s", e)
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
|
||||
class FaceModelException(Exception):
|
||||
"""Exception raised when an error is encountered in the face model."""
|
||||
|
||||
|
|
|
|||
|
|
@ -26,7 +26,6 @@ from scripts.faceswaplab_postprocessing.postprocessing import enhance_image
|
|||
from dataclasses import fields
|
||||
from typing import Any, Dict, List, Optional
|
||||
from scripts.faceswaplab_ui.faceswaplab_unit_settings import FaceSwapUnitSettings
|
||||
from scripts.faceswaplab_utils.models_utils import get_current_model
|
||||
import re
|
||||
|
||||
|
||||
|
|
@ -291,9 +290,6 @@ def batch_process(
|
|||
files: List[gr.File], save_path: str, *components: List[gr.components.Component]
|
||||
) -> Optional[List[Image.Image]]:
|
||||
try:
|
||||
if save_path is not None:
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
|
||||
units_count = opts.data.get("faceswaplab_units_count", 3)
|
||||
units: List[FaceSwapUnitSettings] = []
|
||||
|
||||
|
|
@ -312,36 +308,15 @@ def batch_process(
|
|||
*components[shift : shift + len(fields(PostProcessingOptions))] # type: ignore
|
||||
)
|
||||
logger.debug("%s", pformat(postprocess_options))
|
||||
|
||||
units = [u for u in units if u.enable]
|
||||
if files is not None and len(units) > 0:
|
||||
images = []
|
||||
for file in files:
|
||||
current_images = []
|
||||
src_image = Image.open(file.name)
|
||||
swapped_images = swapper.process_images_units(
|
||||
get_current_model(),
|
||||
images=[(src_image, None)],
|
||||
units=units,
|
||||
upscaled_swapper=opts.data.get(
|
||||
"faceswaplab_upscaled_swapper", False
|
||||
),
|
||||
)
|
||||
if len(swapped_images) > 0:
|
||||
current_images += [img for img, _ in swapped_images]
|
||||
|
||||
logger.info("%s images generated", len(current_images))
|
||||
for i, img in enumerate(current_images):
|
||||
current_images[i] = enhance_image(img, postprocess_options)
|
||||
|
||||
for img in current_images:
|
||||
path = tempfile.NamedTemporaryFile(
|
||||
delete=False, suffix=".png", dir=save_path
|
||||
).name
|
||||
img.save(path)
|
||||
|
||||
images += current_images
|
||||
return images
|
||||
images = [
|
||||
Image.open(file.name) for file in files
|
||||
] # potentially greedy but Image.open is supposed to be lazy
|
||||
return swapper.batch_process(
|
||||
images,
|
||||
save_path=save_path,
|
||||
units=units,
|
||||
postprocess_options=postprocess_options,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Batch Process error : %s", e)
|
||||
import traceback
|
||||
|
|
|
|||
|
|
@ -0,0 +1,83 @@
|
|||
from typing import List
|
||||
import pytest
|
||||
import requests
|
||||
import sys
|
||||
|
||||
sys.path.append(".")
|
||||
|
||||
from client_api.client_utils import (
|
||||
FaceSwapUnit,
|
||||
FaceSwapResponse,
|
||||
PostProcessingOptions,
|
||||
FaceSwapRequest,
|
||||
base64_to_pil,
|
||||
pil_to_base64,
|
||||
InpaintingWhen,
|
||||
)
|
||||
from PIL import Image
|
||||
|
||||
base_url = "http://127.0.0.1:7860"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def face_swap_request() -> FaceSwapRequest:
|
||||
# First face unit
|
||||
unit1 = FaceSwapUnit(
|
||||
source_img=pil_to_base64("references/man.png"), # The face you want to use
|
||||
faces_index=(0,), # Replace first face
|
||||
)
|
||||
|
||||
# Second face unit
|
||||
unit2 = FaceSwapUnit(
|
||||
source_img=pil_to_base64("references/woman.png"), # The face you want to use
|
||||
same_gender=True,
|
||||
faces_index=(0,), # Replace first woman since same gender is on
|
||||
)
|
||||
|
||||
# Post-processing config
|
||||
pp = PostProcessingOptions(
|
||||
face_restorer_name="CodeFormer",
|
||||
codeformer_weight=0.5,
|
||||
restorer_visibility=1,
|
||||
upscaler_name="Lanczos",
|
||||
scale=4,
|
||||
inpainting_steps=30,
|
||||
inpainting_denoising_strengh=0.1,
|
||||
inpainting_when=InpaintingWhen.BEFORE_RESTORE_FACE,
|
||||
)
|
||||
|
||||
# Prepare the request
|
||||
request = FaceSwapRequest(
|
||||
image=pil_to_base64("tests/test_image.png"),
|
||||
units=[unit1, unit2],
|
||||
postprocessing=pp,
|
||||
)
|
||||
|
||||
return request
|
||||
|
||||
|
||||
def test_version() -> None:
|
||||
response = requests.get(f"{base_url}/faceswaplab/version")
|
||||
assert response.status_code == 200
|
||||
assert "version" in response.json()
|
||||
|
||||
|
||||
def test_faceswap(face_swap_request: FaceSwapRequest) -> None:
|
||||
response = requests.post(
|
||||
f"{base_url}/faceswaplab/swap_face",
|
||||
data=face_swap_request.json(),
|
||||
headers={"Content-Type": "application/json; charset=utf-8"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "images" in data
|
||||
assert "infos" in data
|
||||
|
||||
res = FaceSwapResponse.parse_obj(response.json())
|
||||
images: List[Image.Image] = res.pil_images
|
||||
assert len(images) == 1
|
||||
image = images[0]
|
||||
orig_image = base64_to_pil(face_swap_request.image)
|
||||
assert image.width == orig_image.width * face_swap_request.postprocessing.scale
|
||||
assert image.height == orig_image.height * face_swap_request.postprocessing.scale
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 99 KiB |
Loading…
Reference in New Issue