add api for face building, add tests
parent
4533750c49
commit
02d88bac91
|
|
@ -187,7 +187,7 @@ class FaceSwapRequest(BaseModel):
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
units: List[FaceSwapUnit]
|
units: List[FaceSwapUnit]
|
||||||
postprocessing: Optional[PostProcessingOptions]
|
postprocessing: Optional[PostProcessingOptions] = None
|
||||||
|
|
||||||
|
|
||||||
class FaceSwapResponse(BaseModel):
|
class FaceSwapResponse(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,9 @@
|
||||||
|
from typing import List
|
||||||
import requests
|
import requests
|
||||||
from api_utils import (
|
from api_utils import (
|
||||||
FaceSwapUnit,
|
FaceSwapUnit,
|
||||||
InswappperOptions,
|
InswappperOptions,
|
||||||
|
base64_to_safetensors,
|
||||||
pil_to_base64,
|
pil_to_base64,
|
||||||
PostProcessingOptions,
|
PostProcessingOptions,
|
||||||
InpaintingWhen,
|
InpaintingWhen,
|
||||||
|
|
@ -98,12 +100,30 @@ for img in response.pil_images:
|
||||||
img.show()
|
img.show()
|
||||||
|
|
||||||
|
|
||||||
|
#############################
|
||||||
|
# Build checkpoint
|
||||||
|
|
||||||
|
source_images: List[str] = [
|
||||||
|
pil_to_base64("../references/man.png"),
|
||||||
|
pil_to_base64("../references/woman.png"),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = requests.post(
|
||||||
|
url=f"{address}/faceswaplab/build",
|
||||||
|
json=source_images,
|
||||||
|
headers={"Content-Type": "application/json; charset=utf-8"},
|
||||||
|
)
|
||||||
|
|
||||||
|
base64_to_safetensors(result.json(), output_path="test.safetensors")
|
||||||
|
|
||||||
#############################
|
#############################
|
||||||
# FaceSwap with local safetensors
|
# FaceSwap with local safetensors
|
||||||
|
|
||||||
# First face unit :
|
# First face unit :
|
||||||
unit1 = FaceSwapUnit(
|
unit1 = FaceSwapUnit(
|
||||||
source_face=safetensors_to_base64("test.safetensors"),
|
source_face=safetensors_to_base64(
|
||||||
|
"test.safetensors"
|
||||||
|
), # convert the checkpoint to base64
|
||||||
faces_index=(0,), # Replace first face
|
faces_index=(0,), # Replace first face
|
||||||
swapping_options=InswappperOptions(
|
swapping_options=InswappperOptions(
|
||||||
face_restorer_name="CodeFormer",
|
face_restorer_name="CodeFormer",
|
||||||
|
|
|
||||||
Binary file not shown.
|
|
@ -1,3 +1,4 @@
|
||||||
|
import tempfile
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
|
@ -17,6 +18,9 @@ from scripts.faceswaplab_postprocessing.postprocessing_options import (
|
||||||
PostProcessingOptions,
|
PostProcessingOptions,
|
||||||
)
|
)
|
||||||
from client_api import api_utils
|
from client_api import api_utils
|
||||||
|
from scripts.faceswaplab_utils.face_checkpoints_utils import (
|
||||||
|
build_face_checkpoint_and_save,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def encode_to_base64(image: Union[str, Image.Image, np.ndarray]) -> str: # type: ignore
|
def encode_to_base64(image: Union[str, Image.Image, np.ndarray]) -> str: # type: ignore
|
||||||
|
|
@ -135,3 +139,23 @@ def faceswaplab_api(_: gr.Blocks, app: FastAPI) -> None:
|
||||||
result_images = [encode_to_base64(img) for img in faces]
|
result_images = [encode_to_base64(img) for img in faces]
|
||||||
response = api_utils.FaceSwapExtractResponse(images=result_images)
|
response = api_utils.FaceSwapExtractResponse(images=result_images)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
@app.post(
|
||||||
|
"/faceswaplab/build",
|
||||||
|
tags=["faceswaplab"],
|
||||||
|
description="Build a face checkpoint using base64 images, return base64 satetensors",
|
||||||
|
)
|
||||||
|
async def build(base64_images: List[str]) -> Optional[str]:
|
||||||
|
if len(base64_images) > 0:
|
||||||
|
pil_images = [base64_to_pil(img) for img in base64_images]
|
||||||
|
with tempfile.NamedTemporaryFile(
|
||||||
|
delete=True, suffix=".safetensors"
|
||||||
|
) as temp_file:
|
||||||
|
build_face_checkpoint_and_save(
|
||||||
|
images=pil_images,
|
||||||
|
name="api_ckpt",
|
||||||
|
overwrite=True,
|
||||||
|
path=temp_file.name,
|
||||||
|
)
|
||||||
|
return api_utils.safetensors_to_base64(temp_file.name)
|
||||||
|
return None
|
||||||
|
|
|
||||||
|
|
@ -468,12 +468,12 @@ def get_or_default(l: List[Any], index: int, default: Any) -> Any:
|
||||||
return l[index] if index < len(l) else default
|
return l[index] if index < len(l) else default
|
||||||
|
|
||||||
|
|
||||||
def get_faces_from_img_files(files: List[str]) -> List[Optional[CV2ImgU8]]:
|
def get_faces_from_img_files(images: List[PILImage]) -> List[Optional[CV2ImgU8]]:
|
||||||
"""
|
"""
|
||||||
Extracts faces from a list of image files.
|
Extracts faces from a list of image files.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
files (list): A list of file objects representing image files.
|
images (list): A list of PILImage objects representing image files.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list: A list of detected faces.
|
list: A list of detected faces.
|
||||||
|
|
@ -482,9 +482,8 @@ def get_faces_from_img_files(files: List[str]) -> List[Optional[CV2ImgU8]]:
|
||||||
|
|
||||||
faces = []
|
faces = []
|
||||||
|
|
||||||
if len(files) > 0:
|
if len(images) > 0:
|
||||||
for file in files:
|
for img in images:
|
||||||
img = Image.open(file) # Open the image file
|
|
||||||
face = get_or_default(
|
face = get_or_default(
|
||||||
get_faces(pil_to_cv2(img)), 0, None
|
get_faces(pil_to_cv2(img)), 0, None
|
||||||
) # Extract faces from the image
|
) # Extract faces from the image
|
||||||
|
|
|
||||||
|
|
@ -153,9 +153,9 @@ def build_face_checkpoint_and_save(
|
||||||
if not batch_files:
|
if not batch_files:
|
||||||
logger.error("No face found")
|
logger.error("No face found")
|
||||||
return None
|
return None
|
||||||
filenames = [x.name for x in batch_files]
|
images = [Image.open(file.name) for file in batch_files]
|
||||||
preview_image = face_checkpoints_utils.build_face_checkpoint_and_save(
|
preview_image = face_checkpoints_utils.build_face_checkpoint_and_save(
|
||||||
filenames, name, overwrite=overwrite
|
images, name, overwrite=overwrite
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Failed to build checkpoint %s", e)
|
logger.error("Failed to build checkpoint %s", e)
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,7 @@ def sanitize_name(name: str) -> str:
|
||||||
|
|
||||||
|
|
||||||
def build_face_checkpoint_and_save(
|
def build_face_checkpoint_and_save(
|
||||||
batch_files: List[str], name: str, overwrite: bool = False
|
images: List[PILImage], name: str, overwrite: bool = False, path: str = None
|
||||||
) -> PILImage:
|
) -> PILImage:
|
||||||
"""
|
"""
|
||||||
Builds a face checkpoint using the provided image files, performs face swapping,
|
Builds a face checkpoint using the provided image files, performs face swapping,
|
||||||
|
|
@ -55,9 +55,9 @@ def build_face_checkpoint_and_save(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
name = sanitize_name(name)
|
name = sanitize_name(name)
|
||||||
batch_files = batch_files or []
|
images = images or []
|
||||||
logger.info("Build %s %s", name, [x for x in batch_files])
|
logger.info("Build %s with %s images", name, len(images))
|
||||||
faces = swapper.get_faces_from_img_files(batch_files)
|
faces = swapper.get_faces_from_img_files(images)
|
||||||
blended_face = swapper.blend_faces(faces)
|
blended_face = swapper.blend_faces(faces)
|
||||||
preview_path = os.path.join(
|
preview_path = os.path.join(
|
||||||
scripts.basedir(), "extensions", "sd-webui-faceswaplab", "references"
|
scripts.basedir(), "extensions", "sd-webui-faceswaplab", "references"
|
||||||
|
|
@ -95,14 +95,17 @@ def build_face_checkpoint_and_save(
|
||||||
)
|
)
|
||||||
preview_image = result.image
|
preview_image = result.image
|
||||||
|
|
||||||
file_path = os.path.join(get_checkpoint_path(), f"{name}.safetensors")
|
if path:
|
||||||
if not overwrite:
|
file_path = path
|
||||||
file_number = 1
|
else:
|
||||||
while os.path.exists(file_path):
|
file_path = os.path.join(get_checkpoint_path(), f"{name}.safetensors")
|
||||||
file_path = os.path.join(
|
if not overwrite:
|
||||||
get_checkpoint_path(), f"{name}_{file_number}.safetensors"
|
file_number = 1
|
||||||
)
|
while os.path.exists(file_path):
|
||||||
file_number += 1
|
file_path = os.path.join(
|
||||||
|
get_checkpoint_path(), f"{name}_{file_number}.safetensors"
|
||||||
|
)
|
||||||
|
file_number += 1
|
||||||
save_face(filename=file_path, face=blended_face)
|
save_face(filename=file_path, face=blended_face)
|
||||||
preview_image.save(file_path + ".png")
|
preview_image.save(file_path + ".png")
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -2,22 +2,28 @@ from typing import List
|
||||||
import pytest
|
import pytest
|
||||||
import requests
|
import requests
|
||||||
import sys
|
import sys
|
||||||
|
import tempfile
|
||||||
|
import safetensors
|
||||||
|
|
||||||
sys.path.append(".")
|
sys.path.append(".")
|
||||||
|
|
||||||
|
import requests
|
||||||
from client_api.api_utils import (
|
from client_api.api_utils import (
|
||||||
FaceSwapUnit,
|
FaceSwapUnit,
|
||||||
FaceSwapResponse,
|
InswappperOptions,
|
||||||
PostProcessingOptions,
|
|
||||||
FaceSwapRequest,
|
|
||||||
base64_to_pil,
|
|
||||||
pil_to_base64,
|
pil_to_base64,
|
||||||
|
PostProcessingOptions,
|
||||||
InpaintingWhen,
|
InpaintingWhen,
|
||||||
FaceSwapCompareRequest,
|
InpaintingOptions,
|
||||||
|
FaceSwapRequest,
|
||||||
|
FaceSwapResponse,
|
||||||
FaceSwapExtractRequest,
|
FaceSwapExtractRequest,
|
||||||
|
FaceSwapCompareRequest,
|
||||||
FaceSwapExtractResponse,
|
FaceSwapExtractResponse,
|
||||||
compare_faces,
|
compare_faces,
|
||||||
InpaintingOptions,
|
base64_to_pil,
|
||||||
|
base64_to_safetensors,
|
||||||
|
safetensors_to_base64,
|
||||||
)
|
)
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
|
@ -37,6 +43,13 @@ def face_swap_request() -> FaceSwapRequest:
|
||||||
source_img=pil_to_base64("references/woman.png"), # The face you want to use
|
source_img=pil_to_base64("references/woman.png"), # The face you want to use
|
||||||
same_gender=True,
|
same_gender=True,
|
||||||
faces_index=(0,), # Replace first woman since same gender is on
|
faces_index=(0,), # Replace first woman since same gender is on
|
||||||
|
swapping_options=InswappperOptions(
|
||||||
|
face_restorer_name="CodeFormer",
|
||||||
|
upscaler_name="LDSR",
|
||||||
|
improved_mask=True,
|
||||||
|
sharpen=True,
|
||||||
|
color_corrections=True,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Post-processing config
|
# Post-processing config
|
||||||
|
|
@ -179,3 +192,86 @@ def test_faceswap_inpainting(face_swap_request: FaceSwapRequest) -> None:
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert "images" in data
|
assert "images" in data
|
||||||
assert "infos" in data
|
assert "infos" in data
|
||||||
|
|
||||||
|
|
||||||
|
def test_faceswap_checkpoint_building() -> None:
|
||||||
|
source_images: List[str] = [
|
||||||
|
pil_to_base64("references/man.png"),
|
||||||
|
pil_to_base64("references/woman.png"),
|
||||||
|
]
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
url=f"{base_url}/faceswaplab/build",
|
||||||
|
json=source_images,
|
||||||
|
headers={"Content-Type": "application/json; charset=utf-8"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(delete=True) as temp_file:
|
||||||
|
base64_to_safetensors(response.json(), output_path=temp_file.name)
|
||||||
|
with safetensors.safe_open(temp_file.name, framework="pt") as f:
|
||||||
|
assert "age" in f.keys()
|
||||||
|
assert "gender" in f.keys()
|
||||||
|
assert "embedding" in f.keys()
|
||||||
|
|
||||||
|
|
||||||
|
def test_faceswap_checkpoint_building_and_using() -> None:
|
||||||
|
source_images: List[str] = [
|
||||||
|
pil_to_base64("references/man.png"),
|
||||||
|
]
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
url=f"{base_url}/faceswaplab/build",
|
||||||
|
json=source_images,
|
||||||
|
headers={"Content-Type": "application/json; charset=utf-8"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(delete=True) as temp_file:
|
||||||
|
base64_to_safetensors(response.json(), output_path=temp_file.name)
|
||||||
|
with safetensors.safe_open(temp_file.name, framework="pt") as f:
|
||||||
|
assert "age" in f.keys()
|
||||||
|
assert "gender" in f.keys()
|
||||||
|
assert "embedding" in f.keys()
|
||||||
|
|
||||||
|
# First face unit :
|
||||||
|
unit1 = FaceSwapUnit(
|
||||||
|
source_face=safetensors_to_base64(
|
||||||
|
temp_file.name
|
||||||
|
), # convert the checkpoint to base64
|
||||||
|
faces_index=(0,), # Replace first face
|
||||||
|
swapping_options=InswappperOptions(
|
||||||
|
face_restorer_name="CodeFormer",
|
||||||
|
upscaler_name="LDSR",
|
||||||
|
improved_mask=True,
|
||||||
|
sharpen=True,
|
||||||
|
color_corrections=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare the request
|
||||||
|
request = FaceSwapRequest(
|
||||||
|
image=pil_to_base64("tests/test_image.png"), units=[unit1]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Face Swap
|
||||||
|
response = requests.post(
|
||||||
|
url=f"{base_url}/faceswaplab/swap_face",
|
||||||
|
data=request.json(),
|
||||||
|
headers={"Content-Type": "application/json; charset=utf-8"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
fsr = FaceSwapResponse.parse_obj(response.json())
|
||||||
|
data = response.json()
|
||||||
|
assert "images" in data
|
||||||
|
assert "infos" in data
|
||||||
|
|
||||||
|
# First face is the man
|
||||||
|
assert (
|
||||||
|
compare_faces(
|
||||||
|
fsr.pil_images[0], Image.open("references/man.png"), base_url=base_url
|
||||||
|
)
|
||||||
|
> 0.5
|
||||||
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue