Compare commits
8 Commits
| Author | SHA1 | Date |
|---|---|---|
|
|
d6c7dcef02 | |
|
|
b69c821b97 | |
|
|
620ebe1333 | |
|
|
74de5b7a07 | |
|
|
652286e46d | |
|
|
da5053bc2e | |
|
|
c716c83db0 | |
|
|
94cc5a2a6b |
|
|
@ -3,6 +3,8 @@ import re, sys
|
|||
from typing import List, Set, Optional
|
||||
from enum import Enum
|
||||
|
||||
from PIL import ImageFile
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
|
|
@ -11,7 +13,7 @@ from modules.textual_inversion.dataset import re_numbers_at_start
|
|||
|
||||
from scripts.singleton import Singleton
|
||||
|
||||
from scripts import logger
|
||||
from scripts import logger, utilities
|
||||
from scripts.paths import paths
|
||||
|
||||
from . import (
|
||||
|
|
@ -21,14 +23,17 @@ from . import (
|
|||
taggers_builtin
|
||||
)
|
||||
from .custom_scripts import CustomScripts
|
||||
from .interrogator_names import BLIP2_CAPTIONING_NAMES, WD_TAGGERS, WD_TAGGERS_TIMM
|
||||
from scripts.tokenizer import clip_tokenizer
|
||||
from scripts.tagger import Tagger
|
||||
|
||||
re_tags = re.compile(r"^([\s\S]+?)( \[\d+\])?$")
|
||||
re_newlines = re.compile(r"[\r\n]+")
|
||||
|
||||
def convert_rgb(data:Image.Image):
|
||||
return data.convert("RGB")
|
||||
def get_square_rgb(data:Image.Image):
|
||||
data_rgb = utilities.get_rgb_image(data)
|
||||
size = max(data.size)
|
||||
return utilities.resize_and_fill(data_rgb, (size, size))
|
||||
|
||||
class DatasetTagEditor(Singleton):
|
||||
class SortBy(Enum):
|
||||
|
|
@ -68,39 +73,26 @@ class DatasetTagEditor(Singleton):
|
|||
custom_taggers:list[Tagger] = custom_tagger_scripts.load_derived_classes(Tagger)
|
||||
logger.write(f"Custom taggers loaded: {[tagger().name() for tagger in custom_taggers]}")
|
||||
|
||||
self.BLIP2_CAPTIONING_NAMES = [
|
||||
"blip2-opt-2.7b",
|
||||
"blip2-opt-2.7b-coco",
|
||||
"blip2-opt-6.7b",
|
||||
"blip2-opt-6.7b-coco",
|
||||
"blip2-flan-t5-xl",
|
||||
"blip2-flan-t5-xl-coco",
|
||||
"blip2-flan-t5-xxl",
|
||||
]
|
||||
|
||||
self.WD_TAGGERS = {
|
||||
"wd-v1-4-vit-tagger" : 0.35,
|
||||
"wd-v1-4-convnext-tagger" : 0.35,
|
||||
"wd-v1-4-vit-tagger-v2" : 0.3537,
|
||||
"wd-v1-4-convnext-tagger-v2" : 0.3685,
|
||||
"wd-v1-4-convnextv2-tagger-v2" : 0.371,
|
||||
"wd-v1-4-swinv2-tagger-v2" : 0.3771,
|
||||
"wd-v1-4-moat-tagger-v2" : 0.3771,
|
||||
"wd-vit-tagger-v3" : 0.2614,
|
||||
"wd-convnext-tagger-v3" : 0.2682,
|
||||
"wd-swinv2-tagger-v3" : 0.2653,
|
||||
}
|
||||
# {tagger name : default tagger threshold}
|
||||
# v1: idk if it's okay v2: P=R thresholds on each repo https://huggingface.co/SmilingWolf
|
||||
|
||||
def read_wd_batchsize(name:str):
|
||||
if "vit" in name:
|
||||
return shared.opts.dataset_editor_batch_size_vit
|
||||
elif "convnext" in name:
|
||||
return shared.opts.dataset_editor_batch_size_convnext
|
||||
elif "swinv2" in name:
|
||||
return shared.opts.dataset_editor_batch_size_swinv2
|
||||
|
||||
self.INTERROGATORS = (
|
||||
[taggers_builtin.BLIP()]
|
||||
+ [taggers_builtin.BLIP2(name) for name in self.BLIP2_CAPTIONING_NAMES]
|
||||
+ [taggers_builtin.BLIP2(name) for name in BLIP2_CAPTIONING_NAMES]
|
||||
+ [taggers_builtin.GITLarge()]
|
||||
+ [taggers_builtin.DeepDanbooru()]
|
||||
+ [
|
||||
taggers_builtin.WaifuDiffusion(name, threshold)
|
||||
for name, threshold in self.WD_TAGGERS.items()
|
||||
for name, threshold in WD_TAGGERS.items()
|
||||
]
|
||||
+ [
|
||||
taggers_builtin.WaifuDiffusionTimm(name, threshold, int(read_wd_batchsize(name)))
|
||||
for name, threshold in WD_TAGGERS_TIMM.items()
|
||||
]
|
||||
+ [taggers_builtin.Z3D_E621()]
|
||||
+ [cls_tagger() for cls_tagger in custom_taggers]
|
||||
|
|
@ -109,7 +101,7 @@ class DatasetTagEditor(Singleton):
|
|||
|
||||
def interrogate_image(self, path: str, interrogator_name: str, threshold_booru, threshold_wd, threshold_z3d):
|
||||
try:
|
||||
img = Image.open(path).convert("RGB")
|
||||
img = get_square_rgb(Image.open(path))
|
||||
except:
|
||||
return ""
|
||||
else:
|
||||
|
|
@ -118,7 +110,7 @@ class DatasetTagEditor(Singleton):
|
|||
if isinstance(it, taggers_builtin.DeepDanbooru):
|
||||
with it as tg:
|
||||
res = tg.predict(img, threshold_booru)
|
||||
elif isinstance(it, taggers_builtin.WaifuDiffusion):
|
||||
elif isinstance(it, taggers_builtin.WaifuDiffusion) or isinstance(it, taggers_builtin.WaifuDiffusionTimm):
|
||||
with it as tg:
|
||||
res = tg.predict(img, threshold_wd)
|
||||
elif isinstance(it, taggers_builtin.Z3D_E621):
|
||||
|
|
@ -704,19 +696,27 @@ class DatasetTagEditor(Singleton):
|
|||
continue
|
||||
try:
|
||||
img = Image.open(img_path)
|
||||
if (max_res > 0):
|
||||
img_res = int(max_res), int(max_res)
|
||||
img.thumbnail(img_res)
|
||||
except:
|
||||
continue
|
||||
else:
|
||||
abs_path = str(img_path.absolute())
|
||||
if not use_temp_dir and max_res <= 0:
|
||||
img.already_saved_as = abs_path
|
||||
images[abs_path] = img
|
||||
|
||||
imgpaths.append(abs_path)
|
||||
imgpaths.append(abs_path)
|
||||
return imgpaths, images
|
||||
|
||||
def load_thumbnails(images_raw: dict[str, Image.Image]):
|
||||
images = {}
|
||||
if max_res > 0:
|
||||
for img_path, img in images_raw.items():
|
||||
img_res = int(max_res), int(max_res)
|
||||
images[img_path] = img.copy()
|
||||
images[img_path].thumbnail(img_res)
|
||||
else:
|
||||
for img_path, img in images_raw.items():
|
||||
if not use_temp_dir:
|
||||
img.already_saved_as = img_path
|
||||
images[img_path] = img
|
||||
return images
|
||||
|
||||
def load_captions(imgpaths: list[str]):
|
||||
taglists = []
|
||||
|
|
@ -752,7 +752,7 @@ class DatasetTagEditor(Singleton):
|
|||
if it.name() in interrogator_names:
|
||||
if isinstance(it, taggers_builtin.DeepDanbooru):
|
||||
tagger_thresholds.append((it, threshold_booru))
|
||||
elif isinstance(it, taggers_builtin.WaifuDiffusion):
|
||||
elif isinstance(it, taggers_builtin.WaifuDiffusion) or isinstance(it, taggers_builtin.WaifuDiffusionTimm):
|
||||
tagger_thresholds.append((it, threshold_waifu))
|
||||
elif isinstance(it, taggers_builtin.Z3D_E621):
|
||||
tagger_thresholds.append((it, threshold_z3d))
|
||||
|
|
@ -760,13 +760,15 @@ class DatasetTagEditor(Singleton):
|
|||
tagger_thresholds.append((it, None))
|
||||
|
||||
if kohya_json_path:
|
||||
imgpaths, self.images, taglists = kohya_metadata.read(
|
||||
imgpaths, images_raw, taglists = kohya_metadata.read(
|
||||
img_dir, kohya_json_path, use_temp_dir
|
||||
)
|
||||
else:
|
||||
imgpaths, self.images = load_images(filepaths)
|
||||
imgpaths, images_raw = load_images(filepaths)
|
||||
taglists = load_captions(imgpaths)
|
||||
|
||||
self.images = load_thumbnails(images_raw)
|
||||
|
||||
interrogate_tags = {img_path : [] for img_path in imgpaths}
|
||||
|
||||
img_to_interrogate = [
|
||||
|
|
@ -783,11 +785,11 @@ class DatasetTagEditor(Singleton):
|
|||
|
||||
def gen_data(paths:list[str], images:dict[str, Image.Image]):
|
||||
for img_path in paths:
|
||||
yield images.get(img_path)
|
||||
yield images[img_path]
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
result = list(executor.map(convert_rgb, gen_data(img_to_interrogate, self.images)))
|
||||
result = list(executor.map(get_square_rgb, gen_data(img_to_interrogate, images_raw)))
|
||||
logger.write("Preprocess completed")
|
||||
|
||||
for tg, th in tqdm(tagger_thresholds):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,28 @@
|
|||
|
||||
BLIP2_CAPTIONING_NAMES = [
|
||||
"blip2-opt-2.7b",
|
||||
"blip2-opt-2.7b-coco",
|
||||
"blip2-opt-6.7b",
|
||||
"blip2-opt-6.7b-coco",
|
||||
"blip2-flan-t5-xl",
|
||||
"blip2-flan-t5-xl-coco",
|
||||
"blip2-flan-t5-xxl",
|
||||
]
|
||||
|
||||
|
||||
# {tagger name : default tagger threshold}
|
||||
# v1: idk if it's okay v2, v3: P=R thresholds on each repo https://huggingface.co/SmilingWolf
|
||||
WD_TAGGERS = {
|
||||
"wd-v1-4-vit-tagger" : 0.35,
|
||||
"wd-v1-4-convnext-tagger" : 0.35,
|
||||
"wd-v1-4-vit-tagger-v2" : 0.3537,
|
||||
"wd-v1-4-convnext-tagger-v2" : 0.3685,
|
||||
"wd-v1-4-convnextv2-tagger-v2" : 0.371,
|
||||
"wd-v1-4-moat-tagger-v2" : 0.3771
|
||||
}
|
||||
WD_TAGGERS_TIMM = {
|
||||
"wd-v1-4-swinv2-tagger-v2" : 0.3771,
|
||||
"wd-vit-tagger-v3" : 0.2614,
|
||||
"wd-convnext-tagger-v3" : 0.2682,
|
||||
"wd-swinv2-tagger-v3" : 0.2653,
|
||||
}
|
||||
|
|
@ -1,7 +1,8 @@
|
|||
from .blip2_captioning import BLIP2Captioning
|
||||
from .git_large_captioning import GITLargeCaptioning
|
||||
from .waifu_diffusion_tagger import WaifuDiffusionTagger
|
||||
from .waifu_diffusion_tagger_timm import WaifuDiffusionTaggerTimm
|
||||
|
||||
__all__ = [
|
||||
"BLIP2Captioning", 'GITLargeCaptioning', 'WaifuDiffusionTagger'
|
||||
"BLIP2Captioning", 'GITLargeCaptioning', 'WaifuDiffusionTagger', 'WaifuDiffusionTaggerTimm'
|
||||
]
|
||||
|
|
@ -0,0 +1,104 @@
|
|||
from PIL import Image
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
import torchvision.transforms as tf
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from modules import shared, devices
|
||||
import launch
|
||||
|
||||
|
||||
class ImageDataset(Dataset):
|
||||
def __init__(self, images:list[Image.Image], transforms:tf.Compose=None):
|
||||
self.images = images
|
||||
self.transforms = transforms
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images)
|
||||
|
||||
def __getitem__(self, i):
|
||||
img = self.images[i]
|
||||
if self.transforms is not None:
|
||||
img = self.transforms(img)
|
||||
return img
|
||||
|
||||
|
||||
|
||||
class WaifuDiffusionTaggerTimm:
|
||||
# some codes are brought from https://github.com/neggles/wdv3-timm and modified
|
||||
|
||||
def __init__(self, model_repo, label_filename="selected_tags.csv"):
|
||||
self.LABEL_FILENAME = label_filename
|
||||
self.MODEL_REPO = model_repo
|
||||
self.model = None
|
||||
self.transform = None
|
||||
self.labels = []
|
||||
|
||||
def load(self):
|
||||
import huggingface_hub
|
||||
|
||||
if not launch.is_installed("timm"):
|
||||
launch.run_pip(
|
||||
"install -U timm",
|
||||
"requirements for dataset-tag-editor [timm]",
|
||||
)
|
||||
import timm
|
||||
from timm.data import create_transform, resolve_data_config
|
||||
|
||||
if not self.model:
|
||||
self.model: torch.nn.Module = timm.create_model(
|
||||
"hf-hub:" + self.MODEL_REPO
|
||||
).eval()
|
||||
state_dict = timm.models.load_state_dict_from_hf(self.MODEL_REPO)
|
||||
self.model.load_state_dict(state_dict)
|
||||
self.model.to(devices.device)
|
||||
self.transform = create_transform(
|
||||
**resolve_data_config(self.model.pretrained_cfg, model=self.model)
|
||||
)
|
||||
|
||||
path_label = huggingface_hub.hf_hub_download(
|
||||
self.MODEL_REPO, self.LABEL_FILENAME
|
||||
)
|
||||
import pandas as pd
|
||||
|
||||
self.labels = pd.read_csv(path_label)["name"].tolist()
|
||||
|
||||
def unload(self):
|
||||
if not shared.opts.interrogate_keep_models_in_memory:
|
||||
self.model = None
|
||||
devices.torch_gc()
|
||||
|
||||
def apply(self, image: Image.Image):
|
||||
if not self.model:
|
||||
return []
|
||||
|
||||
image_t: torch.Tensor = self.transform(image).unsqueeze(0)
|
||||
image_t = image_t[:, [2, 1, 0]]
|
||||
image_t = image_t.to(devices.device)
|
||||
|
||||
with torch.inference_mode():
|
||||
features = self.model.forward(image_t)
|
||||
probs = F.sigmoid(features).detach().cpu().numpy()
|
||||
|
||||
labels: list[Tuple[str, float]] = list(zip(self.labels, probs[0].astype(float)))
|
||||
|
||||
return labels
|
||||
|
||||
|
||||
def apply_multi(self, images: list[Image.Image], batch_size: int):
|
||||
if not self.model:
|
||||
return []
|
||||
|
||||
dataset = ImageDataset(images, self.transform)
|
||||
dataloader = DataLoader(dataset, batch_size=batch_size)
|
||||
|
||||
with torch.inference_mode():
|
||||
for batch in tqdm(dataloader):
|
||||
batch = batch[:, [2, 1, 0]].to(devices.device)
|
||||
features = self.model.forward(batch)
|
||||
probs = F.sigmoid(features).detach().cpu().numpy()
|
||||
labels: list[Tuple[str, float]] = [list(zip(self.labels, probs[i].astype(float))) for i in range(probs.shape[0])]
|
||||
yield labels
|
||||
|
|
@ -8,7 +8,7 @@ from modules import devices, shared
|
|||
from modules import deepbooru as db
|
||||
|
||||
from scripts.tagger import Tagger, get_replaced_tag
|
||||
from .interrogators import BLIP2Captioning, GITLargeCaptioning, WaifuDiffusionTagger
|
||||
from .interrogators import BLIP2Captioning, GITLargeCaptioning, WaifuDiffusionTagger, WaifuDiffusionTaggerTimm
|
||||
|
||||
|
||||
class BLIP(Tagger):
|
||||
|
|
@ -142,6 +142,28 @@ class WaifuDiffusion(Tagger):
|
|||
return self.repo_name
|
||||
|
||||
|
||||
class WaifuDiffusionTimm(WaifuDiffusion):
|
||||
def __init__(self, repo_name, threshold, batch_size=4):
|
||||
super().__init__(repo_name, threshold)
|
||||
self.tagger_inst = WaifuDiffusionTaggerTimm("SmilingWolf/" + repo_name)
|
||||
self.batch_size = batch_size
|
||||
|
||||
def predict_pipe(self, data: list[Image.Image], threshold: Optional[float] = None):
|
||||
for labels_list in self.tagger_inst.apply_multi(data, batch_size=self.batch_size):
|
||||
for labels in labels_list:
|
||||
if not shared.opts.dataset_editor_use_rating:
|
||||
labels = labels[4:]
|
||||
|
||||
if threshold is not None:
|
||||
if threshold < 0:
|
||||
threshold = self.threshold
|
||||
tags = [get_replaced_tag(tag) for tag, value in labels if value > threshold]
|
||||
else:
|
||||
tags = [get_replaced_tag(tag) for tag, _ in labels]
|
||||
|
||||
yield tags
|
||||
|
||||
|
||||
class Z3D_E621(Tagger):
|
||||
def __init__(self):
|
||||
self.tagger_inst = WaifuDiffusionTagger("toynya/Z3D-E621-Convnext", label_filename="tags-selected.csv")
|
||||
|
|
|
|||
|
|
@ -413,10 +413,6 @@ def on_ui_tabs():
|
|||
ui.move_or_delete_files.set_callbacks(
|
||||
o_update_filter_and_gallery,
|
||||
ui.dataset_gallery,
|
||||
ui.filter_by_tags,
|
||||
ui.batch_edit_captions,
|
||||
ui.filter_by_selection,
|
||||
ui.edit_caption_of_selected_image,
|
||||
get_filters,
|
||||
update_filter_and_gallery,
|
||||
)
|
||||
|
|
@ -472,6 +468,33 @@ def on_ui_settings():
|
|||
),
|
||||
)
|
||||
|
||||
shared.opts.add_option(
|
||||
"dataset_editor_batch_size_vit",
|
||||
shared.OptionInfo(
|
||||
4,
|
||||
"Inference batch size for ViT taggers",
|
||||
section=section,
|
||||
),
|
||||
)
|
||||
|
||||
shared.opts.add_option(
|
||||
"dataset_editor_batch_size_convnext",
|
||||
shared.OptionInfo(
|
||||
4,
|
||||
"Inference batch size for ConvNeXt taggers",
|
||||
section=section,
|
||||
),
|
||||
)
|
||||
|
||||
shared.opts.add_option(
|
||||
"dataset_editor_batch_size_swinv2",
|
||||
shared.OptionInfo(
|
||||
4,
|
||||
"Inference batch size for SwinTransformerV2 taggers",
|
||||
section=section,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
script_callbacks.on_ui_settings(on_ui_settings)
|
||||
script_callbacks.on_ui_tabs(on_ui_tabs)
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ class FilterByTagsUI(UIBase):
|
|||
common_callback = lambda : \
|
||||
update_gallery() + \
|
||||
batch_edit_captions.get_common_tags(get_filters, self) + \
|
||||
[move_or_delete_files.get_current_move_or_delete_target_num()] + \
|
||||
[move_or_delete_files.update_current_move_or_delete_target_num()] + \
|
||||
[batch_edit_captions.tag_select_ui_remove.cbg_tags_update()]
|
||||
|
||||
common_callback_output = \
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ class MoveOrDeleteFilesUI(UIBase):
|
|||
def __init__(self):
|
||||
self.target_data = 'Selected One'
|
||||
self.current_target_txt = ''
|
||||
self.update_func = None
|
||||
|
||||
def create_ui(self, cfg_file_move_delete):
|
||||
gr.HTML(value='<b>Note: </b>Moved or deleted images will be unloaded.')
|
||||
|
|
@ -27,11 +28,15 @@ class MoveOrDeleteFilesUI(UIBase):
|
|||
gr.HTML(value='<b>Note: </b>DELETE cannot be undone. The files will be deleted completely.')
|
||||
self.btn_move_or_delete_delete_files = gr.Button(value='DELETE File(s)', variant='primary')
|
||||
|
||||
def get_current_move_or_delete_target_num(self):
|
||||
return self.current_target_txt
|
||||
def update_current_move_or_delete_target_num(self):
|
||||
if self.update_func:
|
||||
return self.update_func(self.target_data)
|
||||
else:
|
||||
return self.current_target_txt
|
||||
|
||||
def set_callbacks(self, o_update_filter_and_gallery:List[gr.components.Component], dataset_gallery:DatasetGalleryUI, filter_by_tags:FilterByTagsUI, batch_edit_captions:BatchEditCaptionsUI, filter_by_selection:FilterBySelectionUI, edit_caption_of_selected_image:EditCaptionOfSelectedImageUI, get_filters:Callable[[], List[dte_module.filters.Filter]], update_filter_and_gallery:Callable[[], List]):
|
||||
def _get_current_move_or_delete_target_num():
|
||||
def set_callbacks(self, o_update_filter_and_gallery:List[gr.components.Component], dataset_gallery:DatasetGalleryUI, get_filters:Callable[[], List[dte_module.filters.Filter]], update_filter_and_gallery:Callable[[], List]):
|
||||
def _get_current_move_or_delete_target_num(text: str):
|
||||
self.target_data = text
|
||||
if self.target_data == 'Selected One':
|
||||
self.current_target_txt = f'Target dataset num: {1 if dataset_gallery.selected_index != -1 else 0}'
|
||||
elif self.target_data == 'All Displayed Ones':
|
||||
|
|
@ -41,25 +46,17 @@ class MoveOrDeleteFilesUI(UIBase):
|
|||
self.current_target_txt = f'Target dataset num: 0'
|
||||
return self.current_target_txt
|
||||
|
||||
self.update_func = _get_current_move_or_delete_target_num
|
||||
|
||||
update_args = {
|
||||
'fn' : _get_current_move_or_delete_target_num,
|
||||
'inputs' : None,
|
||||
'fn': self.update_func,
|
||||
'inputs': [self.rb_move_or_delete_target_data],
|
||||
'outputs' : [self.ta_move_or_delete_target_dataset_num]
|
||||
}
|
||||
|
||||
batch_edit_captions.btn_apply_edit_tags.click(lambda:None).then(**update_args)
|
||||
|
||||
batch_edit_captions.btn_apply_sr_tags.click(lambda:None).then(**update_args)
|
||||
|
||||
filter_by_selection.btn_apply_image_selection_filter.click(lambda:None).then(**update_args)
|
||||
|
||||
filter_by_tags.btn_clear_tag_filters.click(lambda:None).then(**update_args)
|
||||
|
||||
filter_by_tags.btn_clear_all_filters.click(lambda:None).then(**update_args)
|
||||
|
||||
edit_caption_of_selected_image.btn_apply_changes_selected_image.click(lambda:None).then(**update_args)
|
||||
|
||||
self.rb_move_or_delete_target_data.change(**update_args)
|
||||
dataset_gallery.cbg_hidden_dataset_filter.change(lambda:None).then(**update_args)
|
||||
dataset_gallery.nb_hidden_image_index.change(lambda:None).then(**update_args)
|
||||
|
||||
def move_files(
|
||||
target_data: str,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,68 @@
|
|||
from typing import Tuple
|
||||
import math
|
||||
|
||||
from PIL import Image
|
||||
|
||||
if not hasattr(Image, 'Resampling'): # Pillow<9.0
|
||||
Image.Resampling = Image
|
||||
|
||||
|
||||
def resize(image: Image.Image, size: Tuple[int, int]):
|
||||
return image.resize(size, resample=Image.Resampling.LANCZOS)
|
||||
|
||||
|
||||
def get_rgb_image(image:Image.Image):
|
||||
if image.mode not in ["RGB", "RGBA"]:
|
||||
image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB")
|
||||
if image.mode == "RGBA":
|
||||
white = Image.new("RGBA", image.size, (255, 255, 255, 255))
|
||||
white.alpha_composite(image)
|
||||
image = white.convert("RGB")
|
||||
return image
|
||||
|
||||
|
||||
def resize_and_fill(image: Image.Image, size: Tuple[int, int], repeat_edge = True, fill_rgb:tuple[int,int,int] = (255, 255, 255)):
|
||||
width, height = size
|
||||
scale_w, scale_h = width / image.width, height / image.height
|
||||
resized_w, resized_h = width, height
|
||||
if scale_w < scale_h:
|
||||
resized_h = image.height * resized_w // image.width
|
||||
elif scale_h < scale_w:
|
||||
resized_w = image.width * resized_h // image.height
|
||||
|
||||
resized = resize(image, (resized_w, resized_h))
|
||||
if resized_w == width and resized_h == height:
|
||||
return resized
|
||||
|
||||
if repeat_edge:
|
||||
fill_l = math.floor((width - resized_w) / 2)
|
||||
fill_r = width - resized_w - fill_l
|
||||
fill_t = math.floor((height - resized_h) / 2)
|
||||
fill_b = height - resized_h - fill_t
|
||||
result = Image.new("RGB", (width, height))
|
||||
result.paste(resized, (fill_l, fill_t))
|
||||
if fill_t > 0:
|
||||
result.paste(resized.resize((width, fill_t), box=(0, 0, width, 0)), (0, 0))
|
||||
if fill_b > 0:
|
||||
result.paste(
|
||||
resized.resize(
|
||||
(width, fill_b), box=(0, resized.height, width, resized.height)
|
||||
),
|
||||
(0, resized.height + fill_t),
|
||||
)
|
||||
if fill_l > 0:
|
||||
result.paste(resized.resize((fill_l, height), box=(0, 0, 0, height)), (0, 0))
|
||||
if fill_r > 0:
|
||||
result.paste(
|
||||
resized.resize(
|
||||
(fill_r, height), box=(resized.width, 0, resized.width, height)
|
||||
),
|
||||
(resized.width + fill_l, 0),
|
||||
)
|
||||
return result
|
||||
else:
|
||||
result = Image.new("RGB", size, fill_rgb)
|
||||
result.paste(resized, box=((width - resized_w) // 2, (height - resized_h) // 2))
|
||||
return result.convert("RGB")
|
||||
|
||||
|
||||
|
|
@ -1 +1 @@
|
|||
0.3.3
|
||||
0.3.4
|
||||
|
|
|
|||
Loading…
Reference in New Issue