parent
d1fe158903
commit
e4c638c7b8
|
|
@ -18,6 +18,16 @@ def encode(*args):
|
|||
return out
|
||||
|
||||
|
||||
extras_dict = {
|
||||
"yolo": [],
|
||||
'torch.nn.modules': ['Conv', 'Conv2d', 'BatchNorm2d', "SiLU", "MaxPool2d", "Upsample", "ModuleList"],
|
||||
"models.common": ["C3", "Bottleneck", "SPPF", "Concat", "Conv"],
|
||||
"numpy.core.multiarray": ["_reconstruct"],
|
||||
'torch': ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage',
|
||||
'BFloat16Storage']
|
||||
}
|
||||
|
||||
|
||||
class RestrictedUnpickler(pickle.Unpickler):
|
||||
extra_handler = None
|
||||
|
||||
|
|
@ -52,24 +62,17 @@ class RestrictedUnpickler(pickle.Unpickler):
|
|||
if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint':
|
||||
import pytorch_lightning.callbacks.model_checkpoint
|
||||
return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
|
||||
if "yolo" in module:
|
||||
return super().find_class(module, name)
|
||||
if module == "models.common" and name == "Conv":
|
||||
return super().find_class(module, name)
|
||||
if 'torch.nn.modules' in module and name in ['Conv', 'Conv2d', 'BatchNorm2d', "SiLU", "MaxPool2d", "Upsample",
|
||||
"ModuleList"]:
|
||||
return super().find_class(module, name)
|
||||
if "models.common" in module and name in ["C3", "Bottleneck", "SPPF", "Concat"]:
|
||||
return super().find_class(module, name)
|
||||
if module == "__builtin__" and name == 'set':
|
||||
return set
|
||||
|
||||
# Forbid everything else.
|
||||
for key in extras_dict:
|
||||
if module in key and name in extras_dict[module] or len(extras_dict[module]) == 0:
|
||||
return super().find_class(name, module)
|
||||
raise Exception(f"global '{module}/{name}' is forbidden")
|
||||
|
||||
|
||||
# Regular expression that accepts 'dirname/version', 'dirname/data.pkl', and 'dirname/data/<number>'
|
||||
allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|(data\.pkl))$")
|
||||
extra_zip_names = []
|
||||
data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$")
|
||||
|
||||
|
||||
|
|
@ -77,7 +80,8 @@ def check_zip_filenames(filename, names):
|
|||
for name in names:
|
||||
if allowed_zip_names_re.match(name):
|
||||
continue
|
||||
|
||||
if name in extra_zip_names:
|
||||
continue
|
||||
raise Exception(f"bad file inside {filename}: {name}")
|
||||
|
||||
|
||||
|
|
@ -87,7 +91,7 @@ def check_pt(filename, extra_handler):
|
|||
# new pytorch format is a zip file
|
||||
with zipfile.ZipFile(filename) as z:
|
||||
check_zip_filenames(filename, z.namelist())
|
||||
|
||||
|
||||
# find filename of data.pkl in zip file: '<directory name>/data.pkl'
|
||||
data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
|
||||
if len(data_pkl_filenames) == 0:
|
||||
|
|
@ -145,7 +149,9 @@ def load_with_extra(filename, extra_handler=None, *args, **kwargs):
|
|||
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print(f"-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr)
|
||||
print(f"You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr)
|
||||
print(
|
||||
f"You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n",
|
||||
file=sys.stderr)
|
||||
return None
|
||||
|
||||
except Exception:
|
||||
|
|
|
|||
|
|
@ -1,26 +1,20 @@
|
|||
import math
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import tqdm
|
||||
from PIL import Image, ImageOps
|
||||
from functools import reduce
|
||||
from PIL import Image, ImageOps, features
|
||||
|
||||
import modules.codeformer_model
|
||||
import modules.gfpgan_model
|
||||
import reallysafe
|
||||
from clipcrop import CropClip
|
||||
from extensions.sd_dreambooth_extension.dreambooth.utils import list_features, is_image
|
||||
from extensions.sd_smartprocess.clipinterrogator import ClipInterrogator
|
||||
from extensions.sd_smartprocess.interrogator import WaifuDiffusionInterrogator, BooruInterrogator
|
||||
from modules import shared, images, safe
|
||||
from modules.shared import cmd_opts
|
||||
|
||||
if cmd_opts.deepdanbooru:
|
||||
import modules.deepbooru as deepbooru
|
||||
|
||||
|
||||
def printi(message):
|
||||
|
|
@ -28,6 +22,31 @@ def printi(message):
|
|||
print(message)
|
||||
|
||||
|
||||
def list_features():
|
||||
# Create buffer for pilinfo() to write into rather than stdout
|
||||
buffer = StringIO()
|
||||
features.pilinfo(out=buffer)
|
||||
pil_features = []
|
||||
# Parse and analyse lines
|
||||
for line in buffer.getvalue().splitlines():
|
||||
if "Extensions:" in line:
|
||||
ext_list = line.split(": ")[1]
|
||||
extensions = ext_list.split(", ")
|
||||
for extension in extensions:
|
||||
if extension not in pil_features:
|
||||
pil_features.append(extension)
|
||||
return pil_features
|
||||
|
||||
|
||||
def is_image(path: Path, feats=None):
|
||||
if feats is None:
|
||||
feats = []
|
||||
if not len(feats):
|
||||
feats = list_features()
|
||||
is_img = path.is_file() and path.suffix.lower() in feats
|
||||
return is_img
|
||||
|
||||
|
||||
def preprocess(rename,
|
||||
src,
|
||||
dst,
|
||||
|
|
|
|||
Loading…
Reference in New Issue