mirror of https://github.com/Bing-su/adetailer.git
feat: change unsafe pickling
parent
589412052d
commit
06100063b3
|
|
@ -1,15 +1,17 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from typing import TYPE_CHECKING, Any, Union
|
from typing import TYPE_CHECKING, Any, Union
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from typing_extensions import Protocol
|
from typing_extensions import Protocol
|
||||||
|
|
||||||
from modules import safe
|
from modules import safe
|
||||||
from modules.shared import opts
|
from modules.shared import cmd_opts, opts
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
# 타입 체커가 빨간 줄을 긋지 않게 하는 편법
|
# 타입 체커가 빨간 줄을 긋지 않게 하는 편법
|
||||||
|
|
@ -36,6 +38,15 @@ def change_torch_load():
|
||||||
torch.load = orig
|
torch.load = orig
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def disable_safe_unpickle():
|
||||||
|
with (
|
||||||
|
patch.dict(os.environ, {"TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD": "1"}, clear=False),
|
||||||
|
patch.object(cmd_opts, "disable_safe_unpickle", True),
|
||||||
|
):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def pause_total_tqdm():
|
def pause_total_tqdm():
|
||||||
orig = opts.data.get("multiple_tqdm", True)
|
orig = opts.data.get("multiple_tqdm", True)
|
||||||
|
|
|
||||||
|
|
@ -18,8 +18,8 @@ import modules
|
||||||
from aaaaaa.conditional import create_binary_mask, schedulers
|
from aaaaaa.conditional import create_binary_mask, schedulers
|
||||||
from aaaaaa.helper import (
|
from aaaaaa.helper import (
|
||||||
PPImage,
|
PPImage,
|
||||||
change_torch_load,
|
|
||||||
copy_extra_params,
|
copy_extra_params,
|
||||||
|
disable_safe_unpickle,
|
||||||
pause_total_tqdm,
|
pause_total_tqdm,
|
||||||
preserve_prompts,
|
preserve_prompts,
|
||||||
)
|
)
|
||||||
|
|
@ -825,8 +825,8 @@ class AfterDetailerScript(scripts.Script):
|
||||||
pred = mediapipe_predict(args.ad_model, pp.image, args.ad_confidence)
|
pred = mediapipe_predict(args.ad_model, pp.image, args.ad_confidence)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
with change_torch_load():
|
|
||||||
ad_model = self.get_ad_model(args.ad_model)
|
ad_model = self.get_ad_model(args.ad_model)
|
||||||
|
with disable_safe_unpickle():
|
||||||
pred = ultralytics_predict(
|
pred = ultralytics_predict(
|
||||||
ad_model,
|
ad_model,
|
||||||
image=pp.image,
|
image=pp.image,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue