From c21cc99b731ff295039b0531e73968d1bf384570 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 10 Dec 2022 14:33:46 +0300 Subject: [PATCH] first! --- README.md | 1 + install.py | 4 ++++ scripts/censor.py | 56 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 61 insertions(+) create mode 100644 README.md create mode 100644 install.py create mode 100644 scripts/censor.py diff --git a/README.md b/README.md new file mode 100644 index 0000000..edb991c --- /dev/null +++ b/README.md @@ -0,0 +1 @@ +A NSFW checker for [Web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui). Replaces non-worksafe images with black squares. Install it from UI. diff --git a/install.py b/install.py new file mode 100644 index 0000000..d2271bb --- /dev/null +++ b/install.py @@ -0,0 +1,4 @@ +import launch + +if not launch.is_installed("diffusers"): + launch.run_pip(f"install diffusers", "diffusers") diff --git a/scripts/censor.py b/scripts/censor.py new file mode 100644 index 0000000..e27dea3 --- /dev/null +++ b/scripts/censor.py @@ -0,0 +1,56 @@ +import torch +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from transformers import AutoFeatureExtractor +from PIL import Image + +from modules import scripts, shared + +safety_model_id = "CompVis/stable-diffusion-safety-checker" +safety_feature_extractor = None +safety_checker = None + + +def numpy_to_pil(images): + """ + Convert a numpy image or a batch of images to a PIL image. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + pil_images = [Image.fromarray(image) for image in images] + + return pil_images + + +# check and replace nsfw content +def check_safety(x_image): + global safety_feature_extractor, safety_checker + + if safety_feature_extractor is None: + safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id) + safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id) + + safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt") + x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values) + + return x_checked_image, has_nsfw_concept + + +def censor_batch(x): + x_samples_ddim_numpy = x.cpu().permute(0, 2, 3, 1).numpy() + x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim_numpy) + x = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2) + + return x + + +class NsfwCheckScript(scripts.Script): + def title(self): + return "NSFW check" + + def show(self, is_img2img): + return scripts.AlwaysVisible + + def postprocess_batch(self, p, *args, **kwargs): + images = kwargs['images'] + images[:] = censor_batch(images)[:]