refactor: refactor some functions

pull/601/head^2
Dowon 2024-05-15 22:13:06 +09:00
parent e05104a220
commit f12f66c298
4 changed files with 36 additions and 5 deletions

View File

@ -28,3 +28,7 @@ def get_i(p) -> int:
bs = p.batch_size bs = p.batch_size
i = p.batch_index i = p.batch_index
return it * bs + i return it * bs + i
def is_skip_img2img(p) -> bool:
return getattr(p, "_ad_skip_img2img", False)

View File

@ -200,6 +200,12 @@ class ADetailerArgs(BaseModel, extra=Extra.forbid):
return p return p
def is_mediapipe(self) -> bool:
return self.ad_model.lower().startswith("mediapipe")
def need_skip(self) -> bool:
return self.ad_model == "None"
_all_args = [ _all_args = [
("ad_model", "ADetailer model"), ("ad_model", "ADetailer model"),

View File

@ -26,6 +26,7 @@ from aaaaaa.p_method import (
get_i, get_i,
is_img2img_inpaint, is_img2img_inpaint,
is_inpaint_only_masked, is_inpaint_only_masked,
is_skip_img2img,
need_call_postprocess, need_call_postprocess,
need_call_process, need_call_process,
) )
@ -625,7 +626,7 @@ class AfterDetailerScript(scripts.Script):
@staticmethod @staticmethod
def get_i2i_init_image(p, pp): def get_i2i_init_image(p, pp):
if getattr(p, "_ad_skip_img2img", False): if is_skip_img2img(p):
return p.init_images[0] return p.init_images[0]
return pp.image return pp.image
@ -649,7 +650,7 @@ class AfterDetailerScript(scripts.Script):
mask = ImageChops.invert(mask) mask = ImageChops.invert(mask)
mask = create_binary_mask(mask) mask = create_binary_mask(mask)
if getattr(p, "_ad_skip_img2img", False): if is_skip_img2img(p):
if hasattr(p, "init_images") and p.init_images: if hasattr(p, "init_images") and p.init_images:
width, height = p.init_images[0].size width, height = p.init_images[0].size
else: else:
@ -712,7 +713,7 @@ class AfterDetailerScript(scripts.Script):
seed, subseed = self.get_seed(p) seed, subseed = self.get_seed(p)
ad_prompts, ad_negatives = self.get_prompt(p, args) ad_prompts, ad_negatives = self.get_prompt(p, args)
is_mediapipe = args.ad_model.lower().startswith("mediapipe") is_mediapipe = args.is_mediapipe()
kwargs = {} kwargs = {}
if is_mediapipe: if is_mediapipe:
@ -800,11 +801,11 @@ class AfterDetailerScript(scripts.Script):
is_processed = False is_processed = False
with CNHijackRestore(), pause_total_tqdm(), cn_allow_script_control(): with CNHijackRestore(), pause_total_tqdm(), cn_allow_script_control():
for n, args in enumerate(arg_list): for n, args in enumerate(arg_list):
if args.ad_model == "None": if args.need_skip():
continue continue
is_processed |= self._postprocess_image_inner(p, pp, args, n=n) is_processed |= self._postprocess_image_inner(p, pp, args, n=n)
if is_processed and not getattr(p, "_ad_skip_img2img", False): if is_processed and not is_skip_img2img(p):
self.save_image( self.save_image(
p, init_image, condition="ad_save_images_before", suffix="-ad-before" p, init_image, condition="ad_save_images_before", suffix="-ad-before"
) )

View File

@ -1,5 +1,7 @@
from __future__ import annotations from __future__ import annotations
import pytest
from adetailer.args import ALL_ARGS, ADetailerArgs from adetailer.args import ALL_ARGS, ADetailerArgs
@ -12,3 +14,21 @@ def test_all_args() -> None:
if attr == "is_api": if attr == "is_api":
continue continue
assert attr in ALL_ARGS.attrs, attr assert attr in ALL_ARGS.attrs, attr
@pytest.mark.parametrize(
("ad_model", "expect"),
[("mediapipe_face_full", True), ("face_yolov8n.pt", False)],
)
def test_is_mediapipe(ad_model: str, expect: bool) -> None:
args = ADetailerArgs(ad_model=ad_model)
assert args.is_mediapipe() is expect
@pytest.mark.parametrize(
("ad_model", "expect"),
[("mediapipe_face_full", False), ("face_yolov8n.pt", False), ("None", True)],
)
def test_need_skip(ad_model: str, expect: bool) -> None:
args = ADetailerArgs(ad_model=ad_model)
assert args.need_skip() is expect