diff --git a/reallysafe.py b/reallysafe.py index 95c94cc..57628dd 100644 --- a/reallysafe.py +++ b/reallysafe.py @@ -18,6 +18,16 @@ def encode(*args): return out +extras_dict = { + "yolo": [], + 'torch.nn.modules': ['Conv', 'Conv2d', 'BatchNorm2d', "SiLU", "MaxPool2d", "Upsample", "ModuleList"], + "models.common": ["C3", "Bottleneck", "SPPF", "Concat", "Conv"], + "numpy.core.multiarray": ["_reconstruct"], + 'torch': ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', + 'BFloat16Storage'] +} + + class RestrictedUnpickler(pickle.Unpickler): extra_handler = None @@ -52,24 +62,17 @@ class RestrictedUnpickler(pickle.Unpickler): if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint': import pytorch_lightning.callbacks.model_checkpoint return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint - if "yolo" in module: - return super().find_class(module, name) - if module == "models.common" and name == "Conv": - return super().find_class(module, name) - if 'torch.nn.modules' in module and name in ['Conv', 'Conv2d', 'BatchNorm2d', "SiLU", "MaxPool2d", "Upsample", - "ModuleList"]: - return super().find_class(module, name) - if "models.common" in module and name in ["C3", "Bottleneck", "SPPF", "Concat"]: - return super().find_class(module, name) if module == "__builtin__" and name == 'set': return set - - # Forbid everything else. + for key in extras_dict: + if module in key and name in extras_dict[module] or len(extras_dict[module]) == 0: + return super().find_class(name, module) raise Exception(f"global '{module}/{name}' is forbidden") # Regular expression that accepts 'dirname/version', 'dirname/data.pkl', and 'dirname/data/' allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|(data\.pkl))$") +extra_zip_names = [] data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$") @@ -77,7 +80,8 @@ def check_zip_filenames(filename, names): for name in names: if allowed_zip_names_re.match(name): continue - + if name in extra_zip_names: + continue raise Exception(f"bad file inside {filename}: {name}") @@ -87,7 +91,7 @@ def check_pt(filename, extra_handler): # new pytorch format is a zip file with zipfile.ZipFile(filename) as z: check_zip_filenames(filename, z.namelist()) - + # find filename of data.pkl in zip file: '/data.pkl' data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)] if len(data_pkl_filenames) == 0: @@ -145,7 +149,9 @@ def load_with_extra(filename, extra_handler=None, *args, **kwargs): print(f"Error verifying pickled file from {filename}:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) print(f"-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr) - print(f"You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr) + print( + f"You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", + file=sys.stderr) return None except Exception: diff --git a/smartprocess.py b/smartprocess.py index 96aa607..b11eb9c 100644 --- a/smartprocess.py +++ b/smartprocess.py @@ -1,26 +1,20 @@ -import math import os import sys import traceback +from io import StringIO from pathlib import Path import numpy as np import tqdm -from PIL import Image, ImageOps -from functools import reduce +from PIL import Image, ImageOps, features import modules.codeformer_model import modules.gfpgan_model import reallysafe from clipcrop import CropClip -from extensions.sd_dreambooth_extension.dreambooth.utils import list_features, is_image from extensions.sd_smartprocess.clipinterrogator import ClipInterrogator from extensions.sd_smartprocess.interrogator import WaifuDiffusionInterrogator, BooruInterrogator from modules import shared, images, safe -from modules.shared import cmd_opts - -if cmd_opts.deepdanbooru: - import modules.deepbooru as deepbooru def printi(message): @@ -28,6 +22,31 @@ def printi(message): print(message) +def list_features(): + # Create buffer for pilinfo() to write into rather than stdout + buffer = StringIO() + features.pilinfo(out=buffer) + pil_features = [] + # Parse and analyse lines + for line in buffer.getvalue().splitlines(): + if "Extensions:" in line: + ext_list = line.split(": ")[1] + extensions = ext_list.split(", ") + for extension in extensions: + if extension not in pil_features: + pil_features.append(extension) + return pil_features + + +def is_image(path: Path, feats=None): + if feats is None: + feats = [] + if not len(feats): + feats = list_features() + is_img = path.is_file() and path.suffix.lower() in feats + return is_img + + def preprocess(rename, src, dst,