fix(scripts): move contextmanagers

pull/595/head
Dowon 2024-04-13 16:04:40 +09:00
parent 89ee330271
commit 7d7dfb76a5
4 changed files with 59 additions and 35 deletions

0
aaaaaa/__init__.py Normal file
View File

55
aaaaaa/helper.py Normal file
View File

@ -0,0 +1,55 @@
from __future__ import annotations
from contextlib import contextmanager
from copy import copy
from typing import TYPE_CHECKING
import torch
from modules import safe
from modules.shared import opts
if TYPE_CHECKING:
# 타입 체커가 빨간 줄을 긋지 않게 하는 편법
from types import SimpleNamespace
StableDiffusionProcessingTxt2Img = SimpleNamespace
StableDiffusionProcessingImg2Img = SimpleNamespace
else:
from modules.processing import (
StableDiffusionProcessingImg2Img,
StableDiffusionProcessingTxt2Img,
)
PT = StableDiffusionProcessingTxt2Img | StableDiffusionProcessingImg2Img
@contextmanager
def change_torch_load():
orig = torch.load
try:
torch.load = safe.unsafe_torch_load
yield
finally:
torch.load = orig
@contextmanager
def pause_total_tqdm():
orig = opts.data.get("multiple_tqdm", True)
try:
opts.data["multiple_tqdm"] = False
yield
finally:
opts.data["multiple_tqdm"] = orig
@contextmanager
def preseve_prompts(p: PT):
all_pt = copy(p.all_prompts)
all_ng = copy(p.all_negative_prompts)
try:
yield
finally:
p.all_prompts = all_pt
p.all_negative_prompts = all_ng

View File

@ -1 +1 @@
__version__ = "24.4.0" __version__ = "24.4.1-dev.0"

View File

@ -4,7 +4,7 @@ import platform
import re import re
import sys import sys
import traceback import traceback
from contextlib import contextmanager, suppress from contextlib import suppress
from copy import copy from copy import copy
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
@ -12,11 +12,11 @@ from textwrap import dedent
from typing import TYPE_CHECKING, Any, NamedTuple, cast from typing import TYPE_CHECKING, Any, NamedTuple, cast
import gradio as gr import gradio as gr
import torch
from PIL import Image, ImageChops from PIL import Image, ImageChops
from rich import print from rich import print
import modules import modules
from aaaaaa.helper import change_torch_load, pause_total_tqdm, preseve_prompts
from adetailer import ( from adetailer import (
AFTER_DETAILER, AFTER_DETAILER,
__version__, __version__,
@ -44,7 +44,7 @@ from controlnet_ext import (
controlnet_type, controlnet_type,
get_cn_models, get_cn_models,
) )
from modules import images, paths, safe, script_callbacks, scripts, shared from modules import images, paths, script_callbacks, scripts, shared
from modules.devices import NansException from modules.devices import NansException
from modules.processing import ( from modules.processing import (
Processed, Processed,
@ -86,37 +86,6 @@ print(
) )
@contextmanager
def change_torch_load():
orig = torch.load
try:
torch.load = safe.unsafe_torch_load
yield
finally:
torch.load = orig
@contextmanager
def pause_total_tqdm():
orig = opts.data.get("multiple_tqdm", True)
try:
opts.data["multiple_tqdm"] = False
yield
finally:
opts.data["multiple_tqdm"] = orig
@contextmanager
def preseve_prompts(p):
all_pt = copy(p.all_prompts)
all_ng = copy(p.all_negative_prompts)
try:
yield
finally:
p.all_prompts = all_pt
p.all_negative_prompts = all_ng
class AfterDetailerScript(scripts.Script): class AfterDetailerScript(scripts.Script):
def __init__(self): def __init__(self):
super().__init__() super().__init__()