fix: nsfw_censor not working (#6)
* chore: import image from natali repo * fix: return censor png instead of original image * fix: basedir broken when calling outsidepull/15/head
parent
a68f2c208d
commit
b1127c09ed
Binary file not shown.
|
After Width: | Height: | Size: 38 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 29 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 45 KiB |
|
|
@ -0,0 +1 @@
|
|||
**Images in this folder was copied from https://github.com/Sygil-Dev/nataili and licenced as AGPL-3.0 without any modification.**
|
||||
|
|
@ -3,12 +3,14 @@ from typing import Optional
|
|||
from fastapi import FastAPI
|
||||
import gradio as gr
|
||||
|
||||
from modules import script_callbacks, shared
|
||||
from modules import scripts, script_callbacks, shared
|
||||
|
||||
from stable_horde import StableHorde, StableHordeConfig
|
||||
|
||||
basedir = scripts.basedir()
|
||||
|
||||
async def start_horde():
|
||||
config = StableHordeConfig()
|
||||
config = StableHordeConfig(basedir)
|
||||
horde = StableHorde(config)
|
||||
await horde.run()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
from os import path
|
||||
from random import randint
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
|
@ -10,7 +11,7 @@ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionS
|
|||
from PIL import Image
|
||||
from transformers import AutoFeatureExtractor
|
||||
|
||||
from modules import shared, call_queue, txt2img, img2img, processing, sd_samplers
|
||||
from modules import shared, call_queue, txt2img, img2img, processing, sd_samplers, scripts
|
||||
|
||||
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
||||
safety_feature_extractor = None
|
||||
|
|
@ -18,8 +19,8 @@ safety_checker = None
|
|||
|
||||
|
||||
class StableHordeConfig:
|
||||
def __init__(self):
|
||||
pass
|
||||
def __init__(self, basedir: str):
|
||||
self.basedir = basedir
|
||||
|
||||
@property
|
||||
def endpoint(self) -> str:
|
||||
|
|
@ -68,6 +69,8 @@ class StableHorde:
|
|||
|
||||
self.session = aiohttp.ClientSession(self.config.endpoint, headers=headers)
|
||||
|
||||
self.sfw_request_censor = Image.open(path.join(self.config.basedir, "assets", "nsfw_censor_sfw_request.png"))
|
||||
|
||||
async def run(self):
|
||||
while True:
|
||||
await asyncio.sleep(shared.opts.stable_horde_interval)
|
||||
|
|
@ -201,14 +204,16 @@ class StableHorde:
|
|||
with call_queue.queue_lock:
|
||||
processed = processing.process_images(p)
|
||||
|
||||
has_nsfw = False
|
||||
|
||||
if req["payload"].get("use_nsfw_censor"):
|
||||
x_image = np.array(processed.images[0])
|
||||
x_checked_image, _ = self.check_safety(x_image)
|
||||
image = Image.fromarray(x_checked_image)
|
||||
image, has_nsfw = self.check_safety(x_image)
|
||||
|
||||
else:
|
||||
image = processed.images[0]
|
||||
|
||||
if "RealESRGAN_x4plus" in postprocessors:
|
||||
if "RealESRGAN_x4plus" in postprocessors and not has_nsfw:
|
||||
from modules.extras import run_extras
|
||||
images, _info, _wtf = run_extras(
|
||||
image=image, extras_mode=0, resize_mode=0,
|
||||
|
|
@ -273,9 +278,11 @@ class StableHorde:
|
|||
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
|
||||
|
||||
safety_checker_input = safety_feature_extractor(x_image, return_tensors="pt")
|
||||
x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
|
||||
image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
|
||||
|
||||
return x_checked_image, has_nsfw_concept
|
||||
if has_nsfw_concept:
|
||||
return self.sfw_request_censor, has_nsfw_concept
|
||||
return Image.fromarray(image), has_nsfw_concept
|
||||
|
||||
|
||||
def handle_error(self, status: int, res: Dict[str, Any]):
|
||||
|
|
|
|||
Loading…
Reference in New Issue