diff --git a/assets/nsfw_censor_censorlist.png b/assets/nsfw_censor_censorlist.png new file mode 100644 index 0000000..c4f2b1e Binary files /dev/null and b/assets/nsfw_censor_censorlist.png differ diff --git a/assets/nsfw_censor_sfw_request.png b/assets/nsfw_censor_sfw_request.png new file mode 100644 index 0000000..444e0f2 Binary files /dev/null and b/assets/nsfw_censor_sfw_request.png differ diff --git a/assets/nsfw_censor_sfw_worker.png b/assets/nsfw_censor_sfw_worker.png new file mode 100644 index 0000000..d36858e Binary files /dev/null and b/assets/nsfw_censor_sfw_worker.png differ diff --git a/assets/readme.md b/assets/readme.md new file mode 100644 index 0000000..04f0fb6 --- /dev/null +++ b/assets/readme.md @@ -0,0 +1 @@ +**Images in this folder was copied from https://github.com/Sygil-Dev/nataili and licenced as AGPL-3.0 without any modification.** diff --git a/scripts/script.py b/scripts/script.py index 66add32..7ab64b1 100644 --- a/scripts/script.py +++ b/scripts/script.py @@ -3,12 +3,14 @@ from typing import Optional from fastapi import FastAPI import gradio as gr -from modules import script_callbacks, shared +from modules import scripts, script_callbacks, shared from stable_horde import StableHorde, StableHordeConfig +basedir = scripts.basedir() + async def start_horde(): - config = StableHordeConfig() + config = StableHordeConfig(basedir) horde = StableHorde(config) await horde.run() diff --git a/stable_horde.py b/stable_horde.py index e62ea85..d161179 100644 --- a/stable_horde.py +++ b/stable_horde.py @@ -1,6 +1,7 @@ import asyncio import base64 import io +from os import path from random import randint from typing import Any, Dict, List, Optional @@ -10,7 +11,7 @@ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionS from PIL import Image from transformers import AutoFeatureExtractor -from modules import shared, call_queue, txt2img, img2img, processing, sd_samplers +from modules import shared, call_queue, txt2img, img2img, processing, sd_samplers, scripts safety_model_id = "CompVis/stable-diffusion-safety-checker" safety_feature_extractor = None @@ -18,8 +19,8 @@ safety_checker = None class StableHordeConfig: - def __init__(self): - pass + def __init__(self, basedir: str): + self.basedir = basedir @property def endpoint(self) -> str: @@ -68,6 +69,8 @@ class StableHorde: self.session = aiohttp.ClientSession(self.config.endpoint, headers=headers) + self.sfw_request_censor = Image.open(path.join(self.config.basedir, "assets", "nsfw_censor_sfw_request.png")) + async def run(self): while True: await asyncio.sleep(shared.opts.stable_horde_interval) @@ -201,14 +204,16 @@ class StableHorde: with call_queue.queue_lock: processed = processing.process_images(p) + has_nsfw = False + if req["payload"].get("use_nsfw_censor"): x_image = np.array(processed.images[0]) - x_checked_image, _ = self.check_safety(x_image) - image = Image.fromarray(x_checked_image) + image, has_nsfw = self.check_safety(x_image) + else: image = processed.images[0] - if "RealESRGAN_x4plus" in postprocessors: + if "RealESRGAN_x4plus" in postprocessors and not has_nsfw: from modules.extras import run_extras images, _info, _wtf = run_extras( image=image, extras_mode=0, resize_mode=0, @@ -273,9 +278,11 @@ class StableHorde: safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id) safety_checker_input = safety_feature_extractor(x_image, return_tensors="pt") - x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values) + image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values) - return x_checked_image, has_nsfw_concept + if has_nsfw_concept: + return self.sfw_request_censor, has_nsfw_concept + return Image.fromarray(image), has_nsfw_concept def handle_error(self, status: int, res: Dict[str, Any]):