chore: lint existing code (#47)

* chore: add flake8 and black for linting code

* chore: lint all files

* chore: add lint workflow
pull/13/head
Maiko Sinkyaet Tan 2023-01-27 20:03:29 +08:00 committed by GitHub
parent e790ac2e3c
commit 72e28ca8aa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 251 additions and 98 deletions

3
.flake8 Normal file
View File

@ -0,0 +1,3 @@
[flake8]
max-line-length = 88
exclude = .git,__pycache__,old,build,dist

30
.github/workflows/lint.yml vendored Normal file
View File

@ -0,0 +1,30 @@
name: Lint
on:
push:
branches:
- master
pull_request:
jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v2
with:
python-version: 3.10
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
- name: Lint with flake8
run: |
flake8 . --count --show-source --statistics
- name: Lint with black
run: |
black --check .

View File

@ -1,5 +1,5 @@
import launch
if not launch.is_installed("diffusers"):
launch.run_pip(f"install diffusers", "diffusers") # NSFW filter
launch.run_pip(f"install aiohttp", "aiohttp") # asynchroneous HTTP requests
launch.run_pip("install diffusers", "diffusers") # NSFW filter
launch.run_pip("install aiohttp", "aiohttp") # asynchroneous HTTP requests

3
requirements.txt Normal file
View File

@ -0,0 +1,3 @@
# Formatters and linters
black ~= 22.0
flake8

View File

@ -15,22 +15,52 @@ def on_app_started(demo: Optional[gr.Blocks], app: FastAPI):
horde = StableHorde(config)
import gradio.utils
gradio.utils.synchronize_async(horde.run)
def on_ui_settings():
section = ('stable-horde', 'Stable Horde')
shared.opts.add_option('stable_horde_enable', shared.OptionInfo(False, 'Enable', section=section))
shared.opts.add_option('stable_horde_endpoint', shared.OptionInfo('https://stablehorde.net/', 'Endpoint', section=section))
shared.opts.add_option('stable_horde_apikey', shared.OptionInfo('', 'API Key', section=section))
shared.opts.add_option('stable_horde_name', shared.OptionInfo('Stable Horde', 'Worker Name', section=section))
shared.opts.add_option('stable_horde_model', shared.OptionInfo('Anything Diffusion', 'Model', section=section))
shared.opts.add_option('stable_horde_nsfw', shared.OptionInfo(False, 'NSFW', section=section))
shared.opts.add_option('stable_horde_interval', shared.OptionInfo(10, 'Interval', section=section))
shared.opts.add_option('stable_horde_max_pixels', shared.OptionInfo(1024 * 1024, 'Max Pixels', section=section))
shared.opts.add_option('stable_horde_allow_img2img', shared.OptionInfo(True, 'Allow img2img', section=section))
shared.opts.add_option('stable_horde_allow_painting', shared.OptionInfo(True, 'Allow Painting', section=section))
shared.opts.add_option('stable_horde_allow_unsafe_ipaddr', shared.OptionInfo(True, 'Allow Unsafe IP Address', section=section))
section = ("stable-horde", "Stable Horde")
shared.opts.add_option(
"stable_horde_enable", shared.OptionInfo(False, "Enable", section=section)
)
shared.opts.add_option(
"stable_horde_endpoint",
shared.OptionInfo("https://stablehorde.net/", "Endpoint", section=section),
)
shared.opts.add_option(
"stable_horde_apikey", shared.OptionInfo("", "API Key", section=section)
)
shared.opts.add_option(
"stable_horde_name",
shared.OptionInfo("Stable Horde", "Worker Name", section=section),
)
shared.opts.add_option(
"stable_horde_model",
shared.OptionInfo("Anything Diffusion", "Model", section=section),
)
shared.opts.add_option(
"stable_horde_nsfw", shared.OptionInfo(False, "NSFW", section=section)
)
shared.opts.add_option(
"stable_horde_interval", shared.OptionInfo(10, "Interval", section=section)
)
shared.opts.add_option(
"stable_horde_max_pixels",
shared.OptionInfo(1024 * 1024, "Max Pixels", section=section),
)
shared.opts.add_option(
"stable_horde_allow_img2img",
shared.OptionInfo(True, "Allow img2img", section=section),
)
shared.opts.add_option(
"stable_horde_allow_painting",
shared.OptionInfo(True, "Allow Painting", section=section),
)
shared.opts.add_option(
"stable_horde_allow_unsafe_ipaddr",
shared.OptionInfo(True, "Allow Unsafe IP Address", section=section),
)
script_callbacks.on_app_started(on_app_started)

View File

@ -8,13 +8,25 @@ from typing import Any, Dict, List, Optional
import aiohttp
import numpy as np
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
from PIL import Image
from transformers.models.auto.feature_extraction_auto import AutoFeatureExtractor
from modules import shared, call_queue, txt2img, img2img, processing, sd_models, sd_samplers, scripts
from modules import (
shared,
call_queue,
txt2img,
img2img,
processing,
sd_models,
sd_samplers,
)
stable_horde_supported_models_url = "https://raw.githubusercontent.com/Sygil-Dev/nataili-model-reference/main/db.json"
stable_horde_supported_models_url = (
"https://raw.githubusercontent.com/Sygil-Dev/nataili-model-reference/main/db.json"
)
safety_model_id = "CompVis/stable-diffusion-safety-checker"
safety_feature_extractor = None
@ -29,7 +41,7 @@ class StableHordeConfig:
@property
def endpoint(self) -> str:
return shared.opts.stable_horde_endpoint
@property
def apikey(self) -> str:
return shared.opts.stable_horde_apikey
@ -49,11 +61,11 @@ class StableHordeConfig:
@property
def allow_img2img(self) -> bool:
return shared.opts.stable_horde_allow_img2img
@property
def allow_painting(self) -> bool:
return shared.opts.stable_horde_allow_painting
@property
def allow_unsafe_ipaddr(self) -> bool:
return shared.opts.stable_horde_allow_unsafe_ipaddr
@ -62,7 +74,31 @@ class StableHordeConfig:
class HordeJob:
retry_interval: int = 1
def __init__(self, session: aiohttp.ClientSession, id: str, model: str, prompt: str, negative_prompt: str, sampler: str, cfg_scale: float, seed: int, denoising_strength: float, n_iter: int, height: int, width: int, subseed: int, steps: int, karras: bool, tiling: bool, postprocessors: List[str], nsfw_censor: bool = False, source_image: Optional[Image.Image] = None, source_processing: Optional[str] = "img2img", source_mask: Optional[Image.Image] = None, r2_upload: Optional[str] = None):
def __init__(
self,
session: aiohttp.ClientSession,
id: str,
model: str,
prompt: str,
negative_prompt: str,
sampler: str,
cfg_scale: float,
seed: int,
denoising_strength: float,
n_iter: int,
height: int,
width: int,
subseed: int,
steps: int,
karras: bool,
tiling: bool,
postprocessors: List[str],
nsfw_censor: bool = False,
source_image: Optional[Image.Image] = None,
source_processing: Optional[str] = "img2img",
source_mask: Optional[Image.Image] = None,
r2_upload: Optional[str] = None,
):
self.id = id
self.model = model
self.prompt = prompt
@ -82,7 +118,9 @@ class HordeJob:
self.postprocessors = postprocessors
self.nsfw_censor = nsfw_censor
self.source_image = source_image
self.source_processing = source_processing # "img2img", "inpainting", "outpainting"
self.source_processing = (
source_processing # "img2img", "inpainting", "outpainting"
)
self.source_mask = source_mask
self.r2_upload = r2_upload
@ -115,7 +153,7 @@ class HordeJob:
attempts = 10
while attempts > 0:
try:
r = await session.post('/api/v2/generate/submit', json=post_data)
r = await session.post("/api/v2/generate/submit", json=post_data)
try:
res = await r.json()
@ -127,20 +165,30 @@ class HordeJob:
if r.ok:
return res.get("reward", None)
else:
print(f"Failed to submit job with status code {r.status}: {res.get('message')}")
print(
"Failed to submit job with status code"
+ f"{r.status}: {res.get('message')}"
)
return None
except Exception:
print(f"Error when decoding response, the server might be down.")
print("Error when decoding response, the server might be down.")
return None
except aiohttp.ClientConnectorError:
attempts -= 1
await asyncio.sleep(self.retry_interval)
continue
@classmethod
async def get(cls, session: aiohttp.ClientSession, config: StableHordeConfig, models: List[str]):
async def get(
cls,
session: aiohttp.ClientSession,
config: StableHordeConfig,
models: List[str],
):
name = "Stable Horde Worker Bridge for Stable Diffusion WebUI"
version = 10
repo = "https://github.com/sdwebui-w-horde/sd-webui-stable-horde-worker"
# https://stablehorde.net/api/
post_data = {
"name": config.name,
@ -150,7 +198,7 @@ class HordeJob:
"models": models,
# TODO: add support for bridge version 11 "tiling"
"bridge_version": 9,
"bridge_agent": "Stable Horde Worker Bridge for Stable Diffusion WebUI:10:https://github.com/sdwebui-w-horde/sd-webui-stable-horde-worker",
"bridge_agent": f"{name}:{version}:{repo}",
"threads": 1,
"max_pixels": config.max_pixels,
"allow_img2img": config.allow_img2img,
@ -158,77 +206,62 @@ class HordeJob:
"allow_unsafe_ipaddr": config.allow_unsafe_ipaddr,
}
r = await session.post('/api/v2/generate/pop', json=post_data)
r = await session.post("/api/v2/generate/pop", json=post_data)
req = await r.json()
if r.status != 200:
raise Exception(f"Failed to get job: {req.get('message')}")
if not req.get('id'):
if not req.get("id"):
return
payload = req.get('payload')
prompt = payload.get('prompt')
payload = req.get("payload")
prompt = payload.get("prompt")
if "###" in prompt:
prompt, negative = map(lambda x: x.strip(), prompt.split("###"))
else:
negative = ""
def to_image(base64str: Optional[str]) -> Optional[Image.Image]:
if not base64str:
return None
return Image.open(io.BytesIO(base64.b64decode(base64str)))
return cls(
session=session,
id=req['id'],
id=req["id"],
prompt=prompt,
negative_prompt=negative,
sampler=payload.get('sampler_name'),
cfg_scale=payload.get('cfg_scale', 5),
seed=int(payload.get('seed', randint(0, 2**32))),
denoising_strength=payload.get('denoising_strength', 0.75),
n_iter=payload.get('n_iter', 1),
height=payload['height'],
width=payload['width'],
subseed=payload.get('seed_variation', 1),
steps=payload.get('ddim_steps', 30),
karras=payload.get('karras', False),
tiling=payload.get('tiling', False),
postprocessors=payload.get('post_processing', []),
nsfw_censor=payload.get('use_nsfw_censor', False),
model=req['model'],
source_image=to_image(payload.get('source_image')),
source_processing=payload.get('source_processing'),
source_mask=to_image(payload.get('source_mask')),
r2_upload=payload.get('r2_upload'),
sampler=payload.get("sampler_name"),
cfg_scale=payload.get("cfg_scale", 5),
seed=int(payload.get("seed", randint(0, 2**32))),
denoising_strength=payload.get("denoising_strength", 0.75),
n_iter=payload.get("n_iter", 1),
height=payload["height"],
width=payload["width"],
subseed=payload.get("seed_variation", 1),
steps=payload.get("ddim_steps", 30),
karras=payload.get("karras", False),
tiling=payload.get("tiling", False),
postprocessors=payload.get("post_processing", []),
nsfw_censor=payload.get("use_nsfw_censor", False),
model=req["model"],
source_image=to_image(payload.get("source_image")),
source_processing=payload.get("source_processing"),
source_mask=to_image(payload.get("source_mask")),
r2_upload=payload.get("r2_upload"),
)
@property
def allow_painting(self) -> bool:
return shared.opts.stable_horde_allow_painting
@property
def allow_unsafe_ipaddr(self) -> bool:
return shared.opts.stable_horde_allow_unsafe_ipaddr
class StableHorde:
def __init__(self, config: StableHordeConfig):
self.config = config
headers = {
"apikey": self.config.apikey,
"Content-Type": "application/json",
}
self.session: Optional[aiohttp.ClientSession] = None
self.sfw_request_censor = Image.open(path.join(self.config.basedir, "assets", "nsfw_censor_sfw_request.png"))
self.sfw_request_censor = Image.open(
path.join(self.config.basedir, "assets", "nsfw_censor_sfw_request.png")
)
self.supported_models = []
@ -237,9 +270,9 @@ class StableHorde:
if not path.exists(filepath):
async with aiohttp.ClientSession() as session:
async with session.get(stable_horde_supported_models_url) as resp:
with open(filepath, 'wb') as f:
with open(filepath, "wb") as f:
f.write(await resp.read())
with open(filepath, 'r') as f:
with open(filepath, "r") as f:
supported_models: Dict[str, Any] = json.load(f)
self.supported_models = list(supported_models.values())
@ -247,6 +280,7 @@ class StableHorde:
def detect_current_model(self):
def get_md5sum(filepath):
import hashlib
with open(filepath, "rb") as f:
return hashlib.md5(f.read()).hexdigest()
@ -268,7 +302,6 @@ class StableHorde:
if len(self.config.models) == 0:
return f"Current model {model_checkpoint} not found on StableHorde"
async def run(self):
await self.get_supported_models()
@ -276,8 +309,8 @@ class StableHorde:
result = self.detect_current_model()
if result is not None:
# Wait 10 seconds before retrying to detect the current model
# if the current model is not listed in the Stable Horde supported models,
# we don't want to spam the server with requests
# if the current model is not listed in the Stable Horde supported
# models, we don't want to spam the server with requests
await asyncio.sleep(10)
continue
@ -285,13 +318,16 @@ class StableHorde:
if shared.opts.stable_horde_enable:
try:
req = await HordeJob.get(await self.get_session(), self.config, self.config.models)
req = await HordeJob.get(
await self.get_session(), self.config, self.config.models
)
if req is None:
continue
await self.handle_request(req)
except Exception as e:
except Exception:
import traceback
traceback.print_exc()
def patch_sampler_names(self):
@ -299,17 +335,53 @@ class StableHorde:
but are not included in the default sd_samplers module.
"""
from modules import sd_samplers
from modules.sd_samplers import KDiffusionSampler, SamplerData
if sd_samplers.samplers_map.get('euler a karras'):
if sd_samplers.samplers_map.get("euler a karras"):
# already patched
return
samplers = [
sd_samplers.SamplerData("Euler a Karras", lambda model, funcname="sample_euler_ancestral": sd_samplers.KDiffusionSampler(funcname, model), ['k_euler_a_ka'], {'scheduler': 'karras'}),
sd_samplers.SamplerData("Euler Karras", lambda model, funcname="sample_euler": sd_samplers.KDiffusionSampler(funcname, model), ['k_euler_ka'], {'scheduler': 'karras'}),
sd_samplers.SamplerData("Heun Karras", lambda model, funcname="sample_heun": sd_samplers.KDiffusionSampler(funcname, model), ['k_heun_ka'], {'scheduler': 'karras'}),
sd_samplers.SamplerData('DPM adaptive Karras', lambda model, funcname='sample_dpm_adaptive': sd_samplers.KDiffusionSampler(funcname, model), ['k_dpm_ad_ka'], {'scheduler': 'karras'}),
sd_samplers.SamplerData('DPM fast Karras', lambda model, funcname='sample_dpm_fast': sd_samplers.KDiffusionSampler(funcname, model), ['k_dpm_fast_ka'], {'scheduler': 'karras'}),
SamplerData(
"Euler a Karras",
lambda model, funcname="sample_euler_ancestral": KDiffusionSampler(
funcname, model
),
["k_euler_a_ka"],
{"scheduler": "karras"},
),
SamplerData(
"Euler Karras",
lambda model, funcname="sample_euler": KDiffusionSampler(
funcname, model
),
["k_euler_ka"],
{"scheduler": "karras"},
),
SamplerData(
"Heun Karras",
lambda model, funcname="sample_heun": KDiffusionSampler(
funcname, model
),
["k_heun_ka"],
{"scheduler": "karras"},
),
SamplerData(
"DPM adaptive Karras",
lambda model, funcname="sample_dpm_adaptive": KDiffusionSampler(
funcname, model
),
["k_dpm_ad_ka"],
{"scheduler": "karras"},
),
SamplerData(
"DPM fast Karras",
lambda model, funcname="sample_dpm_fast": KDiffusionSampler(
funcname, model
),
["k_dpm_fast_ka"],
{"scheduler": "karras"},
),
]
sd_samplers.samplers.extend(samplers)
sd_samplers.samplers_for_img2img.extend(samplers)
@ -324,13 +396,13 @@ class StableHorde:
print(f"Get popped generation request {job.id}")
sampler_name = job.sampler
if sampler_name == 'k_dpm_adaptive':
sampler_name = 'k_dpm_ad'
if sampler_name == "k_dpm_adaptive":
sampler_name = "k_dpm_ad"
if sampler_name not in sd_samplers.samplers_map:
print(f"ERROR: Unknown sampler {sampler_name}")
return
if job.karras:
sampler_name += '_ka'
sampler_name += "_ka"
sampler = sd_samplers.samplers_map.get(sampler_name, None)
if sampler is None:
@ -392,17 +464,27 @@ class StableHorde:
if "RealESRGAN_x4plus" in postprocessors and not has_nsfw:
from modules.postprocessing import run_extras
with call_queue.queue_lock:
images, _info, _wtf = run_extras(
image=image, extras_mode=0, resize_mode=0,
show_extras_results=True, upscaling_resize=2,
upscaling_resize_h=None, upscaling_resize_w=None,
upscaling_crop=False, upscale_first=False,
extras_upscaler_1="R-ESRGAN 4x+", # 8 - RealESRGAN_x4plus
image=image,
extras_mode=0,
resize_mode=0,
show_extras_results=True,
upscaling_resize=2,
upscaling_resize_h=None,
upscaling_resize_w=None,
upscaling_crop=False,
upscale_first=False,
extras_upscaler_1="R-ESRGAN 4x+", # 8 - RealESRGAN_x4plus
extras_upscaler_2=None,
extras_upscaler_2_visibility=0.0,
gfpgan_visibility=0.0, codeformer_visibility=0.0, codeformer_weight=0.0,
image_folder="", input_dir="", output_dir="",
gfpgan_visibility=0.0,
codeformer_visibility=0.0,
codeformer_weight=0.0,
image_folder="",
input_dir="",
output_dir="",
save_output=False,
)
@ -417,17 +499,22 @@ class StableHorde:
global safety_feature_extractor, safety_checker
if safety_feature_extractor is None:
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(
safety_model_id
)
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
safety_model_id
)
safety_checker_input = safety_feature_extractor(x_image, return_tensors="pt")
image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
image, has_nsfw_concept = safety_checker(
images=x_image, clip_input=safety_checker_input.pixel_values
)
if has_nsfw_concept:
return self.sfw_request_censor, has_nsfw_concept
return Image.fromarray(image), has_nsfw_concept
async def get_session(self) -> aiohttp.ClientSession:
if self.session is None:
headers = {