v8.3.0
|
|
@ -3,6 +3,37 @@ All notable changes to this project will be documented in this file.
|
|||
|
||||
For more details on new features, please check the [Manual](./MANUAL.md).
|
||||
|
||||
<details open><summary>8.3.0 - 21 April 2023</summary>
|
||||
|
||||
### Added
|
||||
- New shortcode `[color_correct]`: provides the same automatic color grading features as Bodysnatcher, but in the form of a standalone block
|
||||
- `[color_correct]`: supports the `source` argument, which is a string that processes the initial image with `[txt2mask]` and uses the resulting masked image as a source for color correction, as opposed to the entire image
|
||||
- `[txt2mask]`: implemented [CLIP Surgery](https://github.com/xmed-lab/CLIP_Surgery) as a new method type ("clip_surgery") which optionally supports Segment Anything (dev comment: this is better than `clipseg` at certain tasks but worse at others - `clipseg` is still default for the time being)
|
||||
- `[txt2mask]`: new argument `stamp` that pastes a temporary PNG onto the init image before running mask processing, useful for redacting a portion of the image for example
|
||||
- `[txt2mask]`: supports `stamp_method` to choose sizing and positioning logic
|
||||
- `[txt2mask]`: supports `stamp_x` and `stamp_y` for precise positioning of the stamp
|
||||
- `[txt2mask]`: supports `stamp_blur` radius to engage optional gaussian filter
|
||||
- `[txt2mask]`: 10 basic stamps are included by default
|
||||
- `[zoom_enhance]`: now supports `mask_method`
|
||||
- `[template]`: any kwargs in the Wizard template block will be passed to the constructed `[file]` block
|
||||
- `[file]`: experimental new argument `_bypass_if` that skips file processing if the value returns true (intended to be used with Wizard templates)
|
||||
- `[get sd_model]` should now work as expected
|
||||
- Bodysnatcher: new option `background_mode` that inverts the mask and disables the zoom_enhance step
|
||||
- Bodysnatcher: new setting `stamp`
|
||||
|
||||
### Changed
|
||||
- `[zoom_enhance]`: the `color_correct_method` default value is now `none`
|
||||
- `[zoom_enhance]`: fix for adaptive CFG scaling
|
||||
- `[zoom_enhance]`: minor tweaks to the adaptive scaling algorithm
|
||||
- `[zoom_enhance]`: speculative fix for an issue with batch processing, which may also resolve an infinite loop that could occur with Bodysnatcher
|
||||
- `[txt2mask]`: the "sam" `method` has been renamed to "grounded_sam"
|
||||
- `[txt2mask]`: fixed a crash related to switching back and forth between `method` types
|
||||
- Moved legacy shortcodes into their own `legacy` folder
|
||||
- Fixed a crash related to empty shortcode arguments
|
||||
- Updated the manual
|
||||
|
||||
</details>
|
||||
|
||||
<details><summary>8.2.0 - 18 April 2023</summary>
|
||||
|
||||
### Added
|
||||
|
|
|
|||
126
docs/MANUAL.md
|
|
@ -289,7 +289,7 @@ Pressing **"Generate Shortcode"** will assemble a ready-to-use block of code tha
|
|||
|
||||
Alternatively, you can enable `Auto-include this in prompt` which will add the shortcode to your prompts behind the scenes. This essentially lets you use Unprompted shortcodes as if they were standalone scripts. You can enable/disable this setting on a per-shortcode basis.
|
||||
|
||||
The Wizard includes two distinct modes: Shortcodes and Functions.
|
||||
The Wizard includes two distinct modes: Shortcodes and Templates.
|
||||
|
||||
</details>
|
||||
|
||||
|
|
@ -323,17 +323,17 @@ There are a few reserved argument names that will modify the Wizard's behavior:
|
|||
|
||||
</details>
|
||||
|
||||
<details><summary>Functions Mode</summary>
|
||||
<details><summary>Templates Mode</summary>
|
||||
|
||||
This mode presents you with a list of txt files inside your `Unprompted/templates` directory that begin with a `[template]` block.
|
||||
|
||||
By including this block in your file, Unprompted will parse the file for its `[set x _new]` statements and adapt those into a custom Wizard UI.
|
||||
|
||||
The `_new` argument means "only set this variable if it doesn't already exist," which are generally the variables we want to show in a UI.
|
||||
The `_new` argument means *"only set this variable if it doesn't already exist,"* which are generally the variables we want to show in a UI.
|
||||
|
||||
The `[template]` block supports the optional `name` argument which is a friendly name for your function shown in the functions dropdown menu.
|
||||
The `[template]` block supports the optional `name` argument which is a friendly name for your function shown in the Templates dropdown menu.
|
||||
|
||||
The content of `[template]` is a description of your function to be rendered with [Markdown](https://www.markdownguide.org/basic-syntax/), which means you can include rich content like pictures or links. It will show up at the top of your UI.
|
||||
The content of `[template]` is a description of your script to be rendered with [Markdown](https://www.markdownguide.org/basic-syntax/), which means you can include rich content like pictures or links. It will show up at the top of your UI.
|
||||
|
||||
The `[set]` block supports `_ui` which determines the type of UI element to render your variable as. Defaults to `textbox`. Here are the possible types:
|
||||
|
||||
|
|
@ -343,6 +343,10 @@ The `[set]` block supports `_ui` which determines the type of UI element to rend
|
|||
- `dropdown`: A dropdown menu that is populated by the `_choices` argument, constructed as a delimited list.
|
||||
- `slider`: Limits selection to a range of numbers. You must also specify `_minimum`, `_maximum` and `_step` (step size, normally 1) for this element to work properly.
|
||||
|
||||
The `[set]` block supports `_info` which is descriptive text that will appear near the UI element.
|
||||
|
||||
Supports the `[wizard_ui_accordion]` shortcode which will group the inner `[set]` blocks into a collapsible UI element.
|
||||
|
||||
</details>
|
||||
|
||||
## ⚙️ Shortcodes
|
||||
|
|
@ -639,7 +643,7 @@ RESULT: She said
|
|||
|
||||
<details><summary>[do until(str)]</summary>
|
||||
|
||||
Do-until style loop. The content is processed, then the `until` expression is evaluated - if it's true, the content is processed again. Repeat until `until` is false.
|
||||
Do-until style loop. The content is processed, then the `until` expression is evaluated - if it's false, the content is processed again. Repeat until `until` is true.
|
||||
|
||||
```
|
||||
[sets my_var=0]
|
||||
|
|
@ -1028,6 +1032,8 @@ Returns a slice of the content as determined by the keyword arguments.
|
|||
|
||||
`end` is the last position of the slice. Defaults to 0.
|
||||
|
||||
Alternatively, you can pass a string into `start` or `end` and it will find the index of that string within the `content`.
|
||||
|
||||
`step` is the skip interval. Defaults to 1 (in other words, a continuous substring.)
|
||||
|
||||
`unit` is either `characters` or `words` and refers to the unit of the aforementioned arguments. Defaults to `characters`.
|
||||
|
|
@ -1129,38 +1135,6 @@ Note that variables are automatically deleted at the end of each run - you do **
|
|||
|
||||
This section describes all of the included shortcodes which are specifically designed for use with the A1111 WebUI.
|
||||
|
||||
<details><summary>[controlnet]</summary>
|
||||
|
||||
Enables support for [ControlNet](https://github.com/lllyasviel/ControlNet) models in img2img mode. ControlNet is a neural network structure to control diffusion models by adding extra conditions.
|
||||
|
||||
**NOTE:** This is a "wrapper" implementation of the original ControlNet code. For a more robust solution, you can check out [the dedicated ControlNet extension by Mikubill](https://github.com/Mikubill/sd-webui-controlnet).
|
||||
|
||||
You need a bare minimum of 8 GB of VRAM to use this shortcode, although 12 GB is recommended.
|
||||
|
||||
Supports the `model` argument, which is the name of a ControlNet checkpoint in your `models/Stable-diffusion` directory (do not include the file extension.) You can download ControlNet checkpoints from [the official HuggingFace page](https://huggingface.co/lllyasviel/ControlNet/tree/main/models).
|
||||
|
||||
For each model, you also need a copy of the [cldm_v15.yaml](https://github.com/lllyasviel/ControlNet/tree/main/models) config file. Rename it to match the name of the ControlNet model, e.g. `control_sd15_normal.yaml`.
|
||||
|
||||
For each model, you also need the associated [annotator files available here](https://huggingface.co/lllyasviel/ControlNet/tree/main/annotator/ckpts). Place these into your `extensions/unprompted/lib_unprompted/stable_diffusion/controlnet/annotator/ckpts` folder.
|
||||
|
||||
If you run into any errors, please triple-check your filepaths before opening a bug report.
|
||||
|
||||
You can use ControlNet with custom SD 1.5 models [by merging checkpoints as described here](https://github.com/lllyasviel/ControlNet/issues/4#issuecomment-1426877944).
|
||||
|
||||
Please be aware that the last part of your model's filename indicates which type of ControlNet model it is. The following ControlNet model types are supported: `openpose`, `scribble`, `mlsd`, `depth`, `normal`, `hed`, `canny`, `seg`
|
||||
|
||||
ControlNet models should **not** be loaded manually from your WebUI dropdown.
|
||||
|
||||
Supports the `save_memory` argument to minimize VRAM requirements.
|
||||
|
||||
Supports the `detect_resolution` argument which is the size of the detected map. Defaults to 512. Some models may perform better at 384. Lowering this value to 256 may help with VRAM requirements.
|
||||
|
||||
Supports the `eta` argument.
|
||||
|
||||
Supports the following model-specific arguments: `value_threshold`, `distance_threshold`, `bg_threshold`, `low_threshold`, `high_threshold`
|
||||
|
||||
</details>
|
||||
|
||||
<details><summary>[file2mask]</summary>
|
||||
|
||||
Allows you to modify or replace your img2img mask with arbitrary files.
|
||||
|
|
@ -1391,6 +1365,8 @@ This is a helper shortcode that should be used if multiple init images, multiple
|
|||
|
||||
A port of [the script](https://github.com/ThereforeGames/txt2mask) by the same name, `[txt2mask]` allows you to create a region for inpainting based only on the text content (as opposed to the brush tool.) This shortcode only works in the img2img tab of the A1111 WebUI.
|
||||
|
||||
Supports the `method` argument which determines the technology to use for masking. Defaults to `clipseg`. Can be changed to `sam` which will utilize [Segment Anything](https://segment-anything.com/) instead.
|
||||
|
||||
Supports the `mode` argument which determines how the text mask will behave alongside a brush mask:
|
||||
- `add` will overlay the two masks. This is the default value.
|
||||
- `discard` will ignore the brush mask entirely.
|
||||
|
|
@ -1424,6 +1400,8 @@ Supports the optional `show` positional argument which will append the final mas
|
|||
|
||||
Supports the optional `legacy_weights` positional argument which will utilize the original CLIPseg weights. By default, `[txt2mask]` will use the [refined weights](https://github.com/timojl/clipseg#new-fine-grained-weights).
|
||||
|
||||
Supports the `unload_model` argument, which will unload the clipseg model after processing. On my GTX 3090, this adds about 3 seconds to inference time. Defaults to `False`, and should only be enabled on devices with low memory.
|
||||
|
||||
The content and `negative_mask` both support the vertical pipe delimiter (`|`) which allows you to specify multiple subjects for masking.
|
||||
|
||||
```
|
||||
|
|
@ -1434,7 +1412,7 @@ The content and `negative_mask` both support the vertical pipe delimiter (`|`) w
|
|||
|
||||
<details><summary>[zoom_enhance]</summary>
|
||||
|
||||
Upscales a selected portion of an image via `[img2img]` and `[txt2mask]`.
|
||||
Upscales a selected portion of an image via `[img2img]` and `[txt2mask]`, then pastes it seamlessly back onto the original.
|
||||
|
||||
Greatly improves low-resolution details like faces and hands. It is significantly faster than Hires Fix and more flexible than the "Restore Faces" option.
|
||||
|
||||
|
|
@ -1446,11 +1424,23 @@ Supports the `replacement` keyword argument which is the prompt that will be use
|
|||
|
||||
Supports the `negative_replacement` keyword argument, which is the negative prompt that will be used on the mask region via `[img2img]`. Defaults to an empty string.
|
||||
|
||||
Both `replacement` and `negative_replacement` support multiple, delimited search terms via `Unprompted.config.syntax.delimiter`.
|
||||
|
||||
Supports `mask_sort_method` which is used when multiple, non-contiguous masks are detected. Defaults to `left-to-right`. Options include: `left-to-right`, `right-to-left`, `top-to-bottom`, `bottom-to-top`, `big-to-small`, `small-to-big`, `unsorted`.
|
||||
|
||||
Supports the `mode` keyword argument, which determines how the shortcode will interact with a pre-existing image mask. Defaults to `subtract`, which will remove your masked pixels from the shortcode's calculations. Options include: `add`, `subtract`, `discard`.
|
||||
|
||||
Supports the `bypass_adaptive_hires` positional argument. By default, the shortcode will scale up some inference settings such as CFG scale and sharpness depending on the resolution of the init image. Include this argument to disable that behavior.
|
||||
|
||||
Supports the `hires_size_max` keyword argument which is a hard limit on the size of the upscaled image, in order to avoid OOM errors. Defaults to 1024.
|
||||
|
||||
Supports the `blur_size` keyword argument, which corresponds to the radius of the gaussian blur that will be applied to the mask of the upscaled image - this helps it blend seamlessly back into your original image. Defaults to `0.03`. Note: this is a float that is a percentage of the total canvas size; 0.03 means 3% of the total canvas.
|
||||
|
||||
Supports the `sharpen_amount` argument, which is a float that determines the strength of the unsharp filter that is applied in post-processing.
|
||||
|
||||
Supports the `denoising_max` keyword argument. The `[zoom_enhance]` shortcode is equipped with **dynamic denoising strength** based on a simple idea: the smaller the mask region, the higher denoise we should apply. This argument lets you set the upper limit of that feature.
|
||||
|
||||
Supports the `mask_size_max` keyword argument. Defaults to `0.3`. If a mask region is determined to be greater than this value, it will not be processed by `[zoom_enhance]`. The reason is that large objects generally do not benefit from upscaling.
|
||||
Supports the `mask_size_max` keyword argument. Defaults to `0.5`. If a mask region is determined to be greater than this value, it will not be processed by `[zoom_enhance]`. The reason is that large objects generally do not benefit from upscaling.
|
||||
|
||||
Supports the `min_area` keyword argument. Defaults to `50`. If the pixel area of a mask is smaller than this, it may be a false-positive mask selection or at least not worth upscaling.
|
||||
|
||||
|
|
@ -1460,6 +1450,16 @@ Supports the `upscale_width` and `upscale_height` arguments. Default to `512`. T
|
|||
|
||||
Supports the `include_original` positional argument. This will append the original, "non-zoom-enhanced" image to your output window. Useful for before-after comparisons.
|
||||
|
||||
Supports the `upscale_method` and `downscale_method` arguments which determine the algorithms for image rescaling. Upscale defaults to `Nearest Neighbor`. Downscale defaults to `Lanczos`. Options include: `Nearest Neighbor`, `Box`, `Bilinear`, `Hamming`, `Bicubic`, `Lanczos`.
|
||||
|
||||
Supports the `color_correction_method` argument which will attempt to match the color grading of the upscaled image to the original. Defaults to `none`. Options include: `none`,`mvgd`,`mkl`,`hm-mvgd-hm`,`hm-mkl-hm`.
|
||||
|
||||
Supports the `color_correct_strength` argument which is an integer that determines how many times to run the color correction algorithm. Defaults to 1.
|
||||
|
||||
Supports the `color_correct_timing` argument which determines when to run the color correction algorithm. Defaults to `pre`, which will run color correction before upscaling. Options include `pre` and `post`.
|
||||
|
||||
Supports the `debug` positional argument, which will output a series of images to your WebUI folder over the course of processing.
|
||||
|
||||
This shortcode is compatible with batch count and batch size.
|
||||
|
||||
|
||||
|
|
@ -1509,4 +1509,50 @@ The `cleanup` function runs at the end of the processing chain. You can free any
|
|||
|
||||
For more details, please examine the code of the stock shortcodes.
|
||||
|
||||
</details>
|
||||
|
||||
<details><summary>Legacy Shortcodes</summary>
|
||||
|
||||
Legacy shortcodes are those which are no longer officially supported. Please be aware that they may not work as expected and could be removed from future versions of Unprompted without warning.
|
||||
|
||||
<details><summary>[controlnet]</summary>
|
||||
|
||||
**Reason for legacy status:** The popular [ControlNet extension by Mikubill](https://github.com/Mikubill/sd-webui-controlnet) was released less than 24 hours after this shortcode and is much more robust. ControlNet is a complicated, time-consuming feature to support and I cannot justify further development when the alternative software is already so good.
|
||||
|
||||
Enables support for [ControlNet](https://github.com/lllyasviel/ControlNet) models in img2img mode. ControlNet is a neural network structure to control diffusion models by adding extra conditions.
|
||||
|
||||
You need a bare minimum of 8 GB of VRAM to use this shortcode, although 12 GB is recommended.
|
||||
|
||||
Supports the `model` argument, which is the name of a ControlNet checkpoint in your `models/Stable-diffusion` directory (do not include the file extension.) You can download ControlNet checkpoints from [the official HuggingFace page](https://huggingface.co/lllyasviel/ControlNet/tree/main/models).
|
||||
|
||||
For each model, you also need a copy of the [cldm_v15.yaml](https://github.com/lllyasviel/ControlNet/tree/main/models) config file. Rename it to match the name of the ControlNet model, e.g. `control_sd15_normal.yaml`.
|
||||
|
||||
For each model, you also need the associated [annotator files available here](https://huggingface.co/lllyasviel/ControlNet/tree/main/annotator/ckpts). Place these into your `extensions/unprompted/lib_unprompted/stable_diffusion/controlnet/annotator/ckpts` folder.
|
||||
|
||||
If you run into any errors, please triple-check your filepaths before opening a bug report.
|
||||
|
||||
You can use ControlNet with custom SD 1.5 models [by merging checkpoints as described here](https://github.com/lllyasviel/ControlNet/issues/4#issuecomment-1426877944).
|
||||
|
||||
Please be aware that the last part of your model's filename indicates which type of ControlNet model it is. The following ControlNet model types are supported: `openpose`, `scribble`, `mlsd`, `depth`, `normal`, `hed`, `canny`, `seg`
|
||||
|
||||
ControlNet models should **not** be loaded manually from your WebUI dropdown.
|
||||
|
||||
Supports the `save_memory` argument to minimize VRAM requirements.
|
||||
|
||||
Supports the `detect_resolution` argument which is the size of the detected map. Defaults to 512. Some models may perform better at 384. Lowering this value to 256 may help with VRAM requirements.
|
||||
|
||||
Supports the `eta` argument.
|
||||
|
||||
Supports the following model-specific arguments: `value_threshold`, `distance_threshold`, `bg_threshold`, `low_threshold`, `high_threshold`
|
||||
|
||||
</details>
|
||||
|
||||
<details><summary>[pix2pix_zero]</summary>
|
||||
|
||||
**Reason for legacy status:** the pix2pix-zero method [was not what I originally thought](https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/7711#discussioncomment-4952579) which was sort of a buzzkill. As it stands, I believe ControlNet is better suited at most tasks, but I can't make any definitive claims about that - pix2pix-zero went under the radar and does merit further testing.
|
||||
|
||||
If you wish to use this shortcode, you will need to modify the hardcoded path to a diffusers model on line 33.
|
||||
|
||||
</details>
|
||||
|
||||
</details>
|
||||
|
After Width: | Height: | Size: 2.1 KiB |
|
After Width: | Height: | Size: 3.1 KiB |
|
After Width: | Height: | Size: 3.1 KiB |
|
After Width: | Height: | Size: 3.1 KiB |
|
After Width: | Height: | Size: 3.1 KiB |
|
After Width: | Height: | Size: 2.1 KiB |
|
After Width: | Height: | Size: 3.1 KiB |
|
After Width: | Height: | Size: 3.1 KiB |
|
After Width: | Height: | Size: 5.8 KiB |
|
After Width: | Height: | Size: 190 KiB |
|
|
@ -0,0 +1 @@
|
|||
from .clip import *
|
||||
|
|
@ -0,0 +1,74 @@
|
|||
from torch import nn
|
||||
from .clip_model import CLIP
|
||||
from .clip_surgery_model import CLIPSurgery
|
||||
|
||||
|
||||
def convert_weights(model: nn.Module):
|
||||
"""Convert applicable model parameters to fp16"""
|
||||
|
||||
def _convert_weights_to_fp16(l):
|
||||
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
||||
l.weight.data = l.weight.data.half()
|
||||
if l.bias is not None:
|
||||
l.bias.data = l.bias.data.half()
|
||||
|
||||
if isinstance(l, nn.MultiheadAttention):
|
||||
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
||||
tensor = getattr(l, attr)
|
||||
if tensor is not None:
|
||||
tensor.data = tensor.data.half()
|
||||
|
||||
for name in ["text_projection", "proj"]:
|
||||
if hasattr(l, name):
|
||||
attr = getattr(l, name)
|
||||
if attr is not None:
|
||||
attr.data = attr.data.half()
|
||||
|
||||
model.apply(_convert_weights_to_fp16)
|
||||
|
||||
|
||||
def build_model(name: str, state_dict: dict):
|
||||
vit = "visual.proj" in state_dict
|
||||
|
||||
if vit:
|
||||
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
||||
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
||||
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
||||
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
||||
image_resolution = vision_patch_size * grid_size
|
||||
else:
|
||||
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
|
||||
vision_layers = tuple(counts)
|
||||
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
||||
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
||||
vision_patch_size = None
|
||||
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
||||
image_resolution = output_width * 32
|
||||
|
||||
embed_dim = state_dict["text_projection"].shape[1]
|
||||
context_length = state_dict["positional_embedding"].shape[0]
|
||||
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
||||
transformer_width = state_dict["ln_final.weight"].shape[0]
|
||||
transformer_heads = transformer_width // 64
|
||||
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
|
||||
|
||||
if 'CS-' in name:
|
||||
model = CLIPSurgery(
|
||||
embed_dim,
|
||||
image_resolution, vision_layers, vision_width, vision_patch_size,
|
||||
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
|
||||
)
|
||||
else:
|
||||
model = CLIP(
|
||||
embed_dim,
|
||||
image_resolution, vision_layers, vision_width, vision_patch_size,
|
||||
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
|
||||
)
|
||||
|
||||
for key in ["input_resolution", "context_length", "vocab_size"]:
|
||||
if key in state_dict:
|
||||
del state_dict[key]
|
||||
|
||||
#convert_weights(model)
|
||||
model.load_state_dict(state_dict)
|
||||
return model.eval()
|
||||
|
|
@ -0,0 +1,344 @@
|
|||
import hashlib
|
||||
import os
|
||||
import urllib
|
||||
import warnings
|
||||
from typing import Union, List
|
||||
from pkg_resources import packaging
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
|
||||
from .build_model import build_model
|
||||
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
|
||||
|
||||
try:
|
||||
from torchvision.transforms import InterpolationMode
|
||||
BICUBIC = InterpolationMode.BICUBIC
|
||||
except ImportError:
|
||||
BICUBIC = Image.BICUBIC
|
||||
|
||||
|
||||
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
|
||||
warnings.warn("PyTorch version 1.7.1 or higher is recommended")
|
||||
|
||||
|
||||
__all__ = ["available_models", "load", "tokenize", "encode_text_with_prompt_ensemble",
|
||||
"get_similarity_map", "clip_feature_surgery", "similarity_map_to_points"]
|
||||
_tokenizer = _Tokenizer()
|
||||
|
||||
_MODELS = {
|
||||
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
|
||||
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
|
||||
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
|
||||
"RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
|
||||
"RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
|
||||
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
|
||||
"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
|
||||
"ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
|
||||
"ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
|
||||
"CS-RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
|
||||
"CS-RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
|
||||
"CS-RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
|
||||
"CS-RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
|
||||
"CS-RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
|
||||
"CS-ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
|
||||
"CS-ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
|
||||
"CS-ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
|
||||
"CS-ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
|
||||
}
|
||||
|
||||
|
||||
def _download(url: str, root: str):
|
||||
os.makedirs(root, exist_ok=True)
|
||||
filename = os.path.basename(url)
|
||||
|
||||
expected_sha256 = url.split("/")[-2]
|
||||
download_target = os.path.join(root, filename)
|
||||
|
||||
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
||||
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
||||
|
||||
if os.path.isfile(download_target):
|
||||
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
|
||||
return download_target
|
||||
else:
|
||||
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
||||
|
||||
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
||||
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
|
||||
while True:
|
||||
buffer = source.read(8192)
|
||||
if not buffer:
|
||||
break
|
||||
|
||||
output.write(buffer)
|
||||
loop.update(len(buffer))
|
||||
|
||||
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
|
||||
raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
|
||||
|
||||
return download_target
|
||||
|
||||
|
||||
def _convert_image_to_rgb(image):
|
||||
return image.convert("RGB")
|
||||
|
||||
|
||||
def _transform(n_px):
|
||||
return Compose([
|
||||
Resize((n_px, n_px), interpolation=BICUBIC),
|
||||
#CenterCrop(n_px), # rm center crop to explain whole image
|
||||
_convert_image_to_rgb,
|
||||
ToTensor(),
|
||||
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
||||
])
|
||||
|
||||
|
||||
def available_models() -> List[str]:
|
||||
"""Returns the names of available CLIP models"""
|
||||
return list(_MODELS.keys())
|
||||
|
||||
|
||||
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
|
||||
"""Load a CLIP model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
|
||||
|
||||
device : Union[str, torch.device]
|
||||
The device to put the loaded model
|
||||
|
||||
jit : bool
|
||||
Whether to load the optimized JIT model or more hackable non-JIT model (default).
|
||||
|
||||
download_root: str
|
||||
path to download the model files; by default, it uses "~/.cache/clip"
|
||||
|
||||
Returns
|
||||
-------
|
||||
model : torch.nn.Module
|
||||
The CLIP model
|
||||
|
||||
preprocess : Callable[[PIL.Image], torch.Tensor]
|
||||
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
||||
"""
|
||||
if name in _MODELS:
|
||||
model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
|
||||
elif os.path.isfile(name):
|
||||
model_path = name
|
||||
else:
|
||||
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
||||
|
||||
with open(model_path, 'rb') as opened_file:
|
||||
try:
|
||||
# loading JIT archive
|
||||
model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
|
||||
state_dict = None
|
||||
except RuntimeError:
|
||||
# loading saved state dict
|
||||
if jit:
|
||||
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
|
||||
jit = False
|
||||
state_dict = torch.load(opened_file, map_location="cpu")
|
||||
|
||||
if not jit:
|
||||
model = build_model(name, state_dict or model.state_dict()).to(device)
|
||||
if str(device) == "cpu":
|
||||
model.float()
|
||||
return model, _transform(model.visual.input_resolution)
|
||||
|
||||
# patch the device names
|
||||
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
|
||||
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
|
||||
|
||||
def patch_device(module):
|
||||
try:
|
||||
graphs = [module.graph] if hasattr(module, "graph") else []
|
||||
except RuntimeError:
|
||||
graphs = []
|
||||
|
||||
if hasattr(module, "forward1"):
|
||||
graphs.append(module.forward1.graph)
|
||||
|
||||
for graph in graphs:
|
||||
for node in graph.findAllNodes("prim::Constant"):
|
||||
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
|
||||
node.copyAttributes(device_node)
|
||||
|
||||
model.apply(patch_device)
|
||||
patch_device(model.encode_image)
|
||||
patch_device(model.encode_text)
|
||||
|
||||
# patch dtype to float32 on CPU
|
||||
if str(device) == "cpu":
|
||||
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
|
||||
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
|
||||
float_node = float_input.node()
|
||||
|
||||
def patch_float(module):
|
||||
try:
|
||||
graphs = [module.graph] if hasattr(module, "graph") else []
|
||||
except RuntimeError:
|
||||
graphs = []
|
||||
|
||||
if hasattr(module, "forward1"):
|
||||
graphs.append(module.forward1.graph)
|
||||
|
||||
for graph in graphs:
|
||||
for node in graph.findAllNodes("aten::to"):
|
||||
inputs = list(node.inputs())
|
||||
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
|
||||
if inputs[i].node()["value"] == 5:
|
||||
inputs[i].node().copyAttributes(float_node)
|
||||
|
||||
model.apply(patch_float)
|
||||
patch_float(model.encode_image)
|
||||
patch_float(model.encode_text)
|
||||
|
||||
model.float()
|
||||
|
||||
return model, _transform(model.input_resolution.item())
|
||||
|
||||
|
||||
def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
|
||||
"""
|
||||
Returns the tokenized representation of given input string(s)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
texts : Union[str, List[str]]
|
||||
An input string or a list of input strings to tokenize
|
||||
|
||||
context_length : int
|
||||
The context length to use; all CLIP models use 77 as the context length
|
||||
|
||||
truncate: bool
|
||||
Whether to truncate the text in case its encoding is longer than the context length
|
||||
|
||||
Returns
|
||||
-------
|
||||
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
|
||||
We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
|
||||
"""
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
|
||||
sot_token = _tokenizer.encoder["<|startoftext|>"]
|
||||
eot_token = _tokenizer.encoder["<|endoftext|>"]
|
||||
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
|
||||
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
|
||||
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
||||
else:
|
||||
result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
|
||||
|
||||
for i, tokens in enumerate(all_tokens):
|
||||
if len(tokens) > context_length:
|
||||
if truncate:
|
||||
tokens = tokens[:context_length]
|
||||
tokens[-1] = eot_token
|
||||
else:
|
||||
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
|
||||
result[i, :len(tokens)] = torch.tensor(tokens)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def encode_text_with_prompt_ensemble(model, texts, device, prompt_templates=None):
|
||||
|
||||
# using default prompt templates for ImageNet
|
||||
if prompt_templates == None:
|
||||
prompt_templates = ['a bad photo of a {}.', 'a photo of many {}.', 'a sculpture of a {}.', 'a photo of the hard to see {}.', 'a low resolution photo of the {}.', 'a rendering of a {}.', 'graffiti of a {}.', 'a bad photo of the {}.', 'a cropped photo of the {}.', 'a tattoo of a {}.', 'the embroidered {}.', 'a photo of a hard to see {}.', 'a bright photo of a {}.', 'a photo of a clean {}.', 'a photo of a dirty {}.', 'a dark photo of the {}.', 'a drawing of a {}.', 'a photo of my {}.', 'the plastic {}.', 'a photo of the cool {}.', 'a close-up photo of a {}.', 'a black and white photo of the {}.', 'a painting of the {}.', 'a painting of a {}.', 'a pixelated photo of the {}.', 'a sculpture of the {}.', 'a bright photo of the {}.', 'a cropped photo of a {}.', 'a plastic {}.', 'a photo of the dirty {}.', 'a jpeg corrupted photo of a {}.', 'a blurry photo of the {}.', 'a photo of the {}.', 'a good photo of the {}.', 'a rendering of the {}.', 'a {} in a video game.', 'a photo of one {}.', 'a doodle of a {}.', 'a close-up photo of the {}.', 'a photo of a {}.', 'the origami {}.', 'the {} in a video game.', 'a sketch of a {}.', 'a doodle of the {}.', 'a origami {}.', 'a low resolution photo of a {}.', 'the toy {}.', 'a rendition of the {}.', 'a photo of the clean {}.', 'a photo of a large {}.', 'a rendition of a {}.', 'a photo of a nice {}.', 'a photo of a weird {}.', 'a blurry photo of a {}.', 'a cartoon {}.', 'art of a {}.', 'a sketch of the {}.', 'a embroidered {}.', 'a pixelated photo of a {}.', 'itap of the {}.', 'a jpeg corrupted photo of the {}.', 'a good photo of a {}.', 'a plushie {}.', 'a photo of the nice {}.', 'a photo of the small {}.', 'a photo of the weird {}.', 'the cartoon {}.', 'art of the {}.', 'a drawing of the {}.', 'a photo of the large {}.', 'a black and white photo of a {}.', 'the plushie {}.', 'a dark photo of a {}.', 'itap of a {}.', 'graffiti of the {}.', 'a toy {}.', 'itap of my {}.', 'a photo of a cool {}.', 'a photo of a small {}.', 'a tattoo of the {}.', 'there is a {} in the scene.', 'there is the {} in the scene.', 'this is a {} in the scene.', 'this is the {} in the scene.', 'this is one {} in the scene.']
|
||||
|
||||
text_features = []
|
||||
for t in texts:
|
||||
prompted_t = [template.format(t) for template in prompt_templates]
|
||||
prompted_t = tokenize(prompted_t).to(device)
|
||||
class_embeddings = model.encode_text(prompted_t)
|
||||
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
|
||||
class_embedding = class_embeddings.mean(dim=0)
|
||||
class_embedding /= class_embedding.norm()
|
||||
text_features.append(class_embedding)
|
||||
text_features = torch.stack(text_features, dim=1).to(device).t()
|
||||
|
||||
return text_features
|
||||
|
||||
|
||||
def get_similarity_map(sm, shape):
|
||||
|
||||
# min-max norm
|
||||
sm = (sm - sm.min(1, keepdim=True)[0]) / (sm.max(1, keepdim=True)[0] - sm.min(1, keepdim=True)[0])
|
||||
|
||||
# reshape
|
||||
side = int(sm.shape[1] ** 0.5) # square output
|
||||
sm = sm.reshape(sm.shape[0], side, side, -1).permute(0, 3, 1, 2)
|
||||
|
||||
# interpolate
|
||||
sm = torch.nn.functional.interpolate(sm, shape, mode='bilinear')
|
||||
sm = sm.permute(0, 2, 3, 1)
|
||||
|
||||
return sm
|
||||
|
||||
|
||||
def clip_feature_surgery(image_features, text_features, redundant_feats=None, t=2):
|
||||
|
||||
if redundant_feats != None:
|
||||
similarity = image_features @ (text_features - redundant_feats).t()
|
||||
|
||||
else:
|
||||
# weights to restrain influence of obvious classes on others
|
||||
prob = image_features[:, :1, :] @ text_features.t()
|
||||
prob = (prob * 2).softmax(-1)
|
||||
w = prob / prob.mean(-1, keepdim=True)
|
||||
|
||||
# element-wise multiplied features
|
||||
b, n_t, n_i, c = image_features.shape[0], text_features.shape[0], image_features.shape[1], image_features.shape[2]
|
||||
feats = image_features.reshape(b, n_i, 1, c) * text_features.reshape(1, 1, n_t, c)
|
||||
feats *= w.reshape(1, 1, n_t, 1)
|
||||
redundant_feats = feats.mean(2, keepdim=True) # along cls dim
|
||||
feats = feats - redundant_feats
|
||||
|
||||
# sum the element-wise multiplied features as cosine similarity
|
||||
similarity = feats.sum(-1)
|
||||
|
||||
return similarity
|
||||
|
||||
|
||||
# sm shape N_t
|
||||
def similarity_map_to_points(sm, shape, t=0.8, down_sample=2):
|
||||
side = int(sm.shape[0] ** 0.5)
|
||||
sm = sm.reshape(1, 1, side, side)
|
||||
|
||||
# down sample to smooth results
|
||||
down_side = side // down_sample
|
||||
sm = torch.nn.functional.interpolate(sm, (down_side, down_side), mode='bilinear')[0, 0, :, :]
|
||||
h, w = sm.shape
|
||||
sm = sm.reshape(-1)
|
||||
|
||||
sm = (sm - sm.min()) / (sm.max() - sm.min())
|
||||
rank = sm.sort(0)[1]
|
||||
scale_h = float(shape[0]) / h
|
||||
scale_w = float(shape[1]) / w
|
||||
|
||||
num = min((sm >= t).sum(), sm.shape[0] // 2)
|
||||
labels = np.ones(num * 2).astype('uint8')
|
||||
labels[num:] = 0
|
||||
points = []
|
||||
|
||||
# positives
|
||||
for idx in rank[-num:]:
|
||||
x = min((idx % w + 0.5) * scale_w, shape[1] - 1) # +0.5 to center
|
||||
y = min((idx // w + 0.5) * scale_h, shape[0] - 1)
|
||||
points.append([int(x.item()), int(y.item())])
|
||||
|
||||
# negatives
|
||||
for idx in rank[:num]:
|
||||
x = min((idx % w + 0.5) * scale_w, shape[1] - 1)
|
||||
y = min((idx // w + 0.5) * scale_h, shape[0] - 1)
|
||||
points.append([int(x.item()), int(y.item())])
|
||||
|
||||
return points, labels
|
||||
|
|
@ -0,0 +1,396 @@
|
|||
from collections import OrderedDict
|
||||
from typing import Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1):
|
||||
super().__init__()
|
||||
|
||||
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
|
||||
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.relu2 = nn.ReLU(inplace=True)
|
||||
|
||||
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
||||
|
||||
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
||||
self.relu3 = nn.ReLU(inplace=True)
|
||||
|
||||
self.downsample = None
|
||||
self.stride = stride
|
||||
|
||||
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
||||
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
||||
self.downsample = nn.Sequential(OrderedDict([
|
||||
("-1", nn.AvgPool2d(stride)),
|
||||
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
||||
("1", nn.BatchNorm2d(planes * self.expansion))
|
||||
]))
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
identity = x
|
||||
|
||||
out = self.relu1(self.bn1(self.conv1(x)))
|
||||
out = self.relu2(self.bn2(self.conv2(out)))
|
||||
out = self.avgpool(out)
|
||||
out = self.bn3(self.conv3(out))
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu3(out)
|
||||
return out
|
||||
|
||||
|
||||
class AttentionPool2d(nn.Module):
|
||||
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
||||
super().__init__()
|
||||
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
|
||||
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
||||
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
||||
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
||||
self.num_heads = num_heads
|
||||
|
||||
def forward(self, x):
|
||||
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
|
||||
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
||||
|
||||
side = int((self.positional_embedding.shape[0] - 1) ** 0.5)
|
||||
new_side = int((x.shape[0] - 1) ** 0.5)
|
||||
|
||||
# update the position embedding during inference for varied input size
|
||||
if side != new_side:
|
||||
new_pos = self.positional_embedding[1:, :].reshape(-1, side, side, x.shape[-1]).permute(0, 3, 1, 2)
|
||||
new_pos = torch.nn.functional.interpolate(new_pos, (new_side, new_side), mode='bilinear')
|
||||
new_pos = new_pos.reshape(-1, x.shape[-1], new_side * new_side).transpose(1, 2)
|
||||
self.positional_embedding.data = torch.cat([self.positional_embedding[:1, :], new_pos[0]], 0)
|
||||
|
||||
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
||||
x, _ = F.multi_head_attention_forward(
|
||||
query=x, key=x, value=x,
|
||||
embed_dim_to_check=x.shape[-1],
|
||||
num_heads=self.num_heads,
|
||||
q_proj_weight=self.q_proj.weight,
|
||||
k_proj_weight=self.k_proj.weight,
|
||||
v_proj_weight=self.v_proj.weight,
|
||||
in_proj_weight=None,
|
||||
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
||||
bias_k=None,
|
||||
bias_v=None,
|
||||
add_zero_attn=False,
|
||||
dropout_p=0,
|
||||
out_proj_weight=self.c_proj.weight,
|
||||
out_proj_bias=self.c_proj.bias,
|
||||
use_separate_proj_weight=True,
|
||||
training=self.training,
|
||||
need_weights=False
|
||||
)
|
||||
|
||||
#return x[0]
|
||||
return x.transpose(0, 1) # return both cls token and image tokens, B,N,C
|
||||
|
||||
|
||||
class ModifiedResNet(nn.Module):
|
||||
"""
|
||||
A ResNet class that is similar to torchvision's but contains the following changes:
|
||||
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
||||
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
||||
- The final pooling layer is a QKV attention instead of an average pool
|
||||
"""
|
||||
|
||||
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
|
||||
super().__init__()
|
||||
self.output_dim = output_dim
|
||||
self.input_resolution = input_resolution
|
||||
|
||||
# the 3-layer stem
|
||||
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(width // 2)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(width // 2)
|
||||
self.relu2 = nn.ReLU(inplace=True)
|
||||
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(width)
|
||||
self.relu3 = nn.ReLU(inplace=True)
|
||||
self.avgpool = nn.AvgPool2d(2)
|
||||
|
||||
# residual layers
|
||||
self._inplanes = width # this is a *mutable* variable used during construction
|
||||
self.layer1 = self._make_layer(width, layers[0])
|
||||
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
||||
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
||||
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
||||
|
||||
embed_dim = width * 32 # the ResNet feature dimension
|
||||
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
|
||||
|
||||
def _make_layer(self, planes, blocks, stride=1):
|
||||
layers = [Bottleneck(self._inplanes, planes, stride)]
|
||||
|
||||
self._inplanes = planes * Bottleneck.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(Bottleneck(self._inplanes, planes))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
def stem(x):
|
||||
x = self.relu1(self.bn1(self.conv1(x)))
|
||||
x = self.relu2(self.bn2(self.conv2(x)))
|
||||
x = self.relu3(self.bn3(self.conv3(x)))
|
||||
x = self.avgpool(x)
|
||||
return x
|
||||
|
||||
x = x.type(self.conv1.weight.dtype)
|
||||
x = stem(x)
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
x = self.attnpool(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
"""Subclass torch's LayerNorm to handle fp16."""
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
orig_type = x.dtype
|
||||
ret = super().forward(x.type(torch.float32))
|
||||
return ret.type(orig_type)
|
||||
|
||||
|
||||
class QuickGELU(nn.Module):
|
||||
def forward(self, x: torch.Tensor):
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
|
||||
|
||||
class ResidualAttentionBlock(nn.Module):
|
||||
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, need_weights: bool = False):
|
||||
super().__init__()
|
||||
|
||||
self.attn = nn.MultiheadAttention(d_model, n_head)
|
||||
self.ln_1 = LayerNorm(d_model)
|
||||
self.mlp = nn.Sequential(OrderedDict([
|
||||
("c_fc", nn.Linear(d_model, d_model * 4)),
|
||||
("gelu", QuickGELU()),
|
||||
("c_proj", nn.Linear(d_model * 4, d_model))
|
||||
]))
|
||||
self.ln_2 = LayerNorm(d_model)
|
||||
self.attn_mask = attn_mask
|
||||
self.need_weights = need_weights
|
||||
|
||||
def attention(self, x: torch.Tensor):
|
||||
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
||||
if self.need_weights == False:
|
||||
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
||||
else:
|
||||
return self.attn(x, x, x, need_weights=True, attn_mask=self.attn_mask)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
if self.need_weights == False:
|
||||
x = x + self.attention(self.ln_1(x))
|
||||
x = x + self.mlp(self.ln_2(x))
|
||||
return x
|
||||
else:
|
||||
y, attn = self.attention(self.ln_1(x))
|
||||
x = x + y
|
||||
x = x + self.mlp(self.ln_2(x))
|
||||
return x
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, need_weights: bool = False):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.layers = layers
|
||||
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask, need_weights if i == layers - 1 else False) for i in range(layers)])
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return self.resblocks(x)
|
||||
|
||||
|
||||
class VisionTransformer(nn.Module):
|
||||
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
|
||||
super().__init__()
|
||||
self.input_resolution = input_resolution
|
||||
self.output_dim = output_dim
|
||||
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
||||
|
||||
scale = width ** -0.5
|
||||
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
||||
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
|
||||
self.ln_pre = LayerNorm(width)
|
||||
|
||||
self.transformer = Transformer(width, layers, heads, need_weights=True)
|
||||
|
||||
self.ln_post = LayerNorm(width)
|
||||
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = self.conv1(x) # shape = [*, width, grid, grid]
|
||||
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
||||
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
||||
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
||||
x = x + self.positional_embedding.to(x.dtype)
|
||||
x = self.ln_pre(x)
|
||||
|
||||
x = x.permute(1, 0, 2) # NLD -> LND
|
||||
x = self.transformer(x)
|
||||
x = x.permute(1, 0, 2) # LND -> NLD
|
||||
|
||||
#x = self.ln_post(x[:, 0, :])
|
||||
x = self.ln_post(x) # return both cls token and image tokens
|
||||
|
||||
if self.proj is not None:
|
||||
x = x @ self.proj
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class CLIP(nn.Module):
|
||||
def __init__(self,
|
||||
embed_dim: int,
|
||||
# vision
|
||||
image_resolution: int,
|
||||
vision_layers: Union[Tuple[int, int, int, int], int],
|
||||
vision_width: int,
|
||||
vision_patch_size: int,
|
||||
# text
|
||||
context_length: int,
|
||||
vocab_size: int,
|
||||
transformer_width: int,
|
||||
transformer_heads: int,
|
||||
transformer_layers: int
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.context_length = context_length
|
||||
|
||||
if isinstance(vision_layers, (tuple, list)):
|
||||
vision_heads = vision_width * 32 // 64
|
||||
self.visual = ModifiedResNet(
|
||||
layers=vision_layers,
|
||||
output_dim=embed_dim,
|
||||
heads=vision_heads,
|
||||
input_resolution=image_resolution,
|
||||
width=vision_width
|
||||
)
|
||||
else:
|
||||
vision_heads = vision_width // 64
|
||||
self.visual = VisionTransformer(
|
||||
input_resolution=image_resolution,
|
||||
patch_size=vision_patch_size,
|
||||
width=vision_width,
|
||||
layers=vision_layers,
|
||||
heads=vision_heads,
|
||||
output_dim=embed_dim
|
||||
)
|
||||
|
||||
self.transformer = Transformer(
|
||||
width=transformer_width,
|
||||
layers=transformer_layers,
|
||||
heads=transformer_heads,
|
||||
attn_mask=self.build_attention_mask()
|
||||
)
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
||||
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
|
||||
self.ln_final = LayerNorm(transformer_width)
|
||||
|
||||
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
|
||||
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
||||
|
||||
self.initialize_parameters()
|
||||
|
||||
def initialize_parameters(self):
|
||||
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
||||
nn.init.normal_(self.positional_embedding, std=0.01)
|
||||
|
||||
if isinstance(self.visual, ModifiedResNet):
|
||||
if self.visual.attnpool is not None:
|
||||
std = self.visual.attnpool.c_proj.in_features ** -0.5
|
||||
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
|
||||
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
|
||||
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
|
||||
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
|
||||
|
||||
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
|
||||
for name, param in resnet_block.named_parameters():
|
||||
if name.endswith("bn3.weight"):
|
||||
nn.init.zeros_(param)
|
||||
|
||||
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
||||
attn_std = self.transformer.width ** -0.5
|
||||
fc_std = (2 * self.transformer.width) ** -0.5
|
||||
for block in self.transformer.resblocks:
|
||||
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
||||
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
||||
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
||||
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
||||
|
||||
if self.text_projection is not None:
|
||||
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
||||
|
||||
def build_attention_mask(self):
|
||||
# lazily create causal attention mask, with full attention between the vision tokens
|
||||
# pytorch uses additive attention mask; fill with -inf
|
||||
mask = torch.empty(self.context_length, self.context_length)
|
||||
mask.fill_(float("-inf"))
|
||||
mask.triu_(1) # zero out the lower diagonal
|
||||
return mask
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.visual.conv1.weight.dtype
|
||||
|
||||
def encode_image(self, image):
|
||||
return self.visual(image.type(self.dtype))
|
||||
|
||||
def encode_text(self, text):
|
||||
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
|
||||
|
||||
x = x + self.positional_embedding.type(self.dtype)
|
||||
x = x.permute(1, 0, 2) # NLD -> LND
|
||||
x = self.transformer(x)
|
||||
x = x.permute(1, 0, 2) # LND -> NLD
|
||||
x = self.ln_final(x).type(self.dtype)
|
||||
|
||||
# x.shape = [batch_size, n_ctx, transformer.width]
|
||||
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, image, text):
|
||||
image_features = self.encode_image(image)
|
||||
text_features = self.encode_text(text)
|
||||
|
||||
# normalized features
|
||||
image_features = image_features / image_features.norm(dim=1, keepdim=True)
|
||||
text_features = text_features / text_features.norm(dim=1, keepdim=True)
|
||||
|
||||
# cosine similarity as logits
|
||||
logit_scale = self.logit_scale.exp()
|
||||
logits_per_image = logit_scale * image_features @ text_features.t()
|
||||
logits_per_text = logits_per_image.t()
|
||||
|
||||
# shape = [global_batch_size, global_batch_size]
|
||||
return logits_per_image, logits_per_text
|
||||
|
|
@ -0,0 +1,487 @@
|
|||
from collections import OrderedDict
|
||||
from typing import Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1):
|
||||
super().__init__()
|
||||
|
||||
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
|
||||
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.relu2 = nn.ReLU(inplace=True)
|
||||
|
||||
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
||||
|
||||
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
||||
self.relu3 = nn.ReLU(inplace=True)
|
||||
|
||||
self.downsample = None
|
||||
self.stride = stride
|
||||
|
||||
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
||||
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
||||
self.downsample = nn.Sequential(OrderedDict([
|
||||
("-1", nn.AvgPool2d(stride)),
|
||||
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
||||
("1", nn.BatchNorm2d(planes * self.expansion))
|
||||
]))
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
identity = x
|
||||
|
||||
out = self.relu1(self.bn1(self.conv1(x)))
|
||||
out = self.relu2(self.bn2(self.conv2(out)))
|
||||
out = self.avgpool(out)
|
||||
out = self.bn3(self.conv3(out))
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu3(out)
|
||||
return out
|
||||
|
||||
|
||||
# implement attention module for v-v self-attention
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, out_dim, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., settings=''):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = qk_scale or head_dim ** -0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(out_dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
self.settings = settings
|
||||
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
|
||||
# original self-attention for the original path
|
||||
attn_ori = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn_ori = attn_ori.softmax(dim=-1)
|
||||
attn_ori = self.attn_drop(attn_ori)
|
||||
|
||||
# replace k & q by v
|
||||
k = v
|
||||
q = k
|
||||
|
||||
# resnets have only one self-attention, norm and larger scale perform better
|
||||
if self.settings == 'resnet':
|
||||
k = k / (k.norm(p=2, dim=-1, keepdim=True) + 1e-6)
|
||||
q = k
|
||||
scale = self.scale * 8
|
||||
else:
|
||||
scale = self.scale
|
||||
|
||||
# self-attention, higher temperate for resnets performs better
|
||||
attn = (q @ k.transpose(-2, -1)) * scale
|
||||
attn = (attn).softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x_ori = (attn_ori @ v).transpose(1, 2).reshape(B, N, C)
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C) # clip_surgery
|
||||
#x = v.transpose(1, 2).reshape(B, N, C) # mask_clip
|
||||
x = self.proj_drop(self.proj(x))
|
||||
x_ori = self.proj_drop(self.proj(x_ori))
|
||||
return [x, x_ori]
|
||||
|
||||
|
||||
class AttentionPool2d(nn.Module):
|
||||
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
||||
super().__init__()
|
||||
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
|
||||
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
||||
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
||||
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
||||
self.num_heads = num_heads
|
||||
|
||||
self.attn = None
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.output_dim = output_dim
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
# reform transformer layer after init and load weights, using v only
|
||||
if self.attn == None:
|
||||
self.attn = Attention(self.output_dim, self.embed_dim, self.num_heads, True)
|
||||
self.attn.qkv.weight = torch.nn.Parameter(torch.cat([self.v_proj.weight, self.v_proj.weight, self.v_proj.weight], 0))
|
||||
self.attn.qkv.bias = torch.nn.Parameter(torch.cat([self.v_proj.bias, self.v_proj.bias, self.v_proj.bias]))
|
||||
self.attn.proj.weight = self.c_proj.weight
|
||||
self.attn.proj.bias = self.c_proj.bias
|
||||
|
||||
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
|
||||
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
||||
|
||||
side = int((self.positional_embedding.shape[0] - 1) ** 0.5)
|
||||
new_side = int((x.shape[0] - 1) ** 0.5)
|
||||
|
||||
# update the position embedding during inference for varied input size
|
||||
if side != new_side:
|
||||
new_pos = self.positional_embedding[1:, :].reshape(-1, side, side, x.shape[-1]).permute(0, 3, 1, 2)
|
||||
new_pos = torch.nn.functional.interpolate(new_pos, (new_side, new_side), mode='bilinear')
|
||||
new_pos = new_pos.reshape(-1, x.shape[-1], new_side * new_side).transpose(1, 2)
|
||||
self.positional_embedding.data = torch.cat([self.positional_embedding[:1, :], new_pos[0]], 0)
|
||||
|
||||
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
||||
x, x_ori = self.attn(x.transpose(0, 1))
|
||||
|
||||
# cls token from the original path, and img tokens from the new path
|
||||
x[:, 0, :] = x_ori[:, 0, :]
|
||||
return x
|
||||
|
||||
|
||||
class ModifiedResNet(nn.Module):
|
||||
"""
|
||||
A ResNet class that is similar to torchvision's but contains the following changes:
|
||||
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
||||
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
||||
- The final pooling layer is a QKV attention instead of an average pool
|
||||
"""
|
||||
|
||||
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
|
||||
super().__init__()
|
||||
self.output_dim = output_dim
|
||||
self.input_resolution = input_resolution
|
||||
|
||||
# the 3-layer stem
|
||||
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(width // 2)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(width // 2)
|
||||
self.relu2 = nn.ReLU(inplace=True)
|
||||
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(width)
|
||||
self.relu3 = nn.ReLU(inplace=True)
|
||||
self.avgpool = nn.AvgPool2d(2)
|
||||
|
||||
# residual layers
|
||||
self._inplanes = width # this is a *mutable* variable used during construction
|
||||
self.layer1 = self._make_layer(width, layers[0])
|
||||
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
||||
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
||||
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
||||
|
||||
embed_dim = width * 32 # the ResNet feature dimension
|
||||
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
|
||||
|
||||
def _make_layer(self, planes, blocks, stride=1):
|
||||
layers = [Bottleneck(self._inplanes, planes, stride)]
|
||||
|
||||
self._inplanes = planes * Bottleneck.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(Bottleneck(self._inplanes, planes))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
def stem(x):
|
||||
x = self.relu1(self.bn1(self.conv1(x)))
|
||||
x = self.relu2(self.bn2(self.conv2(x)))
|
||||
x = self.relu3(self.bn3(self.conv3(x)))
|
||||
x = self.avgpool(x)
|
||||
return x
|
||||
|
||||
x = x.type(self.conv1.weight.dtype)
|
||||
x = stem(x)
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
x = self.attnpool(x)
|
||||
|
||||
# shape BNC
|
||||
return x
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
"""Subclass torch's LayerNorm to handle fp16."""
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
orig_type = x.dtype
|
||||
ret = super().forward(x.type(torch.float32))
|
||||
return ret.type(orig_type)
|
||||
|
||||
|
||||
class QuickGELU(nn.Module):
|
||||
def forward(self, x: torch.Tensor):
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
|
||||
|
||||
class ResidualAttentionBlock(nn.Module):
|
||||
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
|
||||
super().__init__()
|
||||
|
||||
self.attn = nn.MultiheadAttention(d_model, n_head)
|
||||
self.ln_1 = LayerNorm(d_model)
|
||||
self.mlp = nn.Sequential(OrderedDict([
|
||||
("c_fc", nn.Linear(d_model, d_model * 4)),
|
||||
("gelu", QuickGELU()),
|
||||
("c_proj", nn.Linear(d_model * 4, d_model))
|
||||
]))
|
||||
self.ln_2 = LayerNorm(d_model)
|
||||
self.attn_mask = attn_mask
|
||||
|
||||
def attention(self, x: torch.Tensor):
|
||||
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
||||
if isinstance(self.attn, Attention):
|
||||
x = x.transpose(0, 1)
|
||||
x, x_ori = self.attn(x)
|
||||
return [x.transpose(0, 1), x_ori.transpose(0, 1)]
|
||||
else:
|
||||
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
# dual paths for blocks deeper than "d"
|
||||
if isinstance(self.attn, Attention):
|
||||
if isinstance(x, list):
|
||||
x, x_ori = x
|
||||
x_res = self.attention(self.ln_1(x_ori))
|
||||
x_res, x_ori_res = x_res
|
||||
x_ori += x_ori_res
|
||||
x_ori = x_ori + self.mlp(self.ln_2(x_ori))
|
||||
x += x_res # skip ffn for the new path
|
||||
return [x, x_ori]
|
||||
|
||||
# start of dual path
|
||||
else:
|
||||
x_res = self.attention(self.ln_1(x))
|
||||
if isinstance(x_res, list):
|
||||
x_res, x_ori_res = x_res
|
||||
x_ori = x + x_ori_res
|
||||
x_ori = x_ori + self.mlp(self.ln_2(x_ori))
|
||||
x += x_res
|
||||
return [x, x_ori]
|
||||
|
||||
# singl path before "d"
|
||||
else:
|
||||
x = x + self.attention(self.ln_1(x))
|
||||
x = x + self.mlp(self.ln_2(x))
|
||||
return x
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, need_weights: bool = False):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.layers = layers
|
||||
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for i in range(layers)])
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return self.resblocks(x)
|
||||
|
||||
|
||||
class VisionTransformer(nn.Module):
|
||||
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
|
||||
super().__init__()
|
||||
self.input_resolution = input_resolution
|
||||
self.output_dim = output_dim
|
||||
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
||||
|
||||
scale = width ** -0.5
|
||||
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
||||
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
|
||||
self.ln_pre = LayerNorm(width)
|
||||
|
||||
self.transformer = Transformer(width, layers, heads, need_weights=True)
|
||||
self.attn = None
|
||||
self.embed_dim = width
|
||||
self.num_heads = heads
|
||||
|
||||
self.ln_post = LayerNorm(width)
|
||||
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x: torch.Tensor):
|
||||
|
||||
# reform the architecture during first inference
|
||||
if self.attn == None:
|
||||
|
||||
# apply architecture surgery on the last 6 blocks
|
||||
for i in range(1, 7): # surgery 7, maskclip 2
|
||||
self.attn = Attention(self.embed_dim, self.embed_dim, self.num_heads, True)
|
||||
self.attn.qkv.weight.data = self.transformer.resblocks[-i].attn.in_proj_weight.clone()
|
||||
self.attn.qkv.bias.data = self.transformer.resblocks[-i].attn.in_proj_bias.clone()
|
||||
self.attn.proj.weight.data = self.transformer.resblocks[-i].attn.out_proj.weight.clone()
|
||||
self.attn.proj.bias.data = self.transformer.resblocks[-i].attn.out_proj.bias.clone()
|
||||
self.transformer.resblocks[-i].attn = self.attn
|
||||
|
||||
x = self.conv1(x) # shape = [*, width, grid, grid]
|
||||
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
||||
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
||||
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
||||
side = int((self.positional_embedding.shape[0] - 1) ** 0.5)
|
||||
new_side = int((x.shape[1] - 1) ** 0.5)
|
||||
|
||||
# update the position embedding during inference for varied input size
|
||||
if side != new_side:
|
||||
new_pos = self.positional_embedding[1:, :].reshape(-1, side, side, x.shape[-1]).permute(0, 3, 1, 2)
|
||||
new_pos = torch.nn.functional.interpolate(new_pos, (new_side, new_side), mode='bilinear')
|
||||
new_pos = new_pos.reshape(-1, x.shape[-1], new_side * new_side).transpose(1, 2)
|
||||
self.positional_embedding.data = torch.cat([self.positional_embedding[:1, :], new_pos[0]], 0)
|
||||
|
||||
pos = self.positional_embedding.to(x.dtype)
|
||||
x = x + pos
|
||||
x = self.ln_pre(x)
|
||||
|
||||
x = x.permute(1, 0, 2) # NLD -> LND
|
||||
x, x_ori = self.transformer(x)
|
||||
x[0, :, :] = x_ori[0, :, :] # clip_surgery
|
||||
x = x.permute(1, 0, 2) # LND -> NLD
|
||||
|
||||
x = self.ln_post(x)
|
||||
x = x @ self.proj
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class CLIPSurgery(nn.Module):
|
||||
def __init__(self,
|
||||
embed_dim: int,
|
||||
# vision
|
||||
image_resolution: int,
|
||||
vision_layers: Union[Tuple[int, int, int, int], int],
|
||||
vision_width: int,
|
||||
vision_patch_size: int,
|
||||
# text
|
||||
context_length: int,
|
||||
vocab_size: int,
|
||||
transformer_width: int,
|
||||
transformer_heads: int,
|
||||
transformer_layers: int
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.context_length = context_length
|
||||
|
||||
if isinstance(vision_layers, (tuple, list)):
|
||||
vision_heads = vision_width * 32 // 64
|
||||
self.visual = ModifiedResNet(
|
||||
layers=vision_layers,
|
||||
output_dim=embed_dim,
|
||||
heads=vision_heads,
|
||||
input_resolution=image_resolution,
|
||||
width=vision_width
|
||||
)
|
||||
else:
|
||||
vision_heads = vision_width // 64
|
||||
self.visual = VisionTransformer(
|
||||
input_resolution=image_resolution,
|
||||
patch_size=vision_patch_size,
|
||||
width=vision_width,
|
||||
layers=vision_layers,
|
||||
heads=vision_heads,
|
||||
output_dim=embed_dim
|
||||
)
|
||||
|
||||
self.transformer = Transformer(
|
||||
width=transformer_width,
|
||||
layers=transformer_layers,
|
||||
heads=transformer_heads,
|
||||
attn_mask=self.build_attention_mask()
|
||||
)
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
||||
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
|
||||
self.ln_final = LayerNorm(transformer_width)
|
||||
|
||||
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
|
||||
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
||||
|
||||
self.initialize_parameters()
|
||||
|
||||
def initialize_parameters(self):
|
||||
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
||||
nn.init.normal_(self.positional_embedding, std=0.01)
|
||||
|
||||
if isinstance(self.visual, ModifiedResNet):
|
||||
if self.visual.attnpool is not None:
|
||||
std = self.visual.attnpool.c_proj.in_features ** -0.5
|
||||
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
|
||||
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
|
||||
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
|
||||
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
|
||||
|
||||
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
|
||||
for name, param in resnet_block.named_parameters():
|
||||
if name.endswith("bn3.weight"):
|
||||
nn.init.zeros_(param)
|
||||
|
||||
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
||||
attn_std = self.transformer.width ** -0.5
|
||||
fc_std = (2 * self.transformer.width) ** -0.5
|
||||
for block in self.transformer.resblocks:
|
||||
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
||||
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
||||
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
||||
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
||||
|
||||
if self.text_projection is not None:
|
||||
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
||||
|
||||
def build_attention_mask(self):
|
||||
# lazily create causal attention mask, with full attention between the vision tokens
|
||||
# pytorch uses additive attention mask; fill with -inf
|
||||
mask = torch.empty(self.context_length, self.context_length)
|
||||
mask.fill_(float("-inf"))
|
||||
mask.triu_(1) # zero out the lower diagonal
|
||||
return mask
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.visual.conv1.weight.dtype
|
||||
|
||||
def encode_image(self, image):
|
||||
return self.visual(image.type(self.dtype))
|
||||
|
||||
def encode_text(self, text):
|
||||
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
|
||||
|
||||
x = x + self.positional_embedding.type(self.dtype)
|
||||
x = x.permute(1, 0, 2) # NLD -> LND
|
||||
x = self.transformer(x)
|
||||
x = x.permute(1, 0, 2) # LND -> NLD
|
||||
x = self.ln_final(x).type(self.dtype)
|
||||
|
||||
# x.shape = [batch_size, n_ctx, transformer.width]
|
||||
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, image, text):
|
||||
image_features = self.encode_image(image)
|
||||
text_features = self.encode_text(text)
|
||||
|
||||
# normalized features
|
||||
image_features = image_features / image_features.norm(dim=1, keepdim=True)
|
||||
text_features = text_features / text_features.norm(dim=1, keepdim=True)
|
||||
|
||||
# cosine similarity as logits
|
||||
logit_scale = self.logit_scale.exp()
|
||||
logits_per_image = logit_scale * image_features @ text_features.t()
|
||||
logits_per_text = logits_per_image.t()
|
||||
|
||||
# shape = [global_batch_size, global_batch_size]
|
||||
return logits_per_image, logits_per_text
|
||||
|
|
@ -0,0 +1,132 @@
|
|||
import gzip
|
||||
import html
|
||||
import os
|
||||
from functools import lru_cache
|
||||
|
||||
import ftfy
|
||||
import regex as re
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def default_bpe():
|
||||
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def bytes_to_unicode():
|
||||
"""
|
||||
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
||||
The reversible bpe codes work on unicode strings.
|
||||
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
||||
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
||||
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
||||
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
||||
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
||||
"""
|
||||
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
||||
cs = bs[:]
|
||||
n = 0
|
||||
for b in range(2**8):
|
||||
if b not in bs:
|
||||
bs.append(b)
|
||||
cs.append(2**8+n)
|
||||
n += 1
|
||||
cs = [chr(n) for n in cs]
|
||||
return dict(zip(bs, cs))
|
||||
|
||||
|
||||
def get_pairs(word):
|
||||
"""Return set of symbol pairs in a word.
|
||||
Word is represented as tuple of symbols (symbols being variable-length strings).
|
||||
"""
|
||||
pairs = set()
|
||||
prev_char = word[0]
|
||||
for char in word[1:]:
|
||||
pairs.add((prev_char, char))
|
||||
prev_char = char
|
||||
return pairs
|
||||
|
||||
|
||||
def basic_clean(text):
|
||||
text = ftfy.fix_text(text)
|
||||
text = html.unescape(html.unescape(text))
|
||||
return text.strip()
|
||||
|
||||
|
||||
def whitespace_clean(text):
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
|
||||
class SimpleTokenizer(object):
|
||||
def __init__(self, bpe_path: str = default_bpe()):
|
||||
self.byte_encoder = bytes_to_unicode()
|
||||
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
||||
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
||||
merges = merges[1:49152-256-2+1]
|
||||
merges = [tuple(merge.split()) for merge in merges]
|
||||
vocab = list(bytes_to_unicode().values())
|
||||
vocab = vocab + [v+'</w>' for v in vocab]
|
||||
for merge in merges:
|
||||
vocab.append(''.join(merge))
|
||||
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
||||
self.encoder = dict(zip(vocab, range(len(vocab))))
|
||||
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
||||
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
|
||||
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
|
||||
|
||||
def bpe(self, token):
|
||||
if token in self.cache:
|
||||
return self.cache[token]
|
||||
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
||||
pairs = get_pairs(word)
|
||||
|
||||
if not pairs:
|
||||
return token+'</w>'
|
||||
|
||||
while True:
|
||||
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
||||
if bigram not in self.bpe_ranks:
|
||||
break
|
||||
first, second = bigram
|
||||
new_word = []
|
||||
i = 0
|
||||
while i < len(word):
|
||||
try:
|
||||
j = word.index(first, i)
|
||||
new_word.extend(word[i:j])
|
||||
i = j
|
||||
except:
|
||||
new_word.extend(word[i:])
|
||||
break
|
||||
|
||||
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
||||
new_word.append(first+second)
|
||||
i += 2
|
||||
else:
|
||||
new_word.append(word[i])
|
||||
i += 1
|
||||
new_word = tuple(new_word)
|
||||
word = new_word
|
||||
if len(word) == 1:
|
||||
break
|
||||
else:
|
||||
pairs = get_pairs(word)
|
||||
word = ' '.join(word)
|
||||
self.cache[token] = word
|
||||
return word
|
||||
|
||||
def encode(self, text):
|
||||
bpe_tokens = []
|
||||
text = whitespace_clean(basic_clean(text)).lower()
|
||||
for token in re.findall(self.pat, text):
|
||||
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
||||
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
||||
return bpe_tokens
|
||||
|
||||
def decode(self, tokens):
|
||||
text = ''.join([self.decoder[token] for token in tokens])
|
||||
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
||||
return text
|
||||
|
|
@ -10,7 +10,7 @@ import sys
|
|||
|
||||
class Unprompted:
|
||||
def __init__(self, base_dir="."):
|
||||
self.VERSION = "8.2.0"
|
||||
self.VERSION = "8.3.0"
|
||||
|
||||
print(f"Loading Unprompted v{self.VERSION} by Therefore Games")
|
||||
self.log("Initializing Unprompted object...",False,"SETUP")
|
||||
|
|
@ -113,7 +113,8 @@ class Unprompted:
|
|||
|
||||
def parse_advanced(self,string,context=None):
|
||||
"""First runs the string through parse_alt_tags, the result of which then goes through simpleeval"""
|
||||
if (len(string) < 1): return string
|
||||
if string is None: return ""
|
||||
if (len(string) < 1): return ""
|
||||
string = self.parse_alt_tags(string,context)
|
||||
if self.Config.advanced_expressions:
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ Unprompted.wizard_dropdown = None
|
|||
|
||||
Unprompted.wizard_template_files = []
|
||||
Unprompted.wizard_template_names = []
|
||||
Unprompted.wizard_template_kwargs = []
|
||||
|
||||
def do_dry_run(string):
|
||||
Unprompted.log(string)
|
||||
|
|
@ -109,6 +110,9 @@ def wizard_generate_template(option,is_img2img,prepend="",append=""):
|
|||
|
||||
result = parse_children(group,result)
|
||||
|
||||
for kwarg in Unprompted.wizard_template_kwargs[option]:
|
||||
result += f" {kwarg}='{Unprompted.wizard_template_kwargs[option][kwarg]}'"
|
||||
|
||||
# Closing bracket
|
||||
result += Unprompted.Config.syntax.tag_end
|
||||
|
||||
|
|
@ -261,6 +265,7 @@ class Scripts(scripts.Script):
|
|||
content = content.replace("\\r\\n", "<br>") + "<br><br>"
|
||||
gr.Label(label="Options",value=f"{self.dropdown_item_name}")
|
||||
gr.Markdown(value=content)
|
||||
self.wizard_template_kwargs = kwargs
|
||||
return("")
|
||||
wizard_shortcode_parser.register(handler,"template",f"{Unprompted.Config.syntax.tag_close}template")
|
||||
|
||||
|
|
@ -306,6 +311,7 @@ class Scripts(scripts.Script):
|
|||
wizard_add_template(is_first)
|
||||
Unprompted.wizard_template_names.append(self.dropdown_item_name)
|
||||
Unprompted.wizard_template_files.append(filename)
|
||||
Unprompted.wizard_template_kwargs.append(self.wizard_template_kwargs)
|
||||
if (is_first):
|
||||
templates_dropdown.value = self.dropdown_item_name
|
||||
is_first = False
|
||||
|
|
@ -435,6 +441,8 @@ class Scripts(scripts.Script):
|
|||
|
||||
# Extra vars
|
||||
Unprompted.shortcode_user_vars["batch_index"] = 0
|
||||
original_model = opts.sd_model_checkpoint
|
||||
Unprompted.shortcode_user_vars["sd_model"] = opts.sd_model_checkpoint
|
||||
|
||||
# Set up system var support - copy relevant p attributes into shortcode var object
|
||||
for att in dir(p):
|
||||
|
|
@ -451,7 +459,7 @@ class Scripts(scripts.Script):
|
|||
# Special handling of vars
|
||||
for att in Unprompted.shortcode_user_vars:
|
||||
# change models
|
||||
if att == "sd_model":
|
||||
if att == "sd_model" and Unprompted.shortcode_user_vars[att] != original_model:
|
||||
info = sd_models.get_closet_checkpoint_match(Unprompted.shortcode_user_vars["sd_model"])
|
||||
if (info): sd_models.load_model(info,None,None)
|
||||
# control controlnet
|
||||
|
|
@ -510,8 +518,8 @@ class Scripts(scripts.Script):
|
|||
# After routines
|
||||
def postprocess(self, p, processed, is_enabled=True, unprompted_seed=-1, match_main_seed=True):
|
||||
if not self.allow_postprocess or not is_enabled:
|
||||
Unprompted.log("Bypassing After routine")
|
||||
self.allow_postprocess = True
|
||||
Unprompted.log("Bypassing After routine to avoid infinite loop")
|
||||
# self.allow_postprocess = True
|
||||
return False # Prevents endless loop with some shortcodes
|
||||
self.allow_postprocess = False
|
||||
Unprompted.log("Entering After routine...")
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ class Shortcode():
|
|||
def run_block(self, pargs, kwargs, context, content):
|
||||
index = int(self.Unprompted.parse_advanced(pargs[0])) if len(pargs) > 0 else 0
|
||||
if self.last_index != index or "allow_dupe_index" in pargs:
|
||||
self.Unprompted.log(f"Queueing up conent: {content}")
|
||||
self.after_content.insert(index,content)
|
||||
self.last_index = index
|
||||
else: self.Unprompted.log("Duplicate [after] content detected, skipping - include allow_dupe_index to bypass this check")
|
||||
|
|
|
|||
|
|
@ -8,6 +8,9 @@ class Shortcode():
|
|||
self.description = "Processes the file content of 'path.'"
|
||||
|
||||
def run_atomic(self, pargs, kwargs, context):
|
||||
if "_bypass_if" in kwargs:
|
||||
if self.Unprompted.parse_advanced(kwargs["_bypass_if"],context): return ""
|
||||
|
||||
file_string = self.Unprompted.parse_alt_tags(pargs[0],context)
|
||||
this_encoding = self.Unprompted.parse_advanced(kwargs["_encoding"],context) if "_encoding" in kwargs else "utf-8"
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,100 @@
|
|||
class Shortcode():
|
||||
def __init__(self,Unprompted):
|
||||
self.Unprompted = Unprompted
|
||||
self.description = "Applies color correction to a resulting image."
|
||||
self.wizard_prepend = Unprompted.Config.syntax.tag_start + "after" + Unprompted.Config.syntax.tag_end + Unprompted.Config.syntax.tag_start_alt + "color_correct"
|
||||
self.wizard_append = Unprompted.Config.syntax.tag_end_alt + Unprompted.Config.syntax.tag_start + Unprompted.Config.syntax.tag_close + "after" + Unprompted.Config.syntax.tag_end
|
||||
|
||||
def run_atomic(self, pargs, kwargs, context):
|
||||
from PIL import Image
|
||||
def autocrop_image(image, border = 0):
|
||||
# Get the bounding box
|
||||
bbox = image.getbbox()
|
||||
# Crop the image to the contents of the bounding box
|
||||
image = image.crop(bbox)
|
||||
# Determine the width and height of the cropped image
|
||||
(width, height) = image.size
|
||||
# Add border
|
||||
width += border * 2
|
||||
height += border * 2
|
||||
# Create a new image object for the output image
|
||||
cropped_image = Image.new("RGBA", (width, height), (0,0,0,0))
|
||||
# Paste the cropped image onto the new image
|
||||
cropped_image.paste(image, (border, border))
|
||||
# Done!
|
||||
return cropped_image
|
||||
|
||||
from blendmodes.blend import blendLayers, BlendType
|
||||
|
||||
color_correct_method = kwargs["method"] if "method" in kwargs else "mkl"
|
||||
blend_lum = True if "blend_lum" in pargs else False
|
||||
debug = True if "debug" in pargs else False
|
||||
|
||||
try:
|
||||
image_to_fix = self.Unprompted.after_processed.images[0].copy()
|
||||
except:
|
||||
self.Unprompted.log("This must be used inside of an [after] block",context="ERROR")
|
||||
return("")
|
||||
starting_image = self.Unprompted.p_copy.init_images[0]
|
||||
|
||||
orig_image = image_to_fix.copy()
|
||||
if self.Unprompted.shortcode_user_vars["image_mask"]:
|
||||
mask = self.Unprompted.shortcode_user_vars["image_mask"]
|
||||
mask = mask.convert("L")
|
||||
else: mask = None
|
||||
|
||||
if "source" in kwargs:
|
||||
set_kwargs = kwargs
|
||||
set_pargs = pargs
|
||||
set_pargs.insert(0,"return_image")
|
||||
set_kwargs["txt2mask_init_image"] = starting_image
|
||||
set_kwargs["precision"] = "150"
|
||||
set_kwargs["padding"] = "0"
|
||||
set_kwargs["method"] = "clipseg"
|
||||
self.Unprompted.shortcode_user_vars["image_mask"] = None
|
||||
source_mask = self.Unprompted.shortcode_objects["txt2mask"].run_block(set_pargs,set_kwargs,None,kwargs["source"])
|
||||
source_mask = source_mask.convert("L")
|
||||
starting_image.putalpha(source_mask)
|
||||
starting_image = autocrop_image(starting_image)
|
||||
|
||||
import cv2
|
||||
import numpy
|
||||
from sklearn.cluster import KMeans
|
||||
np_img = numpy.array(starting_image)
|
||||
all_pixels = np_img.reshape(-1, 4)
|
||||
#all_pixels.shape
|
||||
just_non_alpha = all_pixels[all_pixels[:, 3] == 255]
|
||||
#just_non_alpha.shape
|
||||
avg_model = KMeans(3)
|
||||
reshaped_arr = np_img[:, :, :3].reshape(-1, 3)
|
||||
avg_model.fit(np_img[:, :, :3].reshape(-1, 3))
|
||||
KMeans(algorithm='auto', copy_x=True, init='k-means++', max_iter=300,
|
||||
n_clusters=3, n_init=10, random_state=None, tol=0.0001, verbose=0)
|
||||
print(avg_model.cluster_centers_)
|
||||
avg_color = avg_model.cluster_centers_[0]
|
||||
|
||||
if debug: self.Unprompted.log(f"Average color: {avg_color}")
|
||||
|
||||
new_image = Image.new("RGB", starting_image.size, (int(avg_color[0]),int(avg_color[1]),int(avg_color[2])))
|
||||
new_image.paste(starting_image, (0, 0), starting_image)
|
||||
if debug: new_image.save("color_correct_alpha_test.png")
|
||||
starting_image = new_image.copy()
|
||||
|
||||
|
||||
strength = float(kwargs["strength"]) if "strength" in kwargs else 1.0
|
||||
|
||||
fixed_image = self.Unprompted.color_match(starting_image,image_to_fix,color_correct_method,1).convert("RGBA")
|
||||
|
||||
if blend_lum:
|
||||
fixed_image = blendLayers(fixed_image, orig_image, BlendType.LUMINOSITY)
|
||||
|
||||
if strength < 1.0:
|
||||
fixed_image.putalpha(int(255 * strength))
|
||||
else: self.Unprompted.after_processed.images[0] = fixed_image
|
||||
|
||||
orig_image.paste(fixed_image,(0,0), fixed_image)
|
||||
orig_image.resize((self.Unprompted.after_processed.images[0].size[0],self.Unprompted.after_processed.images[0].size[1]))
|
||||
self.Unprompted.after_processed.images[0].paste(orig_image,(0,0),mask)
|
||||
|
||||
|
||||
# self.Unprompted.after_processed.images[0] = fixed_image
|
||||
|
|
@ -5,8 +5,16 @@ class Shortcode():
|
|||
self.image_mask = None
|
||||
self.show = False
|
||||
self.description = "Creates an image mask from the content for use with inpainting."
|
||||
try:
|
||||
del self.cached_model
|
||||
del self.cached_transform
|
||||
del self.cached_model_method
|
||||
del self.cached_predictor
|
||||
except: pass
|
||||
self.cached_model = -1
|
||||
self.cached_transform = -1
|
||||
self.cached_model_method = ""
|
||||
self.cached_predictor = -1
|
||||
|
||||
def run_block(self, pargs, kwargs, context, content):
|
||||
from PIL import ImageChops, Image, ImageOps
|
||||
|
|
@ -16,45 +24,80 @@ class Shortcode():
|
|||
from matplotlib import pyplot as plt
|
||||
import cv2
|
||||
import numpy
|
||||
import gc
|
||||
from modules.images import flatten
|
||||
from modules.shared import opts
|
||||
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
|
||||
|
||||
gc.collect()
|
||||
|
||||
|
||||
if "txt2mask_init_image" in kwargs:
|
||||
self.init_image = kwargs["txt2mask_init_image"]
|
||||
self.init_image = kwargs["txt2mask_init_image"].copy()
|
||||
elif "init_images" not in self.Unprompted.shortcode_user_vars:
|
||||
self.Unprompted.log("No init_images found...")
|
||||
return
|
||||
else: self.init_image = self.Unprompted.shortcode_user_vars["init_images"][0]
|
||||
else: self.init_image = self.Unprompted.shortcode_user_vars["init_images"][0].copy()
|
||||
|
||||
method = self.Unprompted.parse_advanced(kwargs["method"],context) if "method" in kwargs else "clipseg"
|
||||
|
||||
if method == "clipseg":
|
||||
mask_width = 512
|
||||
mask_height = 512
|
||||
elif method == "sam":
|
||||
import launch
|
||||
if not launch.is_installed("groundingdino"):
|
||||
self.Unprompted.log("Attempting to install GroundingDINO library. Buckle up bro")
|
||||
try:
|
||||
launch.run_pip("install git+https://github.com/IDEA-Research/GroundingDINO","requirements for Unprompted - txt2mask SAM method")
|
||||
except Exception as e:
|
||||
self.Unprompted.log(f"GroundingDINO problem: {e}",context="ERROR")
|
||||
self.Unprompted.log(f"Please open an issue on their repo, not mine.",context="ERROR")
|
||||
return ""
|
||||
else:
|
||||
if method == "grounded_sam":
|
||||
import launch
|
||||
if not launch.is_installed("groundingdino"):
|
||||
self.Unprompted.log("Attempting to install GroundingDINO library. Buckle up bro")
|
||||
try:
|
||||
launch.run_pip("install git+https://github.com/IDEA-Research/GroundingDINO","requirements for Unprompted - txt2mask SAM method")
|
||||
except Exception as e:
|
||||
self.Unprompted.log(f"GroundingDINO problem: {e}",context="ERROR")
|
||||
self.Unprompted.log(f"Please open an issue on their repo, not mine.",context="ERROR")
|
||||
return ""
|
||||
|
||||
mask_width = self.Unprompted.shortcode_user_vars["width"]
|
||||
mask_height = self.Unprompted.shortcode_user_vars["height"]
|
||||
mask_width = self.init_image.size[0]
|
||||
mask_height = self.init_image.size[1]
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
if device == "cuda": torch.cuda.empty_cache()
|
||||
|
||||
if "stamp" in kwargs:
|
||||
stamps = (self.Unprompted.parse_advanced(kwargs["stamp"],context)).split(self.Unprompted.Config.syntax.delimiter)
|
||||
|
||||
stamp_x = int(float(self.Unprompted.parse_advanced(kwargs["stamp_x"],context))) if "stamp_x" in kwargs else 0
|
||||
stamp_y = int(float(self.Unprompted.parse_advanced(kwargs["stamp_y"],context))) if "stamp_y" in kwargs else 0
|
||||
stamp_x_orig = stamp_x
|
||||
stamp_y_orig = stamp_y
|
||||
stamp_method = self.Unprompted.parse_advanced(kwargs["stamp_method"],context) if "stamp_method" in kwargs else "stretch"
|
||||
|
||||
for stamp in stamps:
|
||||
# Checks for file in images/stamps, otherwise assumes absolute path
|
||||
stamp_path = f"{self.Unprompted.base_dir}/images/stamps/{stamp}.png"
|
||||
if not os.path.exists(stamp_path): stamp_path = stamp
|
||||
if not os.path.exists(stamp_path):
|
||||
self.Unprompted.log(f"Stamp not found: {stamp_path}",context="ERROR")
|
||||
continue
|
||||
|
||||
stamp_img = Image.open(stamp_path).convert("RGBA")
|
||||
|
||||
if stamp_method == "stretch":
|
||||
stamp_img = stamp_img.resize((self.init_image.size[0],self.init_image.size[1]))
|
||||
elif stamp_method == "center":
|
||||
stamp_x = stamp_x_orig + int((mask_width - stamp_img.size[0]) / 2)
|
||||
stamp_y = stamp_y_orig + int((mask_height - stamp_img.size[1]) / 2)
|
||||
|
||||
stamp_blur = int(float(self.Unprompted.parse_advanced(kwargs["stamp_blur"],context))) if "stamp_blur" in kwargs else 0
|
||||
if stamp_blur:
|
||||
from PIL import ImageFilter
|
||||
blur = ImageFilter.GaussianBlur(stamp_blur)
|
||||
stamp_img = stamp_img.filter(blur)
|
||||
|
||||
self.init_image.paste(stamp_img,(stamp_x,stamp_y),stamp_img)
|
||||
|
||||
brush_mask_mode = self.Unprompted.parse_advanced(kwargs["mode"],context) if "mode" in kwargs else "add"
|
||||
self.show = True if "show" in pargs else False
|
||||
|
||||
|
||||
box_thresh = float(self.Unprompted.parse_advanced(kwargs["box_threshold"],context)) if "box_threshold" in kwargs else 0.3
|
||||
text_thresh = float(self.Unprompted.parse_advanced(kwargs["text_threshold"],context)) if "text_threshold" in kwargs else 0.25
|
||||
|
||||
self.legacy_weights = True if "legacy_weights" in pargs else False
|
||||
smoothing = int(self.Unprompted.parse_advanced(kwargs["smoothing"],context)) if "smoothing" in kwargs else 20
|
||||
smoothing_kernel = None
|
||||
|
|
@ -103,15 +146,16 @@ class Shortcode():
|
|||
for i, mask in enumerate(masks):
|
||||
|
||||
filename = f"mask_{mode}_{i}.png"
|
||||
|
||||
if method == "clipseg":
|
||||
plt.imsave(filename,torch.sigmoid(mask[0]))
|
||||
img = cv2.imread(filename)
|
||||
# TODO: Figure out how to convert the plot above to numpy instead of re-loading image
|
||||
else:
|
||||
plt.imsave(filename,mask)
|
||||
import random
|
||||
img = cv2.imread(filename)
|
||||
img = cv2.resize(img,(mask_width,mask_height))
|
||||
|
||||
|
||||
|
||||
if padding_dilation_kernel is not None:
|
||||
|
|
@ -119,7 +163,11 @@ class Shortcode():
|
|||
else: img = cv2.erode(img,padding_dilation_kernel,iterations=1)
|
||||
if smoothing_kernel is not None: img = cv2.filter2D(img,-1,smoothing_kernel)
|
||||
|
||||
#if method == "clip_surgery":
|
||||
#gray_image = cv2.cvtColor(cv2.cvtColor(img, cv2.COLOR_BGR2LUV), cv2.COLOR_BGR2GRAY)
|
||||
#else: gray_image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
gray_image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
Image.fromarray(gray_image).save("mask_gray_test.png")
|
||||
(thresh, bw_image) = cv2.threshold(gray_image, mask_precision, 255, cv2.THRESH_BINARY)
|
||||
|
||||
if (mode == "discard"): bw_image = numpy.invert(bw_image)
|
||||
|
|
@ -132,9 +180,140 @@ class Shortcode():
|
|||
return(final_img)
|
||||
|
||||
def get_mask():
|
||||
preds = []
|
||||
negative_preds = []
|
||||
image_pil = flatten(self.init_image, opts.img2img_background_color)
|
||||
|
||||
if method == "sam":
|
||||
if method == "clip_surgery":
|
||||
from lib_unprompted import clip_surgery as clip
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from matplotlib import pyplot as plt
|
||||
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
|
||||
from torchvision.transforms import InterpolationMode
|
||||
BICUBIC = InterpolationMode.BICUBIC
|
||||
from segment_anything import sam_model_registry, SamPredictor
|
||||
|
||||
# default imagenet redundant features
|
||||
redundants = ['a bad photo of a {}.', 'a photo of many {}.', 'a sculpture of a {}.', 'a photo of the hard to see {}.', 'a low resolution photo of the {}.', 'a rendering of a {}.', 'graffiti of a {}.', 'a bad photo of the {}.', 'a cropped photo of the {}.', 'a tattoo of a {}.', 'the embroidered {}.', 'a photo of a hard to see {}.', 'a bright photo of a {}.', 'a photo of a clean {}.', 'a photo of a dirty {}.', 'a dark photo of the {}.', 'a drawing of a {}.', 'a photo of my {}.', 'the plastic {}.', 'a photo of the cool {}.', 'a close-up photo of a {}.', 'a black and white photo of the {}.', 'a painting of the {}.', 'a painting of a {}.', 'a pixelated photo of the {}.', 'a sculpture of the {}.', 'a bright photo of the {}.', 'a cropped photo of a {}.', 'a plastic {}.', 'a photo of the dirty {}.', 'a jpeg corrupted photo of a {}.', 'a blurry photo of the {}.', 'a photo of the {}.', 'a good photo of the {}.', 'a rendering of the {}.', 'a {} in a video game.', 'a photo of one {}.', 'a doodle of a {}.', 'a close-up photo of the {}.', 'a photo of a {}.', 'the origami {}.', 'the {} in a video game.', 'a sketch of a {}.', 'a doodle of the {}.', 'a origami {}.', 'a low resolution photo of a {}.', 'the toy {}.', 'a rendition of the {}.', 'a photo of the clean {}.', 'a photo of a large {}.', 'a rendition of a {}.', 'a photo of a nice {}.', 'a photo of a weird {}.', 'a blurry photo of a {}.', 'a cartoon {}.', 'art of a {}.', 'a sketch of the {}.', 'a embroidered {}.', 'a pixelated photo of a {}.', 'itap of the {}.', 'a jpeg corrupted photo of the {}.', 'a good photo of a {}.', 'a plushie {}.', 'a photo of the nice {}.', 'a photo of the small {}.', 'a photo of the weird {}.', 'the cartoon {}.', 'art of the {}.', 'a drawing of the {}.', 'a photo of the large {}.', 'a black and white photo of a {}.', 'the plushie {}.', 'a dark photo of a {}.', 'itap of a {}.', 'graffiti of the {}.', 'a toy {}.', 'itap of my {}.', 'a photo of a cool {}.', 'a photo of a small {}.', 'a tattoo of the {}.', 'there is a {} in the scene.', 'there is the {} in the scene.', 'this is a {} in the scene.', 'this is the {} in the scene.', 'this is one {} in the scene.']
|
||||
|
||||
if "redundant_features" in kwargs: redundants.extend(kwargs["redundant_features"].split(self.Unprompted.Config.syntax.delimiter))
|
||||
self.bypass_sam = True if "bypass_sam" in pargs else False
|
||||
|
||||
### Init CLIP and data
|
||||
if self.cached_model == -1 or self.cached_model_method != method:
|
||||
model, preprocess = clip.load("CS-ViT-B/16", device=device)
|
||||
model.eval()
|
||||
# Cache for future runs
|
||||
self.cached_model = model
|
||||
self.cached_transform = preprocess
|
||||
else:
|
||||
self.Unprompted.log("Using cached model(s) for CLIP_Surgery method")
|
||||
model = self.cached_model
|
||||
preprocess = self.cached_transform
|
||||
|
||||
image = preprocess(image_pil).unsqueeze(0).to(device)
|
||||
cv2_img = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
|
||||
|
||||
### CLIP Surgery for a single text, without fixed label sets
|
||||
with torch.no_grad():
|
||||
# CLIP architecture surgery acts on the image encoder
|
||||
image_features = model.encode_image(image)
|
||||
image_features = image_features / image_features.norm(dim=1, keepdim=True)
|
||||
|
||||
# Prompt ensemble for text features with normalization
|
||||
text_features = clip.encode_text_with_prompt_ensemble(model, prompts, device)
|
||||
|
||||
if (negative_prompts):
|
||||
negative_text_features = clip.encode_text_with_prompt_ensemble(model, negative_prompts, device)
|
||||
|
||||
# Extract redundant features from an empty string
|
||||
redundant_features = clip.encode_text_with_prompt_ensemble(model, [""], device, redundants)
|
||||
|
||||
# no sam
|
||||
if self.bypass_sam:
|
||||
def reg_inference(text_features):
|
||||
preds = []
|
||||
# Apply feature surgery for single text
|
||||
similarity = clip.clip_feature_surgery(image_features, text_features, redundant_features)
|
||||
similarity_map = clip.get_similarity_map(similarity[:, 1:, :], cv2_img.shape[:2])
|
||||
|
||||
# Draw similarity map
|
||||
for b in range(similarity_map.shape[0]):
|
||||
for n in range(similarity_map.shape[-1]):
|
||||
vis = (similarity_map[b, :, :, n].cpu().numpy() * 255).astype('uint8')
|
||||
preds.append(vis)
|
||||
return(preds)
|
||||
preds = reg_inference(text_features)
|
||||
if (negative_prompts): negative_preds = reg_inference(negative_text_features)
|
||||
else:
|
||||
point_thresh = float(self.Unprompted.parse_advanced(kwargs["point_threshold"],context)) if "point_threshold" in kwargs else 0.98
|
||||
multimask_output = True if "multimask_output" in pargs else False
|
||||
|
||||
# Init SAM
|
||||
if self.cached_predictor == -1 or self.cached_model_method != method:
|
||||
sam_model_dir = f"{self.Unprompted.base_dir}/models/segment_anything"
|
||||
os.makedirs(sam_model_dir, exist_ok=True)
|
||||
sam_filename = "sam_vit_h_4b8939.pth"
|
||||
sam_file = f"{sam_model_dir}/{sam_filename}"
|
||||
# Download model weights if we don't have them yet
|
||||
if not os.path.exists(sam_file):
|
||||
print("Downloading SAM model weights...")
|
||||
self.Unprompted.download_file(sam_file,f"https://dl.fbaipublicfiles.com/segment_anything/{sam_filename}")
|
||||
|
||||
model_type = "vit_h"
|
||||
sam = sam_model_registry[model_type](checkpoint=sam_file)
|
||||
sam.to(device=device)
|
||||
predictor = SamPredictor(sam)
|
||||
|
||||
self.cached_predictor = predictor
|
||||
else:
|
||||
predictor = self.cached_predictor
|
||||
|
||||
predictor.set_image(np.array(image_pil))
|
||||
self.cached_model_method = method
|
||||
|
||||
def sam_inference(text_features):
|
||||
preds = []
|
||||
|
||||
# Combine features after removing redundant features and min-max norm
|
||||
sm = clip.clip_feature_surgery(image_features, text_features, redundant_features)[0, 1:, :]
|
||||
sm_norm = (sm - sm.min(0, keepdim=True)[0]) / (sm.max(0, keepdim=True)[0] - sm.min(0, keepdim=True)[0])
|
||||
sm_mean = sm_norm.mean(-1, keepdim=True)
|
||||
# get positive points from individual maps, and negative points from the mean map
|
||||
p, l = clip.similarity_map_to_points(sm_mean, cv2_img.shape[:2], t=point_thresh)
|
||||
num = len(p) // 2
|
||||
points = p[num:] # negatives in the second half
|
||||
labels = [l[num:]]
|
||||
for i in range(sm.shape[-1]):
|
||||
p, l = clip.similarity_map_to_points(sm[:, i], cv2_img.shape[:2], t=point_thresh)
|
||||
num = len(p) // 2
|
||||
points = points + p[:num] # positive in first half
|
||||
labels.append(l[:num])
|
||||
labels = np.concatenate(labels, 0)
|
||||
|
||||
# Inference SAM with points from CLIP Surgery
|
||||
masks, scores, logits = predictor.predict(point_labels=labels, point_coords=np.array(points), multimask_output=multimask_output)
|
||||
mask = masks[np.argmax(scores)]
|
||||
mask = mask.astype('uint8')
|
||||
|
||||
vis = cv2_img.copy()
|
||||
vis[mask > 0] = np.array([255, 255, 255], dtype=np.uint8)
|
||||
vis[mask == 0] = np.array([0, 0, 0], dtype=np.uint8)
|
||||
preds.append(vis)
|
||||
if self.show:
|
||||
for idx,mask in enumerate(masks):
|
||||
plt.imsave(f"mask{idx}.png",mask)
|
||||
|
||||
return(preds)
|
||||
|
||||
preds = sam_inference(text_features)
|
||||
if negative_prompts: negative_preds = sam_inference(negative_text_features)
|
||||
|
||||
elif method == "grounded_sam":
|
||||
box_thresh = float(self.Unprompted.parse_advanced(kwargs["box_threshold"],context)) if "box_threshold" in kwargs else 0.3
|
||||
text_thresh = float(self.Unprompted.parse_advanced(kwargs["text_threshold"],context)) if "text_threshold" in kwargs else 0.25
|
||||
# Grounding DINO
|
||||
import groundingdino.datasets.transforms as T
|
||||
from groundingdino.models import build_model
|
||||
|
|
@ -205,7 +384,7 @@ class Shortcode():
|
|||
model_config_path = f"{self.Unprompted.base_dir}/lib_unprompted/groundingdino/config/GroundingDINO_SwinT_OGC.py"
|
||||
|
||||
# load model
|
||||
if self.cached_model == -1:
|
||||
if self.cached_model == -1 or self.cached_model_method != method:
|
||||
args = SLConfig.fromfile(model_config_path)
|
||||
args.device = device
|
||||
model = build_model(args)
|
||||
|
|
@ -223,6 +402,7 @@ class Shortcode():
|
|||
)
|
||||
self.cached_model = model
|
||||
self.cached_transform = transform
|
||||
self.cached_model_method = method
|
||||
|
||||
else:
|
||||
self.Unprompted.log("Using cached GroundingDINO model.")
|
||||
|
|
@ -248,10 +428,13 @@ class Shortcode():
|
|||
|
||||
preds = []
|
||||
value = 0
|
||||
mask_img = torch.zeros(masks.shape[-2:])
|
||||
mask_img = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
|
||||
for idx, mask in enumerate(masks):
|
||||
mask_img[mask.cpu().numpy()[0] == True] = value + idx + 1
|
||||
preds.append(mask_img.numpy())
|
||||
# mask_img[mask.cpu().numpy()[0] == True] = value + idx + 1
|
||||
mask_img[mask.cpu().numpy()[0] >= 1] = np.array([255, 255, 255], dtype=np.uint8)
|
||||
mask_img[mask.cpu().numpy()[0] < 1] = np.array([0, 0, 0], dtype=np.uint8)
|
||||
# TODO: Figure out if we can take advantage of individual mask layers rather than stacking as composite
|
||||
preds.append(mask_img)
|
||||
|
||||
return(preds)
|
||||
|
||||
|
|
@ -293,7 +476,7 @@ class Shortcode():
|
|||
|
||||
|
||||
# load model
|
||||
if self.cached_model == -1:
|
||||
if self.cached_model == -1 or self.cached_model_method != method:
|
||||
self.Unprompted.log("Loading clipseg model...")
|
||||
model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, complex_trans_conv=not self.legacy_weights)
|
||||
|
||||
|
|
@ -310,6 +493,7 @@ class Shortcode():
|
|||
# Cache for future runs
|
||||
self.cached_model = model
|
||||
self.cached_transform = transform
|
||||
self.cached_model_method = method
|
||||
else:
|
||||
self.Unprompted.log("Using cached clipseg model.")
|
||||
model = self.cached_model
|
||||
|
|
@ -319,11 +503,17 @@ class Shortcode():
|
|||
|
||||
# predict
|
||||
with torch.no_grad():
|
||||
preds = model(img.repeat(prompt_parts,1,1,1).to(device=device), prompts)[0].cpu()
|
||||
if "image_prompt" in kwargs:
|
||||
from PIL import Image
|
||||
img_mask = flatten(Image.open(r"A:/inbox/test_mask.png"), opts.img2img_background_color)
|
||||
img_mask = transform(img_mask).unsqueeze(0)
|
||||
preds = model(img.to(device=device), img_mask.to(device=device))[0].cpu()
|
||||
else:
|
||||
preds = model(img.repeat(prompt_parts,1,1,1).to(device=device), prompts)[0].cpu()
|
||||
|
||||
if (negative_prompts): negative_preds = model(img.repeat(negative_prompt_parts,1,1,1).to(device=device), negative_prompts)[0].cpu()
|
||||
|
||||
# All of the below logic applies to both clipseg and sam
|
||||
|
||||
if "image_mask" not in self.Unprompted.shortcode_user_vars: self.Unprompted.shortcode_user_vars["image_mask"] = None
|
||||
|
||||
if (brush_mask_mode == "add" and self.Unprompted.shortcode_user_vars["image_mask"] is not None):
|
||||
|
|
@ -411,6 +601,8 @@ class Shortcode():
|
|||
if "unload_model" in pargs:
|
||||
self.model = -1
|
||||
self.cached_model = -1
|
||||
self.cached_model_method = ""
|
||||
self.cached_predictor = -1
|
||||
|
||||
return final_img
|
||||
|
||||
|
|
@ -450,7 +642,7 @@ class Shortcode():
|
|||
|
||||
def ui(self,gr):
|
||||
gr.Radio(label="Mask blend mode 🡢 mode",choices=["add","subtract","discard"],value="add",interactive=True)
|
||||
gr.Radio(label="Masking tech method 🡢 method",choices=["sam","clipseg"],value="sam",interactive=True)
|
||||
gr.Radio(label="Masking tech method 🡢 method",choices=["clipseg","clip_surgery","grounded_sam"],value="clipseg",interactive=True)
|
||||
gr.Checkbox(label="Show mask in output 🡢 show")
|
||||
gr.Checkbox(label="Use clipseg legacy weights 🡢 legacy_weights")
|
||||
gr.Number(label="Precision of selected area 🡢 precision",value=100,interactive=True)
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ class Shortcode():
|
|||
|
||||
debug = True if "debug" in pargs else False
|
||||
show_original = True if "show_original" in pargs else False
|
||||
color_correct_method = self.Unprompted.parse_alt_tags(kwargs["color_correct_method"],context) if "color_correct_method" in kwargs else "mkl"
|
||||
color_correct_method = self.Unprompted.parse_alt_tags(kwargs["color_correct_method"],context) if "color_correct_method" in kwargs else "none"
|
||||
color_correct_timing = self.Unprompted.parse_alt_tags(kwargs["color_correct_timing"],context) if "color_correct_timing" in kwargs else "pre"
|
||||
color_correct_strength = int(float(self.Unprompted.parse_advanced(kwargs["color_correct_strength"],context))) if "color_correct_strength" in kwargs else 1
|
||||
manual_mask_mode = self.Unprompted.parse_alt_tags(kwargs["mode"],context) if "mode" in kwargs else "subtract"
|
||||
|
|
@ -108,7 +108,7 @@ class Shortcode():
|
|||
cfg_min = float(self.Unprompted.parse_advanced(kwargs["cfg_scale_min"],context)) if "cfg_scale_min" in kwargs else 7.0
|
||||
target_size_max = float(self.Unprompted.parse_advanced(kwargs["mask_size_max"],context)) if "mask_size_max" in kwargs else 0.5
|
||||
target_size_max_orig = target_size_max
|
||||
cfg_max = self.Unprompted.p_copy.cfg_scale - cfg_min
|
||||
cfg_max = max(cfg_min,self.Unprompted.p_copy.cfg_scale - cfg_min)
|
||||
|
||||
padding_original = int(float(self.Unprompted.parse_advanced(kwargs["contour_padding"],context))) if "contour_padding" in kwargs else 0
|
||||
min_area = int(float(self.Unprompted.parse_advanced(kwargs["min_area"],context))) if "min_area" in kwargs else 50
|
||||
|
|
@ -141,7 +141,7 @@ class Shortcode():
|
|||
target_size_max = target_size_max_orig * target_multiplier
|
||||
sd_unit = 64
|
||||
|
||||
denoise_unit = (denoising_max / 2) * 0.125
|
||||
denoise_unit = (denoising_max / 4) * 0.125
|
||||
cfg_min_unit = (cfg_min / 2) * 0.125
|
||||
cfg_max_unit = (cfg_max / 2) * 0.125
|
||||
step_unit = math.ceil(self.Unprompted.p_copy.steps * 0.125)
|
||||
|
|
@ -154,6 +154,7 @@ class Shortcode():
|
|||
cfg_min += cfg_min_unit
|
||||
cfg_max += cfg_max_unit
|
||||
sharpen_amount += 0.125
|
||||
denoising_max += denoise_unit
|
||||
self.Unprompted.p_copy.steps += step_unit
|
||||
|
||||
upscale_width = min(hires_size_max,upscale_width)
|
||||
|
|
@ -170,6 +171,7 @@ class Shortcode():
|
|||
|
||||
if "include_original" in pargs:
|
||||
append_originals.append(image_pil.copy())
|
||||
if "mask_method" in kwargs: set_kwargs["method"] = kwargs["mask_method"]
|
||||
|
||||
set_kwargs["txt2mask_init_image"] = image_pil
|
||||
mask_image = self.Unprompted.shortcode_objects["txt2mask"].run_block(set_pargs,set_kwargs,None,target_mask)
|
||||
|
|
@ -257,7 +259,7 @@ class Shortcode():
|
|||
self.Unprompted.log(f"Denoising strength is {self.Unprompted.p_copy.denoising_strength}")
|
||||
if "cfg_scale" not in kwargs:
|
||||
self.Unprompted.p_copy.cfg_scale = cfg_min + sig * cfg_max
|
||||
self.Unprompted.log(f"CFG Scale is {self.Unprompted.shortcode_user_vars['cfg_scale']} (min {cfg_min}, max {cfg_min+cfg_max})")
|
||||
self.Unprompted.log(f"CFG Scale is {self.Unprompted.p_copy.cfg_scale} (min {cfg_min}, max {cfg_min+cfg_max})")
|
||||
else:
|
||||
self.Unprompted.log("Humongous target detected. Skipping zoom_enhance...")
|
||||
continue
|
||||
|
|
@ -312,16 +314,18 @@ class Shortcode():
|
|||
|
||||
|
||||
# run img2img now to improve face
|
||||
|
||||
if is_img2img:
|
||||
fixed_image = process_images_inner_(self.Unprompted.p_copy)
|
||||
fixed_image = fixed_image.images[0]
|
||||
else:
|
||||
#workaround for txt2img
|
||||
# workaround for txt2img, not sure if compatible with controlnet
|
||||
for att in dir(self.Unprompted.p_copy):
|
||||
if not att.startswith("__") and att != "sd_model":
|
||||
self.Unprompted.shortcode_user_vars[att] = getattr(self.Unprompted.p_copy,att)
|
||||
fixed_image = self.Unprompted.shortcode_objects["img2img"].run_atomic(set_pargs,None,None)
|
||||
if debug: fixed_image.save("zoom_enhance_4after.png")
|
||||
|
||||
|
||||
if color_correct_method != "none" and starting_image:
|
||||
try:
|
||||
|
|
@ -352,7 +356,7 @@ class Shortcode():
|
|||
current_mask = current_mask.resize((width,height))
|
||||
if debug: current_mask.save("zoom_enhance_5d_current_main_mask.png")
|
||||
image_pil.paste(corrected_main_img,(0,0),current_mask)
|
||||
image_pil.save("zoom_enhance_5e_corrected_main_image.png")
|
||||
if debug: image_pil.save("zoom_enhance_5e_corrected_main_image.png")
|
||||
except Exception as e:
|
||||
self.Unprompted.log(f"{e}",context="ERROR")
|
||||
|
||||
|
|
@ -369,13 +373,17 @@ class Shortcode():
|
|||
# Slap our new image back onto the original
|
||||
image_pil.paste(fixed_image, (x1 - padding, y1 - padding),sub_mask)
|
||||
|
||||
# self.Unprompted.shortcode_user_vars["init_images"].append(image_pil)
|
||||
if show_original: append_originals.append(image_pil.copy())
|
||||
else: self.Unprompted.after_processed.images[image_idx] = image_pil
|
||||
|
||||
# test outside after block, WIP pls don't use yet
|
||||
self.Unprompted.log(f"Adding zoom_enhance result for image_idx {image_idx}")
|
||||
if context != "after":
|
||||
self.Unprompted.log("Attempting to use zoom_enhance outside of an after block... good luck")
|
||||
self.Unprompted.shortcode_user_vars["init_images"] = image_pil
|
||||
# main return
|
||||
else:
|
||||
try:
|
||||
if show_original: append_originals.append(image_pil.copy())
|
||||
else: self.Unprompted.after_processed.images[image_idx] = image_pil
|
||||
except Exception as e:
|
||||
self.Unprompted.log(f"Could not append zoom_enhance result: {e}",context="ERROR")
|
||||
|
||||
# Remove temp image key in case [img2img] is used later
|
||||
if "img2img_init_image" in self.Unprompted.shortcode_user_vars: del self.Unprompted.shortcode_user_vars["img2img_init_image"]
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
[template name="Bodysnatcher v1.0.0"]
|
||||
[template name="Bodysnatcher v1.1.0"]
|
||||

|
||||
## ⚠️ Important info, please read carefully:
|
||||
|
||||
|
|
@ -40,32 +40,35 @@ Always bodysnatch responsibly.
|
|||
|
||||
[set prefix _new _label="Prefix" _info="For example, the visual medium"]photo of[/set]
|
||||
[set subject _new _label="New subject"]mona lisa[/set]
|
||||
[set simple_description _new _label="Simple Description" _info="These terms will apply to both the full image and the cropped face, less is more"][/set]
|
||||
[set simple_description _new _label="Simple Description" _info="These terms will apply to both the full image and the cropped face, 1-3 words are usually plenty"][/set]
|
||||
[set class _new _label="Class" _info="The search term that determines the inpainting mask"]woman[/set]
|
||||
|
||||
[set background_mode _new _label="Background Mode" _ui="checkbox" _info="Inverts the class mask and disables the zoom_enhance step (note: you'll probably want to increase the mask precision)"]0[/set]
|
||||
|
||||
[set keep_hands _new _label="Keep original hands" _ui="checkbox" _info="You don't really want Stable Diffusion to remake those hands, do you?"]1[/set]
|
||||
[set keep_feet _new _label="Keep original feet" _ui="checkbox"]1[/set]
|
||||
|
||||
[set use_optimized_inference_settings _new _label="Use optimized inference settings" _ui="checkbox" _info="Locks CFG scale, denoising strength, etc. to recommended values"]1[/set]
|
||||
[set use_controlnet_preset _new _info="Loads multiple ControlNet units, make sure you have 'Allow other scripts to control this extension' enabled" _label="ControlNet preset" _ui="dropdown" _choices="none|photo_general_v1|dev"]none[/set]
|
||||
[set use_controlnet_preset _new _info="Loads multiple ControlNet units, please make sure you have 'Allow other scripts to control this extension' enabled (note: the 'dev' preset is for internal testing)" _label="ControlNet preset" _ui="dropdown" _choices="none|photo_general_v1|dev"]none[/set]
|
||||
|
||||
[wizard_ui_accordion _label="⚙️ Advanced Options"]
|
||||
{set fix_bodypart _new _label="Fix a body part"}face{/set}
|
||||
{set fix_bodypart _new _label="Fix a body part" _info="Note: currently not compatible with Background Mode"}face{/set}
|
||||
{set color_correct_method _new _label="Color correct method" _ui="dropdown" _choices="none|hm|mvgd|mkl|hm-mvgd-hm|hm-mkl-hm"}hm-mkl-hm{/set}
|
||||
{set color_correct_timing _new _label="Color correct timing" _info="Post may produce more accurate colors, but it tends to look a bit posterized" _ui="dropdown" _choices="pre|post"}pre{/set}
|
||||
{set color_correct_strength _new _label="Color correct strength" _ui="slider" _minimum=1 _maximum=5}1{/set}
|
||||
{set mask_method _new _label="Masking method (sam requires manual setup)" _ui="radio" _choices="clipseg|sam"}clipseg{/set}
|
||||
{set mask_method _new _label="Masking method" _ui="radio" _choices="clipseg|clip_surgery|grounded_sam"}clipseg{/set}
|
||||
{set manual_mask_mode _new _label="Manual masking mode" _ui="radio" _choices="add|subtract|discard"}subtract{/set}
|
||||
{set mask_precision _new _label="Mask precision"}75{/set}
|
||||
{set stamp _new _label="Stamp" _info="Paste a temporary image on the init image for the purpose of masking (check unprompted/images/stamps for default stamps)"}{/set}
|
||||
{set zoom_enhance_denoising_max _new}0.30{/set}
|
||||
{set zoom_enhance_base_cfg _new _ui="slider" _minimum="1" _maximum="30"}10{/set}
|
||||
{set zoom_enhance_base_cfg _new _ui="slider" _minimum="1" _maximum="30"}7{/set}
|
||||
{set show_original _new _label="Show unenhanced image in output window" _ui="checkbox"}0{/set}
|
||||
{set debug _new _label="Save debug images" _ui="checkbox"}0{/set}
|
||||
[/wizard_ui_accordion]
|
||||
|
||||
[sets neg_mask=""]
|
||||
[if keep_hands=1]{set neg_mask}fingers{/set}[/if]
|
||||
[if keep_feet=1]{set neg_mask _append}|feet{/set}[/if]
|
||||
[if "(keep_hands==1 and background_mode==0) or (keep_hands==0 and background_mode==1)"]{set neg_mask}fingers{/set}[/if]
|
||||
[if "(keep_feet==1 and background_mode==0) or (keep_feet==0 and background_mode==1)"]{set neg_mask _append}|feet{/set}[/if]
|
||||
|
||||
[if use_optimized_inference_settings=1]
|
||||
{sets cfg_scale=7.5 sampler_name="Euler a" steps=25 denoising_strength=0.75 mask_blur=10}
|
||||
|
|
@ -75,8 +78,8 @@ Always bodysnatch responsibly.
|
|||
{{sets controlnet_0_enabled=1 controlnet_0_module=softedge_hed controlnet_0_model=controlnet11Models_softedge controlnet_0_weight=0.25 controlnet_1_enabled=1 controlnet_1_module=mediapipe_face controlnet_1_model=control_mediapipe_face_sd15_v2 controlnet_1_weight=1.0 controlnet_2_enabled=1 controlnet_2_enabled=1 controlnet_2_module=openpose_full controlnet_2_model=controlnet11Models_openpose}}
|
||||
{/case}
|
||||
{case "dev"}
|
||||
{{sets controlnet_0_enabled=1 controlnet_0_module=softedge_hed controlnet_0_model=controlnet11Models_softedge controlnet_0_weight=0.5 controlnet_1_enabled=1 controlnet_1_module=mediapipe_face controlnet_1_model=control_mediapipe_face_sd15_v2 controlnet_1_weight=1.0 controlnet_2_enabled=1 controlnet_2_module=t2ia_color_grid controlnet_2_model=coadapter-color-sd15v1 controlnet_2_weight=1.0 controlnet_3_enabled=1 controlnet_3_module=openpose_full controlnet_3_model=controlnet11Models_openpose controlnet_3_weight=1.0}}
|
||||
{{sets controlnet_0_enabled=1 controlnet_0_module=softedge_hed controlnet_0_model=controlnet11Models_softedge controlnet_0_weight=0.25 controlnet_1_enabled=1 controlnet_1_module=mediapipe_face controlnet_1_model=control_mediapipe_face_sd15_v2 controlnet_1_weight=1.0 controlnet_2_enabled=1 controlnet_2_module=t2ia_color_grid controlnet_2_model=coadapter-color-sd15v1 controlnet_2_weight=1.0 controlnet_3_enabled=1 controlnet_3_module=openpose_full controlnet_3_model=controlnet11Models_openpose controlnet_3_weight=1.0}}
|
||||
{/case}
|
||||
[/switch]
|
||||
|
||||
[img2img_autosize][txt2mask precision="{get mask_precision}" method="{get mask_method}" mode="{get manual_mask_mode}" negative_mask="{get neg_mask}" padding=10 mask_blur=20][get class][/txt2mask][after]{zoom_enhance color_correct_method="[get color_correct_method]" color_correct_timing="[get color_correct_timing]" color_correct_strength="[get color_correct_strength]" [if show_original=1]show_original[/if] sharpen_amount=0.0 mode="subtract" [if debug=1]debug[/if] mask="[get class] [get fix_bodypart]" replacement="[get prefix] [get subject] [get fix_bodypart] [get simple_description _before=' ']" cfg_scale="[get zoom_enhance_base_cfg]" denoising_max="[get zoom_enhance_denoising_max]"}[/after][get prefix] [get subject][get simple_description _before=" "]
|
||||
[img2img_autosize][if batch_index=0]{txt2mask precision="{{get mask_precision}}" method="{{get mask_method}}" mode="{{get manual_mask_mode}}" negative_mask="{{get neg_mask}}" padding=10 mask_blur=20}{get class}{/txt2mask}{if background_mode=1}{{invert_mask}}{/if}[/if][if "background_mode==0 and batch_index==0"]{after}{{zoom_enhance mask_method="[get mask_method]" color_correct_method="[get color_correct_method]" color_correct_timing="[get color_correct_timing]" color_correct_strength="[get color_correct_strength]" [if show_original=1]show_original[/if] sharpen_amount=0.0 mode="subtract" [if debug=1]debug[/if] mask="[get class] [get fix_bodypart]" replacement="[get prefix] [get subject] [get fix_bodypart] [get simple_description _before=' ']" cfg_scale_min="[get zoom_enhance_base_cfg]" cfg_scale_max="15" denoising_max="[get zoom_enhance_denoising_max]"}}{/after}[/if][get prefix] [get subject][get simple_description _before=" "]
|
||||