fix: nsfw_censor not working (#6)

* chore: import image from natali repo

* fix: return censor png instead of original image

* fix: basedir broken when calling outside
pull/15/head
Maiko Sinkyaet Tan 2022-12-19 21:40:30 +08:00 committed by Maiko Tan
parent a68f2c208d
commit b1127c09ed
6 changed files with 20 additions and 10 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 38 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 29 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 45 KiB

1
assets/readme.md Normal file
View File

@ -0,0 +1 @@
**Images in this folder was copied from https://github.com/Sygil-Dev/nataili and licenced as AGPL-3.0 without any modification.**

View File

@ -3,12 +3,14 @@ from typing import Optional
from fastapi import FastAPI
import gradio as gr
from modules import script_callbacks, shared
from modules import scripts, script_callbacks, shared
from stable_horde import StableHorde, StableHordeConfig
basedir = scripts.basedir()
async def start_horde():
config = StableHordeConfig()
config = StableHordeConfig(basedir)
horde = StableHorde(config)
await horde.run()

View File

@ -1,6 +1,7 @@
import asyncio
import base64
import io
from os import path
from random import randint
from typing import Any, Dict, List, Optional
@ -10,7 +11,7 @@ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionS
from PIL import Image
from transformers import AutoFeatureExtractor
from modules import shared, call_queue, txt2img, img2img, processing, sd_samplers
from modules import shared, call_queue, txt2img, img2img, processing, sd_samplers, scripts
safety_model_id = "CompVis/stable-diffusion-safety-checker"
safety_feature_extractor = None
@ -18,8 +19,8 @@ safety_checker = None
class StableHordeConfig:
def __init__(self):
pass
def __init__(self, basedir: str):
self.basedir = basedir
@property
def endpoint(self) -> str:
@ -68,6 +69,8 @@ class StableHorde:
self.session = aiohttp.ClientSession(self.config.endpoint, headers=headers)
self.sfw_request_censor = Image.open(path.join(self.config.basedir, "assets", "nsfw_censor_sfw_request.png"))
async def run(self):
while True:
await asyncio.sleep(shared.opts.stable_horde_interval)
@ -201,14 +204,16 @@ class StableHorde:
with call_queue.queue_lock:
processed = processing.process_images(p)
has_nsfw = False
if req["payload"].get("use_nsfw_censor"):
x_image = np.array(processed.images[0])
x_checked_image, _ = self.check_safety(x_image)
image = Image.fromarray(x_checked_image)
image, has_nsfw = self.check_safety(x_image)
else:
image = processed.images[0]
if "RealESRGAN_x4plus" in postprocessors:
if "RealESRGAN_x4plus" in postprocessors and not has_nsfw:
from modules.extras import run_extras
images, _info, _wtf = run_extras(
image=image, extras_mode=0, resize_mode=0,
@ -273,9 +278,11 @@ class StableHorde:
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
safety_checker_input = safety_feature_extractor(x_image, return_tensors="pt")
x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
return x_checked_image, has_nsfw_concept
if has_nsfw_concept:
return self.sfw_request_censor, has_nsfw_concept
return Image.fromarray(image), has_nsfw_concept
def handle_error(self, status: int, res: Dict[str, Any]):