first!
commit
c21cc99b73
|
|
@ -0,0 +1 @@
|
||||||
|
A NSFW checker for [Web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui). Replaces non-worksafe images with black squares. Install it from UI.
|
||||||
|
|
@ -0,0 +1,4 @@
|
||||||
|
import launch
|
||||||
|
|
||||||
|
if not launch.is_installed("diffusers"):
|
||||||
|
launch.run_pip(f"install diffusers", "diffusers")
|
||||||
|
|
@ -0,0 +1,56 @@
|
||||||
|
import torch
|
||||||
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
|
from transformers import AutoFeatureExtractor
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from modules import scripts, shared
|
||||||
|
|
||||||
|
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
||||||
|
safety_feature_extractor = None
|
||||||
|
safety_checker = None
|
||||||
|
|
||||||
|
|
||||||
|
def numpy_to_pil(images):
|
||||||
|
"""
|
||||||
|
Convert a numpy image or a batch of images to a PIL image.
|
||||||
|
"""
|
||||||
|
if images.ndim == 3:
|
||||||
|
images = images[None, ...]
|
||||||
|
images = (images * 255).round().astype("uint8")
|
||||||
|
pil_images = [Image.fromarray(image) for image in images]
|
||||||
|
|
||||||
|
return pil_images
|
||||||
|
|
||||||
|
|
||||||
|
# check and replace nsfw content
|
||||||
|
def check_safety(x_image):
|
||||||
|
global safety_feature_extractor, safety_checker
|
||||||
|
|
||||||
|
if safety_feature_extractor is None:
|
||||||
|
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
|
||||||
|
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
|
||||||
|
|
||||||
|
safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
|
||||||
|
x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
|
||||||
|
|
||||||
|
return x_checked_image, has_nsfw_concept
|
||||||
|
|
||||||
|
|
||||||
|
def censor_batch(x):
|
||||||
|
x_samples_ddim_numpy = x.cpu().permute(0, 2, 3, 1).numpy()
|
||||||
|
x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim_numpy)
|
||||||
|
x = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class NsfwCheckScript(scripts.Script):
|
||||||
|
def title(self):
|
||||||
|
return "NSFW check"
|
||||||
|
|
||||||
|
def show(self, is_img2img):
|
||||||
|
return scripts.AlwaysVisible
|
||||||
|
|
||||||
|
def postprocess_batch(self, p, *args, **kwargs):
|
||||||
|
images = kwargs['images']
|
||||||
|
images[:] = censor_batch(images)[:]
|
||||||
Loading…
Reference in New Issue