mirror of https://github.com/Bing-su/adetailer.git
refactor: refactor some functions
parent
e05104a220
commit
f12f66c298
|
|
@ -28,3 +28,7 @@ def get_i(p) -> int:
|
||||||
bs = p.batch_size
|
bs = p.batch_size
|
||||||
i = p.batch_index
|
i = p.batch_index
|
||||||
return it * bs + i
|
return it * bs + i
|
||||||
|
|
||||||
|
|
||||||
|
def is_skip_img2img(p) -> bool:
|
||||||
|
return getattr(p, "_ad_skip_img2img", False)
|
||||||
|
|
|
||||||
|
|
@ -200,6 +200,12 @@ class ADetailerArgs(BaseModel, extra=Extra.forbid):
|
||||||
|
|
||||||
return p
|
return p
|
||||||
|
|
||||||
|
def is_mediapipe(self) -> bool:
|
||||||
|
return self.ad_model.lower().startswith("mediapipe")
|
||||||
|
|
||||||
|
def need_skip(self) -> bool:
|
||||||
|
return self.ad_model == "None"
|
||||||
|
|
||||||
|
|
||||||
_all_args = [
|
_all_args = [
|
||||||
("ad_model", "ADetailer model"),
|
("ad_model", "ADetailer model"),
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,7 @@ from aaaaaa.p_method import (
|
||||||
get_i,
|
get_i,
|
||||||
is_img2img_inpaint,
|
is_img2img_inpaint,
|
||||||
is_inpaint_only_masked,
|
is_inpaint_only_masked,
|
||||||
|
is_skip_img2img,
|
||||||
need_call_postprocess,
|
need_call_postprocess,
|
||||||
need_call_process,
|
need_call_process,
|
||||||
)
|
)
|
||||||
|
|
@ -625,7 +626,7 @@ class AfterDetailerScript(scripts.Script):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_i2i_init_image(p, pp):
|
def get_i2i_init_image(p, pp):
|
||||||
if getattr(p, "_ad_skip_img2img", False):
|
if is_skip_img2img(p):
|
||||||
return p.init_images[0]
|
return p.init_images[0]
|
||||||
return pp.image
|
return pp.image
|
||||||
|
|
||||||
|
|
@ -649,7 +650,7 @@ class AfterDetailerScript(scripts.Script):
|
||||||
mask = ImageChops.invert(mask)
|
mask = ImageChops.invert(mask)
|
||||||
mask = create_binary_mask(mask)
|
mask = create_binary_mask(mask)
|
||||||
|
|
||||||
if getattr(p, "_ad_skip_img2img", False):
|
if is_skip_img2img(p):
|
||||||
if hasattr(p, "init_images") and p.init_images:
|
if hasattr(p, "init_images") and p.init_images:
|
||||||
width, height = p.init_images[0].size
|
width, height = p.init_images[0].size
|
||||||
else:
|
else:
|
||||||
|
|
@ -712,7 +713,7 @@ class AfterDetailerScript(scripts.Script):
|
||||||
seed, subseed = self.get_seed(p)
|
seed, subseed = self.get_seed(p)
|
||||||
ad_prompts, ad_negatives = self.get_prompt(p, args)
|
ad_prompts, ad_negatives = self.get_prompt(p, args)
|
||||||
|
|
||||||
is_mediapipe = args.ad_model.lower().startswith("mediapipe")
|
is_mediapipe = args.is_mediapipe()
|
||||||
|
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if is_mediapipe:
|
if is_mediapipe:
|
||||||
|
|
@ -800,11 +801,11 @@ class AfterDetailerScript(scripts.Script):
|
||||||
is_processed = False
|
is_processed = False
|
||||||
with CNHijackRestore(), pause_total_tqdm(), cn_allow_script_control():
|
with CNHijackRestore(), pause_total_tqdm(), cn_allow_script_control():
|
||||||
for n, args in enumerate(arg_list):
|
for n, args in enumerate(arg_list):
|
||||||
if args.ad_model == "None":
|
if args.need_skip():
|
||||||
continue
|
continue
|
||||||
is_processed |= self._postprocess_image_inner(p, pp, args, n=n)
|
is_processed |= self._postprocess_image_inner(p, pp, args, n=n)
|
||||||
|
|
||||||
if is_processed and not getattr(p, "_ad_skip_img2img", False):
|
if is_processed and not is_skip_img2img(p):
|
||||||
self.save_image(
|
self.save_image(
|
||||||
p, init_image, condition="ad_save_images_before", suffix="-ad-before"
|
p, init_image, condition="ad_save_images_before", suffix="-ad-before"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from adetailer.args import ALL_ARGS, ADetailerArgs
|
from adetailer.args import ALL_ARGS, ADetailerArgs
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -12,3 +14,21 @@ def test_all_args() -> None:
|
||||||
if attr == "is_api":
|
if attr == "is_api":
|
||||||
continue
|
continue
|
||||||
assert attr in ALL_ARGS.attrs, attr
|
assert attr in ALL_ARGS.attrs, attr
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("ad_model", "expect"),
|
||||||
|
[("mediapipe_face_full", True), ("face_yolov8n.pt", False)],
|
||||||
|
)
|
||||||
|
def test_is_mediapipe(ad_model: str, expect: bool) -> None:
|
||||||
|
args = ADetailerArgs(ad_model=ad_model)
|
||||||
|
assert args.is_mediapipe() is expect
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("ad_model", "expect"),
|
||||||
|
[("mediapipe_face_full", False), ("face_yolov8n.pt", False), ("None", True)],
|
||||||
|
)
|
||||||
|
def test_need_skip(ad_model: str, expect: bool) -> None:
|
||||||
|
args = ADetailerArgs(ad_model=ad_model)
|
||||||
|
assert args.need_skip() is expect
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue