fix: nsfw_censor not working (#6)
* chore: import image from natali repo * fix: return censor png instead of original image * fix: basedir broken when calling outsidepull/15/head
parent
a68f2c208d
commit
b1127c09ed
Binary file not shown.
|
After Width: | Height: | Size: 38 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 29 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 45 KiB |
|
|
@ -0,0 +1 @@
|
||||||
|
**Images in this folder was copied from https://github.com/Sygil-Dev/nataili and licenced as AGPL-3.0 without any modification.**
|
||||||
|
|
@ -3,12 +3,14 @@ from typing import Optional
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from modules import script_callbacks, shared
|
from modules import scripts, script_callbacks, shared
|
||||||
|
|
||||||
from stable_horde import StableHorde, StableHordeConfig
|
from stable_horde import StableHorde, StableHordeConfig
|
||||||
|
|
||||||
|
basedir = scripts.basedir()
|
||||||
|
|
||||||
async def start_horde():
|
async def start_horde():
|
||||||
config = StableHordeConfig()
|
config = StableHordeConfig(basedir)
|
||||||
horde = StableHorde(config)
|
horde = StableHorde(config)
|
||||||
await horde.run()
|
await horde.run()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
import io
|
import io
|
||||||
|
from os import path
|
||||||
from random import randint
|
from random import randint
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
|
@ -10,7 +11,7 @@ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionS
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import AutoFeatureExtractor
|
from transformers import AutoFeatureExtractor
|
||||||
|
|
||||||
from modules import shared, call_queue, txt2img, img2img, processing, sd_samplers
|
from modules import shared, call_queue, txt2img, img2img, processing, sd_samplers, scripts
|
||||||
|
|
||||||
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
||||||
safety_feature_extractor = None
|
safety_feature_extractor = None
|
||||||
|
|
@ -18,8 +19,8 @@ safety_checker = None
|
||||||
|
|
||||||
|
|
||||||
class StableHordeConfig:
|
class StableHordeConfig:
|
||||||
def __init__(self):
|
def __init__(self, basedir: str):
|
||||||
pass
|
self.basedir = basedir
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def endpoint(self) -> str:
|
def endpoint(self) -> str:
|
||||||
|
|
@ -68,6 +69,8 @@ class StableHorde:
|
||||||
|
|
||||||
self.session = aiohttp.ClientSession(self.config.endpoint, headers=headers)
|
self.session = aiohttp.ClientSession(self.config.endpoint, headers=headers)
|
||||||
|
|
||||||
|
self.sfw_request_censor = Image.open(path.join(self.config.basedir, "assets", "nsfw_censor_sfw_request.png"))
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
while True:
|
while True:
|
||||||
await asyncio.sleep(shared.opts.stable_horde_interval)
|
await asyncio.sleep(shared.opts.stable_horde_interval)
|
||||||
|
|
@ -201,14 +204,16 @@ class StableHorde:
|
||||||
with call_queue.queue_lock:
|
with call_queue.queue_lock:
|
||||||
processed = processing.process_images(p)
|
processed = processing.process_images(p)
|
||||||
|
|
||||||
|
has_nsfw = False
|
||||||
|
|
||||||
if req["payload"].get("use_nsfw_censor"):
|
if req["payload"].get("use_nsfw_censor"):
|
||||||
x_image = np.array(processed.images[0])
|
x_image = np.array(processed.images[0])
|
||||||
x_checked_image, _ = self.check_safety(x_image)
|
image, has_nsfw = self.check_safety(x_image)
|
||||||
image = Image.fromarray(x_checked_image)
|
|
||||||
else:
|
else:
|
||||||
image = processed.images[0]
|
image = processed.images[0]
|
||||||
|
|
||||||
if "RealESRGAN_x4plus" in postprocessors:
|
if "RealESRGAN_x4plus" in postprocessors and not has_nsfw:
|
||||||
from modules.extras import run_extras
|
from modules.extras import run_extras
|
||||||
images, _info, _wtf = run_extras(
|
images, _info, _wtf = run_extras(
|
||||||
image=image, extras_mode=0, resize_mode=0,
|
image=image, extras_mode=0, resize_mode=0,
|
||||||
|
|
@ -273,9 +278,11 @@ class StableHorde:
|
||||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
|
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
|
||||||
|
|
||||||
safety_checker_input = safety_feature_extractor(x_image, return_tensors="pt")
|
safety_checker_input = safety_feature_extractor(x_image, return_tensors="pt")
|
||||||
x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
|
image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
|
||||||
|
|
||||||
return x_checked_image, has_nsfw_concept
|
if has_nsfw_concept:
|
||||||
|
return self.sfw_request_censor, has_nsfw_concept
|
||||||
|
return Image.fromarray(image), has_nsfw_concept
|
||||||
|
|
||||||
|
|
||||||
def handle_error(self, status: int, res: Dict[str, Any]):
|
def handle_error(self, status: int, res: Dict[str, Any]):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue