fix: nsfw_censor not working (#6)

* chore: import image from natali repo

* fix: return censor png instead of original image

* fix: basedir broken when calling outside
pull/15/head
Maiko Sinkyaet Tan 2022-12-19 21:40:30 +08:00 committed by Maiko Tan
parent a68f2c208d
commit b1127c09ed
6 changed files with 20 additions and 10 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 38 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 29 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 45 KiB

1
assets/readme.md Normal file
View File

@ -0,0 +1 @@
**Images in this folder was copied from https://github.com/Sygil-Dev/nataili and licenced as AGPL-3.0 without any modification.**

View File

@ -3,12 +3,14 @@ from typing import Optional
from fastapi import FastAPI from fastapi import FastAPI
import gradio as gr import gradio as gr
from modules import script_callbacks, shared from modules import scripts, script_callbacks, shared
from stable_horde import StableHorde, StableHordeConfig from stable_horde import StableHorde, StableHordeConfig
basedir = scripts.basedir()
async def start_horde(): async def start_horde():
config = StableHordeConfig() config = StableHordeConfig(basedir)
horde = StableHorde(config) horde = StableHorde(config)
await horde.run() await horde.run()

View File

@ -1,6 +1,7 @@
import asyncio import asyncio
import base64 import base64
import io import io
from os import path
from random import randint from random import randint
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
@ -10,7 +11,7 @@ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionS
from PIL import Image from PIL import Image
from transformers import AutoFeatureExtractor from transformers import AutoFeatureExtractor
from modules import shared, call_queue, txt2img, img2img, processing, sd_samplers from modules import shared, call_queue, txt2img, img2img, processing, sd_samplers, scripts
safety_model_id = "CompVis/stable-diffusion-safety-checker" safety_model_id = "CompVis/stable-diffusion-safety-checker"
safety_feature_extractor = None safety_feature_extractor = None
@ -18,8 +19,8 @@ safety_checker = None
class StableHordeConfig: class StableHordeConfig:
def __init__(self): def __init__(self, basedir: str):
pass self.basedir = basedir
@property @property
def endpoint(self) -> str: def endpoint(self) -> str:
@ -68,6 +69,8 @@ class StableHorde:
self.session = aiohttp.ClientSession(self.config.endpoint, headers=headers) self.session = aiohttp.ClientSession(self.config.endpoint, headers=headers)
self.sfw_request_censor = Image.open(path.join(self.config.basedir, "assets", "nsfw_censor_sfw_request.png"))
async def run(self): async def run(self):
while True: while True:
await asyncio.sleep(shared.opts.stable_horde_interval) await asyncio.sleep(shared.opts.stable_horde_interval)
@ -201,14 +204,16 @@ class StableHorde:
with call_queue.queue_lock: with call_queue.queue_lock:
processed = processing.process_images(p) processed = processing.process_images(p)
has_nsfw = False
if req["payload"].get("use_nsfw_censor"): if req["payload"].get("use_nsfw_censor"):
x_image = np.array(processed.images[0]) x_image = np.array(processed.images[0])
x_checked_image, _ = self.check_safety(x_image) image, has_nsfw = self.check_safety(x_image)
image = Image.fromarray(x_checked_image)
else: else:
image = processed.images[0] image = processed.images[0]
if "RealESRGAN_x4plus" in postprocessors: if "RealESRGAN_x4plus" in postprocessors and not has_nsfw:
from modules.extras import run_extras from modules.extras import run_extras
images, _info, _wtf = run_extras( images, _info, _wtf = run_extras(
image=image, extras_mode=0, resize_mode=0, image=image, extras_mode=0, resize_mode=0,
@ -273,9 +278,11 @@ class StableHorde:
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id) safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
safety_checker_input = safety_feature_extractor(x_image, return_tensors="pt") safety_checker_input = safety_feature_extractor(x_image, return_tensors="pt")
x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values) image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
return x_checked_image, has_nsfw_concept if has_nsfw_concept:
return self.sfw_request_censor, has_nsfw_concept
return Image.fromarray(image), has_nsfw_concept
def handle_error(self, status: int, res: Dict[str, Any]): def handle_error(self, status: int, res: Dict[str, Any]):