chore: lint existing code (#47)
* chore: add flake8 and black for linting code * chore: lint all files * chore: add lint workflowpull/13/head
parent
e790ac2e3c
commit
72e28ca8aa
|
|
@ -0,0 +1,3 @@
|
|||
[flake8]
|
||||
max-line-length = 88
|
||||
exclude = .git,__pycache__,old,build,dist
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
name: Lint
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
pull_request:
|
||||
jobs:
|
||||
lint:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
- uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: 3.10
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r requirements.txt
|
||||
|
||||
- name: Lint with flake8
|
||||
run: |
|
||||
flake8 . --count --show-source --statistics
|
||||
|
||||
- name: Lint with black
|
||||
run: |
|
||||
black --check .
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
import launch
|
||||
|
||||
if not launch.is_installed("diffusers"):
|
||||
launch.run_pip(f"install diffusers", "diffusers") # NSFW filter
|
||||
launch.run_pip(f"install aiohttp", "aiohttp") # asynchroneous HTTP requests
|
||||
launch.run_pip("install diffusers", "diffusers") # NSFW filter
|
||||
launch.run_pip("install aiohttp", "aiohttp") # asynchroneous HTTP requests
|
||||
|
|
|
|||
|
|
@ -0,0 +1,3 @@
|
|||
# Formatters and linters
|
||||
black ~= 22.0
|
||||
flake8
|
||||
|
|
@ -15,22 +15,52 @@ def on_app_started(demo: Optional[gr.Blocks], app: FastAPI):
|
|||
horde = StableHorde(config)
|
||||
|
||||
import gradio.utils
|
||||
|
||||
gradio.utils.synchronize_async(horde.run)
|
||||
|
||||
|
||||
def on_ui_settings():
|
||||
section = ('stable-horde', 'Stable Horde')
|
||||
shared.opts.add_option('stable_horde_enable', shared.OptionInfo(False, 'Enable', section=section))
|
||||
shared.opts.add_option('stable_horde_endpoint', shared.OptionInfo('https://stablehorde.net/', 'Endpoint', section=section))
|
||||
shared.opts.add_option('stable_horde_apikey', shared.OptionInfo('', 'API Key', section=section))
|
||||
shared.opts.add_option('stable_horde_name', shared.OptionInfo('Stable Horde', 'Worker Name', section=section))
|
||||
shared.opts.add_option('stable_horde_model', shared.OptionInfo('Anything Diffusion', 'Model', section=section))
|
||||
shared.opts.add_option('stable_horde_nsfw', shared.OptionInfo(False, 'NSFW', section=section))
|
||||
shared.opts.add_option('stable_horde_interval', shared.OptionInfo(10, 'Interval', section=section))
|
||||
shared.opts.add_option('stable_horde_max_pixels', shared.OptionInfo(1024 * 1024, 'Max Pixels', section=section))
|
||||
shared.opts.add_option('stable_horde_allow_img2img', shared.OptionInfo(True, 'Allow img2img', section=section))
|
||||
shared.opts.add_option('stable_horde_allow_painting', shared.OptionInfo(True, 'Allow Painting', section=section))
|
||||
shared.opts.add_option('stable_horde_allow_unsafe_ipaddr', shared.OptionInfo(True, 'Allow Unsafe IP Address', section=section))
|
||||
section = ("stable-horde", "Stable Horde")
|
||||
shared.opts.add_option(
|
||||
"stable_horde_enable", shared.OptionInfo(False, "Enable", section=section)
|
||||
)
|
||||
shared.opts.add_option(
|
||||
"stable_horde_endpoint",
|
||||
shared.OptionInfo("https://stablehorde.net/", "Endpoint", section=section),
|
||||
)
|
||||
shared.opts.add_option(
|
||||
"stable_horde_apikey", shared.OptionInfo("", "API Key", section=section)
|
||||
)
|
||||
shared.opts.add_option(
|
||||
"stable_horde_name",
|
||||
shared.OptionInfo("Stable Horde", "Worker Name", section=section),
|
||||
)
|
||||
shared.opts.add_option(
|
||||
"stable_horde_model",
|
||||
shared.OptionInfo("Anything Diffusion", "Model", section=section),
|
||||
)
|
||||
shared.opts.add_option(
|
||||
"stable_horde_nsfw", shared.OptionInfo(False, "NSFW", section=section)
|
||||
)
|
||||
shared.opts.add_option(
|
||||
"stable_horde_interval", shared.OptionInfo(10, "Interval", section=section)
|
||||
)
|
||||
shared.opts.add_option(
|
||||
"stable_horde_max_pixels",
|
||||
shared.OptionInfo(1024 * 1024, "Max Pixels", section=section),
|
||||
)
|
||||
shared.opts.add_option(
|
||||
"stable_horde_allow_img2img",
|
||||
shared.OptionInfo(True, "Allow img2img", section=section),
|
||||
)
|
||||
shared.opts.add_option(
|
||||
"stable_horde_allow_painting",
|
||||
shared.OptionInfo(True, "Allow Painting", section=section),
|
||||
)
|
||||
shared.opts.add_option(
|
||||
"stable_horde_allow_unsafe_ipaddr",
|
||||
shared.OptionInfo(True, "Allow Unsafe IP Address", section=section),
|
||||
)
|
||||
|
||||
|
||||
script_callbacks.on_app_started(on_app_started)
|
||||
|
|
|
|||
255
stable_horde.py
255
stable_horde.py
|
|
@ -8,13 +8,25 @@ from typing import Any, Dict, List, Optional
|
|||
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import (
|
||||
StableDiffusionSafetyChecker,
|
||||
)
|
||||
from PIL import Image
|
||||
from transformers.models.auto.feature_extraction_auto import AutoFeatureExtractor
|
||||
|
||||
from modules import shared, call_queue, txt2img, img2img, processing, sd_models, sd_samplers, scripts
|
||||
from modules import (
|
||||
shared,
|
||||
call_queue,
|
||||
txt2img,
|
||||
img2img,
|
||||
processing,
|
||||
sd_models,
|
||||
sd_samplers,
|
||||
)
|
||||
|
||||
stable_horde_supported_models_url = "https://raw.githubusercontent.com/Sygil-Dev/nataili-model-reference/main/db.json"
|
||||
stable_horde_supported_models_url = (
|
||||
"https://raw.githubusercontent.com/Sygil-Dev/nataili-model-reference/main/db.json"
|
||||
)
|
||||
|
||||
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
||||
safety_feature_extractor = None
|
||||
|
|
@ -29,7 +41,7 @@ class StableHordeConfig:
|
|||
@property
|
||||
def endpoint(self) -> str:
|
||||
return shared.opts.stable_horde_endpoint
|
||||
|
||||
|
||||
@property
|
||||
def apikey(self) -> str:
|
||||
return shared.opts.stable_horde_apikey
|
||||
|
|
@ -49,11 +61,11 @@ class StableHordeConfig:
|
|||
@property
|
||||
def allow_img2img(self) -> bool:
|
||||
return shared.opts.stable_horde_allow_img2img
|
||||
|
||||
|
||||
@property
|
||||
def allow_painting(self) -> bool:
|
||||
return shared.opts.stable_horde_allow_painting
|
||||
|
||||
|
||||
@property
|
||||
def allow_unsafe_ipaddr(self) -> bool:
|
||||
return shared.opts.stable_horde_allow_unsafe_ipaddr
|
||||
|
|
@ -62,7 +74,31 @@ class StableHordeConfig:
|
|||
class HordeJob:
|
||||
retry_interval: int = 1
|
||||
|
||||
def __init__(self, session: aiohttp.ClientSession, id: str, model: str, prompt: str, negative_prompt: str, sampler: str, cfg_scale: float, seed: int, denoising_strength: float, n_iter: int, height: int, width: int, subseed: int, steps: int, karras: bool, tiling: bool, postprocessors: List[str], nsfw_censor: bool = False, source_image: Optional[Image.Image] = None, source_processing: Optional[str] = "img2img", source_mask: Optional[Image.Image] = None, r2_upload: Optional[str] = None):
|
||||
def __init__(
|
||||
self,
|
||||
session: aiohttp.ClientSession,
|
||||
id: str,
|
||||
model: str,
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
sampler: str,
|
||||
cfg_scale: float,
|
||||
seed: int,
|
||||
denoising_strength: float,
|
||||
n_iter: int,
|
||||
height: int,
|
||||
width: int,
|
||||
subseed: int,
|
||||
steps: int,
|
||||
karras: bool,
|
||||
tiling: bool,
|
||||
postprocessors: List[str],
|
||||
nsfw_censor: bool = False,
|
||||
source_image: Optional[Image.Image] = None,
|
||||
source_processing: Optional[str] = "img2img",
|
||||
source_mask: Optional[Image.Image] = None,
|
||||
r2_upload: Optional[str] = None,
|
||||
):
|
||||
self.id = id
|
||||
self.model = model
|
||||
self.prompt = prompt
|
||||
|
|
@ -82,7 +118,9 @@ class HordeJob:
|
|||
self.postprocessors = postprocessors
|
||||
self.nsfw_censor = nsfw_censor
|
||||
self.source_image = source_image
|
||||
self.source_processing = source_processing # "img2img", "inpainting", "outpainting"
|
||||
self.source_processing = (
|
||||
source_processing # "img2img", "inpainting", "outpainting"
|
||||
)
|
||||
self.source_mask = source_mask
|
||||
self.r2_upload = r2_upload
|
||||
|
||||
|
|
@ -115,7 +153,7 @@ class HordeJob:
|
|||
attempts = 10
|
||||
while attempts > 0:
|
||||
try:
|
||||
r = await session.post('/api/v2/generate/submit', json=post_data)
|
||||
r = await session.post("/api/v2/generate/submit", json=post_data)
|
||||
|
||||
try:
|
||||
res = await r.json()
|
||||
|
|
@ -127,20 +165,30 @@ class HordeJob:
|
|||
if r.ok:
|
||||
return res.get("reward", None)
|
||||
else:
|
||||
print(f"Failed to submit job with status code {r.status}: {res.get('message')}")
|
||||
print(
|
||||
"Failed to submit job with status code"
|
||||
+ f"{r.status}: {res.get('message')}"
|
||||
)
|
||||
return None
|
||||
except Exception:
|
||||
print(f"Error when decoding response, the server might be down.")
|
||||
print("Error when decoding response, the server might be down.")
|
||||
return None
|
||||
|
||||
|
||||
except aiohttp.ClientConnectorError:
|
||||
attempts -= 1
|
||||
await asyncio.sleep(self.retry_interval)
|
||||
continue
|
||||
|
||||
|
||||
@classmethod
|
||||
async def get(cls, session: aiohttp.ClientSession, config: StableHordeConfig, models: List[str]):
|
||||
async def get(
|
||||
cls,
|
||||
session: aiohttp.ClientSession,
|
||||
config: StableHordeConfig,
|
||||
models: List[str],
|
||||
):
|
||||
name = "Stable Horde Worker Bridge for Stable Diffusion WebUI"
|
||||
version = 10
|
||||
repo = "https://github.com/sdwebui-w-horde/sd-webui-stable-horde-worker"
|
||||
# https://stablehorde.net/api/
|
||||
post_data = {
|
||||
"name": config.name,
|
||||
|
|
@ -150,7 +198,7 @@ class HordeJob:
|
|||
"models": models,
|
||||
# TODO: add support for bridge version 11 "tiling"
|
||||
"bridge_version": 9,
|
||||
"bridge_agent": "Stable Horde Worker Bridge for Stable Diffusion WebUI:10:https://github.com/sdwebui-w-horde/sd-webui-stable-horde-worker",
|
||||
"bridge_agent": f"{name}:{version}:{repo}",
|
||||
"threads": 1,
|
||||
"max_pixels": config.max_pixels,
|
||||
"allow_img2img": config.allow_img2img,
|
||||
|
|
@ -158,77 +206,62 @@ class HordeJob:
|
|||
"allow_unsafe_ipaddr": config.allow_unsafe_ipaddr,
|
||||
}
|
||||
|
||||
r = await session.post('/api/v2/generate/pop', json=post_data)
|
||||
r = await session.post("/api/v2/generate/pop", json=post_data)
|
||||
|
||||
req = await r.json()
|
||||
|
||||
if r.status != 200:
|
||||
raise Exception(f"Failed to get job: {req.get('message')}")
|
||||
|
||||
if not req.get('id'):
|
||||
if not req.get("id"):
|
||||
return
|
||||
|
||||
payload = req.get('payload')
|
||||
prompt = payload.get('prompt')
|
||||
|
||||
payload = req.get("payload")
|
||||
prompt = payload.get("prompt")
|
||||
if "###" in prompt:
|
||||
prompt, negative = map(lambda x: x.strip(), prompt.split("###"))
|
||||
else:
|
||||
negative = ""
|
||||
|
||||
|
||||
def to_image(base64str: Optional[str]) -> Optional[Image.Image]:
|
||||
if not base64str:
|
||||
return None
|
||||
return Image.open(io.BytesIO(base64.b64decode(base64str)))
|
||||
|
||||
|
||||
return cls(
|
||||
session=session,
|
||||
id=req['id'],
|
||||
id=req["id"],
|
||||
prompt=prompt,
|
||||
negative_prompt=negative,
|
||||
sampler=payload.get('sampler_name'),
|
||||
cfg_scale=payload.get('cfg_scale', 5),
|
||||
seed=int(payload.get('seed', randint(0, 2**32))),
|
||||
denoising_strength=payload.get('denoising_strength', 0.75),
|
||||
n_iter=payload.get('n_iter', 1),
|
||||
height=payload['height'],
|
||||
width=payload['width'],
|
||||
subseed=payload.get('seed_variation', 1),
|
||||
steps=payload.get('ddim_steps', 30),
|
||||
karras=payload.get('karras', False),
|
||||
tiling=payload.get('tiling', False),
|
||||
postprocessors=payload.get('post_processing', []),
|
||||
nsfw_censor=payload.get('use_nsfw_censor', False),
|
||||
model=req['model'],
|
||||
source_image=to_image(payload.get('source_image')),
|
||||
source_processing=payload.get('source_processing'),
|
||||
source_mask=to_image(payload.get('source_mask')),
|
||||
r2_upload=payload.get('r2_upload'),
|
||||
sampler=payload.get("sampler_name"),
|
||||
cfg_scale=payload.get("cfg_scale", 5),
|
||||
seed=int(payload.get("seed", randint(0, 2**32))),
|
||||
denoising_strength=payload.get("denoising_strength", 0.75),
|
||||
n_iter=payload.get("n_iter", 1),
|
||||
height=payload["height"],
|
||||
width=payload["width"],
|
||||
subseed=payload.get("seed_variation", 1),
|
||||
steps=payload.get("ddim_steps", 30),
|
||||
karras=payload.get("karras", False),
|
||||
tiling=payload.get("tiling", False),
|
||||
postprocessors=payload.get("post_processing", []),
|
||||
nsfw_censor=payload.get("use_nsfw_censor", False),
|
||||
model=req["model"],
|
||||
source_image=to_image(payload.get("source_image")),
|
||||
source_processing=payload.get("source_processing"),
|
||||
source_mask=to_image(payload.get("source_mask")),
|
||||
r2_upload=payload.get("r2_upload"),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@property
|
||||
def allow_painting(self) -> bool:
|
||||
return shared.opts.stable_horde_allow_painting
|
||||
|
||||
@property
|
||||
def allow_unsafe_ipaddr(self) -> bool:
|
||||
return shared.opts.stable_horde_allow_unsafe_ipaddr
|
||||
|
||||
|
||||
class StableHorde:
|
||||
def __init__(self, config: StableHordeConfig):
|
||||
self.config = config
|
||||
headers = {
|
||||
"apikey": self.config.apikey,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
self.session: Optional[aiohttp.ClientSession] = None
|
||||
|
||||
self.sfw_request_censor = Image.open(path.join(self.config.basedir, "assets", "nsfw_censor_sfw_request.png"))
|
||||
self.sfw_request_censor = Image.open(
|
||||
path.join(self.config.basedir, "assets", "nsfw_censor_sfw_request.png")
|
||||
)
|
||||
|
||||
self.supported_models = []
|
||||
|
||||
|
|
@ -237,9 +270,9 @@ class StableHorde:
|
|||
if not path.exists(filepath):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(stable_horde_supported_models_url) as resp:
|
||||
with open(filepath, 'wb') as f:
|
||||
with open(filepath, "wb") as f:
|
||||
f.write(await resp.read())
|
||||
with open(filepath, 'r') as f:
|
||||
with open(filepath, "r") as f:
|
||||
supported_models: Dict[str, Any] = json.load(f)
|
||||
|
||||
self.supported_models = list(supported_models.values())
|
||||
|
|
@ -247,6 +280,7 @@ class StableHorde:
|
|||
def detect_current_model(self):
|
||||
def get_md5sum(filepath):
|
||||
import hashlib
|
||||
|
||||
with open(filepath, "rb") as f:
|
||||
return hashlib.md5(f.read()).hexdigest()
|
||||
|
||||
|
|
@ -268,7 +302,6 @@ class StableHorde:
|
|||
if len(self.config.models) == 0:
|
||||
return f"Current model {model_checkpoint} not found on StableHorde"
|
||||
|
||||
|
||||
async def run(self):
|
||||
await self.get_supported_models()
|
||||
|
||||
|
|
@ -276,8 +309,8 @@ class StableHorde:
|
|||
result = self.detect_current_model()
|
||||
if result is not None:
|
||||
# Wait 10 seconds before retrying to detect the current model
|
||||
# if the current model is not listed in the Stable Horde supported models,
|
||||
# we don't want to spam the server with requests
|
||||
# if the current model is not listed in the Stable Horde supported
|
||||
# models, we don't want to spam the server with requests
|
||||
await asyncio.sleep(10)
|
||||
continue
|
||||
|
||||
|
|
@ -285,13 +318,16 @@ class StableHorde:
|
|||
|
||||
if shared.opts.stable_horde_enable:
|
||||
try:
|
||||
req = await HordeJob.get(await self.get_session(), self.config, self.config.models)
|
||||
req = await HordeJob.get(
|
||||
await self.get_session(), self.config, self.config.models
|
||||
)
|
||||
if req is None:
|
||||
continue
|
||||
|
||||
await self.handle_request(req)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
def patch_sampler_names(self):
|
||||
|
|
@ -299,17 +335,53 @@ class StableHorde:
|
|||
but are not included in the default sd_samplers module.
|
||||
"""
|
||||
from modules import sd_samplers
|
||||
from modules.sd_samplers import KDiffusionSampler, SamplerData
|
||||
|
||||
if sd_samplers.samplers_map.get('euler a karras'):
|
||||
if sd_samplers.samplers_map.get("euler a karras"):
|
||||
# already patched
|
||||
return
|
||||
|
||||
samplers = [
|
||||
sd_samplers.SamplerData("Euler a Karras", lambda model, funcname="sample_euler_ancestral": sd_samplers.KDiffusionSampler(funcname, model), ['k_euler_a_ka'], {'scheduler': 'karras'}),
|
||||
sd_samplers.SamplerData("Euler Karras", lambda model, funcname="sample_euler": sd_samplers.KDiffusionSampler(funcname, model), ['k_euler_ka'], {'scheduler': 'karras'}),
|
||||
sd_samplers.SamplerData("Heun Karras", lambda model, funcname="sample_heun": sd_samplers.KDiffusionSampler(funcname, model), ['k_heun_ka'], {'scheduler': 'karras'}),
|
||||
sd_samplers.SamplerData('DPM adaptive Karras', lambda model, funcname='sample_dpm_adaptive': sd_samplers.KDiffusionSampler(funcname, model), ['k_dpm_ad_ka'], {'scheduler': 'karras'}),
|
||||
sd_samplers.SamplerData('DPM fast Karras', lambda model, funcname='sample_dpm_fast': sd_samplers.KDiffusionSampler(funcname, model), ['k_dpm_fast_ka'], {'scheduler': 'karras'}),
|
||||
SamplerData(
|
||||
"Euler a Karras",
|
||||
lambda model, funcname="sample_euler_ancestral": KDiffusionSampler(
|
||||
funcname, model
|
||||
),
|
||||
["k_euler_a_ka"],
|
||||
{"scheduler": "karras"},
|
||||
),
|
||||
SamplerData(
|
||||
"Euler Karras",
|
||||
lambda model, funcname="sample_euler": KDiffusionSampler(
|
||||
funcname, model
|
||||
),
|
||||
["k_euler_ka"],
|
||||
{"scheduler": "karras"},
|
||||
),
|
||||
SamplerData(
|
||||
"Heun Karras",
|
||||
lambda model, funcname="sample_heun": KDiffusionSampler(
|
||||
funcname, model
|
||||
),
|
||||
["k_heun_ka"],
|
||||
{"scheduler": "karras"},
|
||||
),
|
||||
SamplerData(
|
||||
"DPM adaptive Karras",
|
||||
lambda model, funcname="sample_dpm_adaptive": KDiffusionSampler(
|
||||
funcname, model
|
||||
),
|
||||
["k_dpm_ad_ka"],
|
||||
{"scheduler": "karras"},
|
||||
),
|
||||
SamplerData(
|
||||
"DPM fast Karras",
|
||||
lambda model, funcname="sample_dpm_fast": KDiffusionSampler(
|
||||
funcname, model
|
||||
),
|
||||
["k_dpm_fast_ka"],
|
||||
{"scheduler": "karras"},
|
||||
),
|
||||
]
|
||||
sd_samplers.samplers.extend(samplers)
|
||||
sd_samplers.samplers_for_img2img.extend(samplers)
|
||||
|
|
@ -324,13 +396,13 @@ class StableHorde:
|
|||
|
||||
print(f"Get popped generation request {job.id}")
|
||||
sampler_name = job.sampler
|
||||
if sampler_name == 'k_dpm_adaptive':
|
||||
sampler_name = 'k_dpm_ad'
|
||||
if sampler_name == "k_dpm_adaptive":
|
||||
sampler_name = "k_dpm_ad"
|
||||
if sampler_name not in sd_samplers.samplers_map:
|
||||
print(f"ERROR: Unknown sampler {sampler_name}")
|
||||
return
|
||||
if job.karras:
|
||||
sampler_name += '_ka'
|
||||
sampler_name += "_ka"
|
||||
|
||||
sampler = sd_samplers.samplers_map.get(sampler_name, None)
|
||||
if sampler is None:
|
||||
|
|
@ -392,17 +464,27 @@ class StableHorde:
|
|||
|
||||
if "RealESRGAN_x4plus" in postprocessors and not has_nsfw:
|
||||
from modules.postprocessing import run_extras
|
||||
|
||||
with call_queue.queue_lock:
|
||||
images, _info, _wtf = run_extras(
|
||||
image=image, extras_mode=0, resize_mode=0,
|
||||
show_extras_results=True, upscaling_resize=2,
|
||||
upscaling_resize_h=None, upscaling_resize_w=None,
|
||||
upscaling_crop=False, upscale_first=False,
|
||||
extras_upscaler_1="R-ESRGAN 4x+", # 8 - RealESRGAN_x4plus
|
||||
image=image,
|
||||
extras_mode=0,
|
||||
resize_mode=0,
|
||||
show_extras_results=True,
|
||||
upscaling_resize=2,
|
||||
upscaling_resize_h=None,
|
||||
upscaling_resize_w=None,
|
||||
upscaling_crop=False,
|
||||
upscale_first=False,
|
||||
extras_upscaler_1="R-ESRGAN 4x+", # 8 - RealESRGAN_x4plus
|
||||
extras_upscaler_2=None,
|
||||
extras_upscaler_2_visibility=0.0,
|
||||
gfpgan_visibility=0.0, codeformer_visibility=0.0, codeformer_weight=0.0,
|
||||
image_folder="", input_dir="", output_dir="",
|
||||
gfpgan_visibility=0.0,
|
||||
codeformer_visibility=0.0,
|
||||
codeformer_weight=0.0,
|
||||
image_folder="",
|
||||
input_dir="",
|
||||
output_dir="",
|
||||
save_output=False,
|
||||
)
|
||||
|
||||
|
|
@ -417,17 +499,22 @@ class StableHorde:
|
|||
global safety_feature_extractor, safety_checker
|
||||
|
||||
if safety_feature_extractor is None:
|
||||
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
|
||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
|
||||
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(
|
||||
safety_model_id
|
||||
)
|
||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
||||
safety_model_id
|
||||
)
|
||||
|
||||
safety_checker_input = safety_feature_extractor(x_image, return_tensors="pt")
|
||||
image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
|
||||
image, has_nsfw_concept = safety_checker(
|
||||
images=x_image, clip_input=safety_checker_input.pixel_values
|
||||
)
|
||||
|
||||
if has_nsfw_concept:
|
||||
return self.sfw_request_censor, has_nsfw_concept
|
||||
return Image.fromarray(image), has_nsfw_concept
|
||||
|
||||
|
||||
async def get_session(self) -> aiohttp.ClientSession:
|
||||
if self.session is None:
|
||||
headers = {
|
||||
|
|
|
|||
Loading…
Reference in New Issue