Fix cross-extension import, update reallysafe.

And some code cleanup.
pull/12/head
d8ahazard 2022-12-15 12:03:50 -06:00
parent d1fe158903
commit e4c638c7b8
2 changed files with 47 additions and 22 deletions

View File

@ -18,6 +18,16 @@ def encode(*args):
return out
extras_dict = {
"yolo": [],
'torch.nn.modules': ['Conv', 'Conv2d', 'BatchNorm2d', "SiLU", "MaxPool2d", "Upsample", "ModuleList"],
"models.common": ["C3", "Bottleneck", "SPPF", "Concat", "Conv"],
"numpy.core.multiarray": ["_reconstruct"],
'torch': ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage',
'BFloat16Storage']
}
class RestrictedUnpickler(pickle.Unpickler):
extra_handler = None
@ -52,24 +62,17 @@ class RestrictedUnpickler(pickle.Unpickler):
if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint':
import pytorch_lightning.callbacks.model_checkpoint
return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
if "yolo" in module:
return super().find_class(module, name)
if module == "models.common" and name == "Conv":
return super().find_class(module, name)
if 'torch.nn.modules' in module and name in ['Conv', 'Conv2d', 'BatchNorm2d', "SiLU", "MaxPool2d", "Upsample",
"ModuleList"]:
return super().find_class(module, name)
if "models.common" in module and name in ["C3", "Bottleneck", "SPPF", "Concat"]:
return super().find_class(module, name)
if module == "__builtin__" and name == 'set':
return set
# Forbid everything else.
for key in extras_dict:
if module in key and name in extras_dict[module] or len(extras_dict[module]) == 0:
return super().find_class(name, module)
raise Exception(f"global '{module}/{name}' is forbidden")
# Regular expression that accepts 'dirname/version', 'dirname/data.pkl', and 'dirname/data/<number>'
allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|(data\.pkl))$")
extra_zip_names = []
data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$")
@ -77,7 +80,8 @@ def check_zip_filenames(filename, names):
for name in names:
if allowed_zip_names_re.match(name):
continue
if name in extra_zip_names:
continue
raise Exception(f"bad file inside {filename}: {name}")
@ -87,7 +91,7 @@ def check_pt(filename, extra_handler):
# new pytorch format is a zip file
with zipfile.ZipFile(filename) as z:
check_zip_filenames(filename, z.namelist())
# find filename of data.pkl in zip file: '<directory name>/data.pkl'
data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
if len(data_pkl_filenames) == 0:
@ -145,7 +149,9 @@ def load_with_extra(filename, extra_handler=None, *args, **kwargs):
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
print(f"-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr)
print(f"You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr)
print(
f"You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n",
file=sys.stderr)
return None
except Exception:

View File

@ -1,26 +1,20 @@
import math
import os
import sys
import traceback
from io import StringIO
from pathlib import Path
import numpy as np
import tqdm
from PIL import Image, ImageOps
from functools import reduce
from PIL import Image, ImageOps, features
import modules.codeformer_model
import modules.gfpgan_model
import reallysafe
from clipcrop import CropClip
from extensions.sd_dreambooth_extension.dreambooth.utils import list_features, is_image
from extensions.sd_smartprocess.clipinterrogator import ClipInterrogator
from extensions.sd_smartprocess.interrogator import WaifuDiffusionInterrogator, BooruInterrogator
from modules import shared, images, safe
from modules.shared import cmd_opts
if cmd_opts.deepdanbooru:
import modules.deepbooru as deepbooru
def printi(message):
@ -28,6 +22,31 @@ def printi(message):
print(message)
def list_features():
# Create buffer for pilinfo() to write into rather than stdout
buffer = StringIO()
features.pilinfo(out=buffer)
pil_features = []
# Parse and analyse lines
for line in buffer.getvalue().splitlines():
if "Extensions:" in line:
ext_list = line.split(": ")[1]
extensions = ext_list.split(", ")
for extension in extensions:
if extension not in pil_features:
pil_features.append(extension)
return pil_features
def is_image(path: Path, feats=None):
if feats is None:
feats = []
if not len(feats):
feats = list_features()
is_img = path.is_file() and path.suffix.lower() in feats
return is_img
def preprocess(rename,
src,
dst,