pull/146/head
ThereforeGames 2023-04-21 10:36:08 -04:00
parent 8f6edd9d28
commit a264855b49
29 changed files with 1921 additions and 94 deletions

View File

@ -3,6 +3,37 @@ All notable changes to this project will be documented in this file.
For more details on new features, please check the [Manual](./MANUAL.md).
<details open><summary>8.3.0 - 21 April 2023</summary>
### Added
- New shortcode `[color_correct]`: provides the same automatic color grading features as Bodysnatcher, but in the form of a standalone block
- `[color_correct]`: supports the `source` argument, which is a string that processes the initial image with `[txt2mask]` and uses the resulting masked image as a source for color correction, as opposed to the entire image
- `[txt2mask]`: implemented [CLIP Surgery](https://github.com/xmed-lab/CLIP_Surgery) as a new method type ("clip_surgery") which optionally supports Segment Anything (dev comment: this is better than `clipseg` at certain tasks but worse at others - `clipseg` is still default for the time being)
- `[txt2mask]`: new argument `stamp` that pastes a temporary PNG onto the init image before running mask processing, useful for redacting a portion of the image for example
- `[txt2mask]`: supports `stamp_method` to choose sizing and positioning logic
- `[txt2mask]`: supports `stamp_x` and `stamp_y` for precise positioning of the stamp
- `[txt2mask]`: supports `stamp_blur` radius to engage optional gaussian filter
- `[txt2mask]`: 10 basic stamps are included by default
- `[zoom_enhance]`: now supports `mask_method`
- `[template]`: any kwargs in the Wizard template block will be passed to the constructed `[file]` block
- `[file]`: experimental new argument `_bypass_if` that skips file processing if the value returns true (intended to be used with Wizard templates)
- `[get sd_model]` should now work as expected
- Bodysnatcher: new option `background_mode` that inverts the mask and disables the zoom_enhance step
- Bodysnatcher: new setting `stamp`
### Changed
- `[zoom_enhance]`: the `color_correct_method` default value is now `none`
- `[zoom_enhance]`: fix for adaptive CFG scaling
- `[zoom_enhance]`: minor tweaks to the adaptive scaling algorithm
- `[zoom_enhance]`: speculative fix for an issue with batch processing, which may also resolve an infinite loop that could occur with Bodysnatcher
- `[txt2mask]`: the "sam" `method` has been renamed to "grounded_sam"
- `[txt2mask]`: fixed a crash related to switching back and forth between `method` types
- Moved legacy shortcodes into their own `legacy` folder
- Fixed a crash related to empty shortcode arguments
- Updated the manual
</details>
<details><summary>8.2.0 - 18 April 2023</summary>
### Added

View File

@ -289,7 +289,7 @@ Pressing **"Generate Shortcode"** will assemble a ready-to-use block of code tha
Alternatively, you can enable `Auto-include this in prompt` which will add the shortcode to your prompts behind the scenes. This essentially lets you use Unprompted shortcodes as if they were standalone scripts. You can enable/disable this setting on a per-shortcode basis.
The Wizard includes two distinct modes: Shortcodes and Functions.
The Wizard includes two distinct modes: Shortcodes and Templates.
</details>
@ -323,17 +323,17 @@ There are a few reserved argument names that will modify the Wizard's behavior:
</details>
<details><summary>Functions Mode</summary>
<details><summary>Templates Mode</summary>
This mode presents you with a list of txt files inside your `Unprompted/templates` directory that begin with a `[template]` block.
By including this block in your file, Unprompted will parse the file for its `[set x _new]` statements and adapt those into a custom Wizard UI.
The `_new` argument means "only set this variable if it doesn't already exist," which are generally the variables we want to show in a UI.
The `_new` argument means *"only set this variable if it doesn't already exist,"* which are generally the variables we want to show in a UI.
The `[template]` block supports the optional `name` argument which is a friendly name for your function shown in the functions dropdown menu.
The `[template]` block supports the optional `name` argument which is a friendly name for your function shown in the Templates dropdown menu.
The content of `[template]` is a description of your function to be rendered with [Markdown](https://www.markdownguide.org/basic-syntax/), which means you can include rich content like pictures or links. It will show up at the top of your UI.
The content of `[template]` is a description of your script to be rendered with [Markdown](https://www.markdownguide.org/basic-syntax/), which means you can include rich content like pictures or links. It will show up at the top of your UI.
The `[set]` block supports `_ui` which determines the type of UI element to render your variable as. Defaults to `textbox`. Here are the possible types:
@ -343,6 +343,10 @@ The `[set]` block supports `_ui` which determines the type of UI element to rend
- `dropdown`: A dropdown menu that is populated by the `_choices` argument, constructed as a delimited list.
- `slider`: Limits selection to a range of numbers. You must also specify `_minimum`, `_maximum` and `_step` (step size, normally 1) for this element to work properly.
The `[set]` block supports `_info` which is descriptive text that will appear near the UI element.
Supports the `[wizard_ui_accordion]` shortcode which will group the inner `[set]` blocks into a collapsible UI element.
</details>
## ⚙️ Shortcodes
@ -639,7 +643,7 @@ RESULT: She said
<details><summary>[do until(str)]</summary>
Do-until style loop. The content is processed, then the `until` expression is evaluated - if it's true, the content is processed again. Repeat until `until` is false.
Do-until style loop. The content is processed, then the `until` expression is evaluated - if it's false, the content is processed again. Repeat until `until` is true.
```
[sets my_var=0]
@ -1028,6 +1032,8 @@ Returns a slice of the content as determined by the keyword arguments.
`end` is the last position of the slice. Defaults to 0.
Alternatively, you can pass a string into `start` or `end` and it will find the index of that string within the `content`.
`step` is the skip interval. Defaults to 1 (in other words, a continuous substring.)
`unit` is either `characters` or `words` and refers to the unit of the aforementioned arguments. Defaults to `characters`.
@ -1129,38 +1135,6 @@ Note that variables are automatically deleted at the end of each run - you do **
This section describes all of the included shortcodes which are specifically designed for use with the A1111 WebUI.
<details><summary>[controlnet]</summary>
Enables support for [ControlNet](https://github.com/lllyasviel/ControlNet) models in img2img mode. ControlNet is a neural network structure to control diffusion models by adding extra conditions.
**NOTE:** This is a "wrapper" implementation of the original ControlNet code. For a more robust solution, you can check out [the dedicated ControlNet extension by Mikubill](https://github.com/Mikubill/sd-webui-controlnet).
You need a bare minimum of 8 GB of VRAM to use this shortcode, although 12 GB is recommended.
Supports the `model` argument, which is the name of a ControlNet checkpoint in your `models/Stable-diffusion` directory (do not include the file extension.) You can download ControlNet checkpoints from [the official HuggingFace page](https://huggingface.co/lllyasviel/ControlNet/tree/main/models).
For each model, you also need a copy of the [cldm_v15.yaml](https://github.com/lllyasviel/ControlNet/tree/main/models) config file. Rename it to match the name of the ControlNet model, e.g. `control_sd15_normal.yaml`.
For each model, you also need the associated [annotator files available here](https://huggingface.co/lllyasviel/ControlNet/tree/main/annotator/ckpts). Place these into your `extensions/unprompted/lib_unprompted/stable_diffusion/controlnet/annotator/ckpts` folder.
If you run into any errors, please triple-check your filepaths before opening a bug report.
You can use ControlNet with custom SD 1.5 models [by merging checkpoints as described here](https://github.com/lllyasviel/ControlNet/issues/4#issuecomment-1426877944).
Please be aware that the last part of your model's filename indicates which type of ControlNet model it is. The following ControlNet model types are supported: `openpose`, `scribble`, `mlsd`, `depth`, `normal`, `hed`, `canny`, `seg`
ControlNet models should **not** be loaded manually from your WebUI dropdown.
Supports the `save_memory` argument to minimize VRAM requirements.
Supports the `detect_resolution` argument which is the size of the detected map. Defaults to 512. Some models may perform better at 384. Lowering this value to 256 may help with VRAM requirements.
Supports the `eta` argument.
Supports the following model-specific arguments: `value_threshold`, `distance_threshold`, `bg_threshold`, `low_threshold`, `high_threshold`
</details>
<details><summary>[file2mask]</summary>
Allows you to modify or replace your img2img mask with arbitrary files.
@ -1391,6 +1365,8 @@ This is a helper shortcode that should be used if multiple init images, multiple
A port of [the script](https://github.com/ThereforeGames/txt2mask) by the same name, `[txt2mask]` allows you to create a region for inpainting based only on the text content (as opposed to the brush tool.) This shortcode only works in the img2img tab of the A1111 WebUI.
Supports the `method` argument which determines the technology to use for masking. Defaults to `clipseg`. Can be changed to `sam` which will utilize [Segment Anything](https://segment-anything.com/) instead.
Supports the `mode` argument which determines how the text mask will behave alongside a brush mask:
- `add` will overlay the two masks. This is the default value.
- `discard` will ignore the brush mask entirely.
@ -1424,6 +1400,8 @@ Supports the optional `show` positional argument which will append the final mas
Supports the optional `legacy_weights` positional argument which will utilize the original CLIPseg weights. By default, `[txt2mask]` will use the [refined weights](https://github.com/timojl/clipseg#new-fine-grained-weights).
Supports the `unload_model` argument, which will unload the clipseg model after processing. On my GTX 3090, this adds about 3 seconds to inference time. Defaults to `False`, and should only be enabled on devices with low memory.
The content and `negative_mask` both support the vertical pipe delimiter (`|`) which allows you to specify multiple subjects for masking.
```
@ -1434,7 +1412,7 @@ The content and `negative_mask` both support the vertical pipe delimiter (`|`) w
<details><summary>[zoom_enhance]</summary>
Upscales a selected portion of an image via `[img2img]` and `[txt2mask]`.
Upscales a selected portion of an image via `[img2img]` and `[txt2mask]`, then pastes it seamlessly back onto the original.
Greatly improves low-resolution details like faces and hands. It is significantly faster than Hires Fix and more flexible than the "Restore Faces" option.
@ -1446,11 +1424,23 @@ Supports the `replacement` keyword argument which is the prompt that will be use
Supports the `negative_replacement` keyword argument, which is the negative prompt that will be used on the mask region via `[img2img]`. Defaults to an empty string.
Both `replacement` and `negative_replacement` support multiple, delimited search terms via `Unprompted.config.syntax.delimiter`.
Supports `mask_sort_method` which is used when multiple, non-contiguous masks are detected. Defaults to `left-to-right`. Options include: `left-to-right`, `right-to-left`, `top-to-bottom`, `bottom-to-top`, `big-to-small`, `small-to-big`, `unsorted`.
Supports the `mode` keyword argument, which determines how the shortcode will interact with a pre-existing image mask. Defaults to `subtract`, which will remove your masked pixels from the shortcode's calculations. Options include: `add`, `subtract`, `discard`.
Supports the `bypass_adaptive_hires` positional argument. By default, the shortcode will scale up some inference settings such as CFG scale and sharpness depending on the resolution of the init image. Include this argument to disable that behavior.
Supports the `hires_size_max` keyword argument which is a hard limit on the size of the upscaled image, in order to avoid OOM errors. Defaults to 1024.
Supports the `blur_size` keyword argument, which corresponds to the radius of the gaussian blur that will be applied to the mask of the upscaled image - this helps it blend seamlessly back into your original image. Defaults to `0.03`. Note: this is a float that is a percentage of the total canvas size; 0.03 means 3% of the total canvas.
Supports the `sharpen_amount` argument, which is a float that determines the strength of the unsharp filter that is applied in post-processing.
Supports the `denoising_max` keyword argument. The `[zoom_enhance]` shortcode is equipped with **dynamic denoising strength** based on a simple idea: the smaller the mask region, the higher denoise we should apply. This argument lets you set the upper limit of that feature.
Supports the `mask_size_max` keyword argument. Defaults to `0.3`. If a mask region is determined to be greater than this value, it will not be processed by `[zoom_enhance]`. The reason is that large objects generally do not benefit from upscaling.
Supports the `mask_size_max` keyword argument. Defaults to `0.5`. If a mask region is determined to be greater than this value, it will not be processed by `[zoom_enhance]`. The reason is that large objects generally do not benefit from upscaling.
Supports the `min_area` keyword argument. Defaults to `50`. If the pixel area of a mask is smaller than this, it may be a false-positive mask selection or at least not worth upscaling.
@ -1460,6 +1450,16 @@ Supports the `upscale_width` and `upscale_height` arguments. Default to `512`. T
Supports the `include_original` positional argument. This will append the original, "non-zoom-enhanced" image to your output window. Useful for before-after comparisons.
Supports the `upscale_method` and `downscale_method` arguments which determine the algorithms for image rescaling. Upscale defaults to `Nearest Neighbor`. Downscale defaults to `Lanczos`. Options include: `Nearest Neighbor`, `Box`, `Bilinear`, `Hamming`, `Bicubic`, `Lanczos`.
Supports the `color_correction_method` argument which will attempt to match the color grading of the upscaled image to the original. Defaults to `none`. Options include: `none`,`mvgd`,`mkl`,`hm-mvgd-hm`,`hm-mkl-hm`.
Supports the `color_correct_strength` argument which is an integer that determines how many times to run the color correction algorithm. Defaults to 1.
Supports the `color_correct_timing` argument which determines when to run the color correction algorithm. Defaults to `pre`, which will run color correction before upscaling. Options include `pre` and `post`.
Supports the `debug` positional argument, which will output a series of images to your WebUI folder over the course of processing.
This shortcode is compatible with batch count and batch size.
@ -1509,4 +1509,50 @@ The `cleanup` function runs at the end of the processing chain. You can free any
For more details, please examine the code of the stock shortcodes.
</details>
<details><summary>Legacy Shortcodes</summary>
Legacy shortcodes are those which are no longer officially supported. Please be aware that they may not work as expected and could be removed from future versions of Unprompted without warning.
<details><summary>[controlnet]</summary>
**Reason for legacy status:** The popular [ControlNet extension by Mikubill](https://github.com/Mikubill/sd-webui-controlnet) was released less than 24 hours after this shortcode and is much more robust. ControlNet is a complicated, time-consuming feature to support and I cannot justify further development when the alternative software is already so good.
Enables support for [ControlNet](https://github.com/lllyasviel/ControlNet) models in img2img mode. ControlNet is a neural network structure to control diffusion models by adding extra conditions.
You need a bare minimum of 8 GB of VRAM to use this shortcode, although 12 GB is recommended.
Supports the `model` argument, which is the name of a ControlNet checkpoint in your `models/Stable-diffusion` directory (do not include the file extension.) You can download ControlNet checkpoints from [the official HuggingFace page](https://huggingface.co/lllyasviel/ControlNet/tree/main/models).
For each model, you also need a copy of the [cldm_v15.yaml](https://github.com/lllyasviel/ControlNet/tree/main/models) config file. Rename it to match the name of the ControlNet model, e.g. `control_sd15_normal.yaml`.
For each model, you also need the associated [annotator files available here](https://huggingface.co/lllyasviel/ControlNet/tree/main/annotator/ckpts). Place these into your `extensions/unprompted/lib_unprompted/stable_diffusion/controlnet/annotator/ckpts` folder.
If you run into any errors, please triple-check your filepaths before opening a bug report.
You can use ControlNet with custom SD 1.5 models [by merging checkpoints as described here](https://github.com/lllyasviel/ControlNet/issues/4#issuecomment-1426877944).
Please be aware that the last part of your model's filename indicates which type of ControlNet model it is. The following ControlNet model types are supported: `openpose`, `scribble`, `mlsd`, `depth`, `normal`, `hed`, `canny`, `seg`
ControlNet models should **not** be loaded manually from your WebUI dropdown.
Supports the `save_memory` argument to minimize VRAM requirements.
Supports the `detect_resolution` argument which is the size of the detected map. Defaults to 512. Some models may perform better at 384. Lowering this value to 256 may help with VRAM requirements.
Supports the `eta` argument.
Supports the following model-specific arguments: `value_threshold`, `distance_threshold`, `bg_threshold`, `low_threshold`, `high_threshold`
</details>
<details><summary>[pix2pix_zero]</summary>
**Reason for legacy status:** the pix2pix-zero method [was not what I originally thought](https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/7711#discussioncomment-4952579) which was sort of a buzzkill. As it stands, I believe ControlNet is better suited at most tasks, but I can't make any definitive claims about that - pix2pix-zero went under the radar and does merit further testing.
If you wish to use this shortcode, you will need to modify the hardcoded path to a diffusers model on line 33.
</details>
</details>

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.8 KiB

BIN
images/stamps/gigachad.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 190 KiB

View File

@ -0,0 +1 @@
from .clip import *

View File

@ -0,0 +1,74 @@
from torch import nn
from .clip_model import CLIP
from .clip_surgery_model import CLIPSurgery
def convert_weights(model: nn.Module):
"""Convert applicable model parameters to fp16"""
def _convert_weights_to_fp16(l):
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
l.weight.data = l.weight.data.half()
if l.bias is not None:
l.bias.data = l.bias.data.half()
if isinstance(l, nn.MultiheadAttention):
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
tensor = getattr(l, attr)
if tensor is not None:
tensor.data = tensor.data.half()
for name in ["text_projection", "proj"]:
if hasattr(l, name):
attr = getattr(l, name)
if attr is not None:
attr.data = attr.data.half()
model.apply(_convert_weights_to_fp16)
def build_model(name: str, state_dict: dict):
vit = "visual.proj" in state_dict
if vit:
vision_width = state_dict["visual.conv1.weight"].shape[0]
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
image_resolution = vision_patch_size * grid_size
else:
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
vision_layers = tuple(counts)
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
vision_patch_size = None
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
image_resolution = output_width * 32
embed_dim = state_dict["text_projection"].shape[1]
context_length = state_dict["positional_embedding"].shape[0]
vocab_size = state_dict["token_embedding.weight"].shape[0]
transformer_width = state_dict["ln_final.weight"].shape[0]
transformer_heads = transformer_width // 64
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
if 'CS-' in name:
model = CLIPSurgery(
embed_dim,
image_resolution, vision_layers, vision_width, vision_patch_size,
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
)
else:
model = CLIP(
embed_dim,
image_resolution, vision_layers, vision_width, vision_patch_size,
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
)
for key in ["input_resolution", "context_length", "vocab_size"]:
if key in state_dict:
del state_dict[key]
#convert_weights(model)
model.load_state_dict(state_dict)
return model.eval()

View File

@ -0,0 +1,344 @@
import hashlib
import os
import urllib
import warnings
from typing import Union, List
from pkg_resources import packaging
import torch
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from tqdm import tqdm
import numpy as np
from .build_model import build_model
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
try:
from torchvision.transforms import InterpolationMode
BICUBIC = InterpolationMode.BICUBIC
except ImportError:
BICUBIC = Image.BICUBIC
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
warnings.warn("PyTorch version 1.7.1 or higher is recommended")
__all__ = ["available_models", "load", "tokenize", "encode_text_with_prompt_ensemble",
"get_similarity_map", "clip_feature_surgery", "similarity_map_to_points"]
_tokenizer = _Tokenizer()
_MODELS = {
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
"RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
"RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
"ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
"ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
"CS-RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
"CS-RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
"CS-RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
"CS-RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
"CS-RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
"CS-ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
"CS-ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
"CS-ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
"CS-ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
}
def _download(url: str, root: str):
os.makedirs(root, exist_ok=True)
filename = os.path.basename(url)
expected_sha256 = url.split("/")[-2]
download_target = os.path.join(root, filename)
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")
if os.path.isfile(download_target):
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
return download_target
else:
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
return download_target
def _convert_image_to_rgb(image):
return image.convert("RGB")
def _transform(n_px):
return Compose([
Resize((n_px, n_px), interpolation=BICUBIC),
#CenterCrop(n_px), # rm center crop to explain whole image
_convert_image_to_rgb,
ToTensor(),
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
def available_models() -> List[str]:
"""Returns the names of available CLIP models"""
return list(_MODELS.keys())
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
"""Load a CLIP model
Parameters
----------
name : str
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
device : Union[str, torch.device]
The device to put the loaded model
jit : bool
Whether to load the optimized JIT model or more hackable non-JIT model (default).
download_root: str
path to download the model files; by default, it uses "~/.cache/clip"
Returns
-------
model : torch.nn.Module
The CLIP model
preprocess : Callable[[PIL.Image], torch.Tensor]
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
"""
if name in _MODELS:
model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
elif os.path.isfile(name):
model_path = name
else:
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
with open(model_path, 'rb') as opened_file:
try:
# loading JIT archive
model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
state_dict = None
except RuntimeError:
# loading saved state dict
if jit:
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
jit = False
state_dict = torch.load(opened_file, map_location="cpu")
if not jit:
model = build_model(name, state_dict or model.state_dict()).to(device)
if str(device) == "cpu":
model.float()
return model, _transform(model.visual.input_resolution)
# patch the device names
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
def patch_device(module):
try:
graphs = [module.graph] if hasattr(module, "graph") else []
except RuntimeError:
graphs = []
if hasattr(module, "forward1"):
graphs.append(module.forward1.graph)
for graph in graphs:
for node in graph.findAllNodes("prim::Constant"):
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
node.copyAttributes(device_node)
model.apply(patch_device)
patch_device(model.encode_image)
patch_device(model.encode_text)
# patch dtype to float32 on CPU
if str(device) == "cpu":
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
float_node = float_input.node()
def patch_float(module):
try:
graphs = [module.graph] if hasattr(module, "graph") else []
except RuntimeError:
graphs = []
if hasattr(module, "forward1"):
graphs.append(module.forward1.graph)
for graph in graphs:
for node in graph.findAllNodes("aten::to"):
inputs = list(node.inputs())
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
if inputs[i].node()["value"] == 5:
inputs[i].node().copyAttributes(float_node)
model.apply(patch_float)
patch_float(model.encode_image)
patch_float(model.encode_text)
model.float()
return model, _transform(model.input_resolution.item())
def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
"""
Returns the tokenized representation of given input string(s)
Parameters
----------
texts : Union[str, List[str]]
An input string or a list of input strings to tokenize
context_length : int
The context length to use; all CLIP models use 77 as the context length
truncate: bool
Whether to truncate the text in case its encoding is longer than the context length
Returns
-------
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
"""
if isinstance(texts, str):
texts = [texts]
sot_token = _tokenizer.encoder["<|startoftext|>"]
eot_token = _tokenizer.encoder["<|endoftext|>"]
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
else:
result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
for i, tokens in enumerate(all_tokens):
if len(tokens) > context_length:
if truncate:
tokens = tokens[:context_length]
tokens[-1] = eot_token
else:
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
result[i, :len(tokens)] = torch.tensor(tokens)
return result
def encode_text_with_prompt_ensemble(model, texts, device, prompt_templates=None):
# using default prompt templates for ImageNet
if prompt_templates == None:
prompt_templates = ['a bad photo of a {}.', 'a photo of many {}.', 'a sculpture of a {}.', 'a photo of the hard to see {}.', 'a low resolution photo of the {}.', 'a rendering of a {}.', 'graffiti of a {}.', 'a bad photo of the {}.', 'a cropped photo of the {}.', 'a tattoo of a {}.', 'the embroidered {}.', 'a photo of a hard to see {}.', 'a bright photo of a {}.', 'a photo of a clean {}.', 'a photo of a dirty {}.', 'a dark photo of the {}.', 'a drawing of a {}.', 'a photo of my {}.', 'the plastic {}.', 'a photo of the cool {}.', 'a close-up photo of a {}.', 'a black and white photo of the {}.', 'a painting of the {}.', 'a painting of a {}.', 'a pixelated photo of the {}.', 'a sculpture of the {}.', 'a bright photo of the {}.', 'a cropped photo of a {}.', 'a plastic {}.', 'a photo of the dirty {}.', 'a jpeg corrupted photo of a {}.', 'a blurry photo of the {}.', 'a photo of the {}.', 'a good photo of the {}.', 'a rendering of the {}.', 'a {} in a video game.', 'a photo of one {}.', 'a doodle of a {}.', 'a close-up photo of the {}.', 'a photo of a {}.', 'the origami {}.', 'the {} in a video game.', 'a sketch of a {}.', 'a doodle of the {}.', 'a origami {}.', 'a low resolution photo of a {}.', 'the toy {}.', 'a rendition of the {}.', 'a photo of the clean {}.', 'a photo of a large {}.', 'a rendition of a {}.', 'a photo of a nice {}.', 'a photo of a weird {}.', 'a blurry photo of a {}.', 'a cartoon {}.', 'art of a {}.', 'a sketch of the {}.', 'a embroidered {}.', 'a pixelated photo of a {}.', 'itap of the {}.', 'a jpeg corrupted photo of the {}.', 'a good photo of a {}.', 'a plushie {}.', 'a photo of the nice {}.', 'a photo of the small {}.', 'a photo of the weird {}.', 'the cartoon {}.', 'art of the {}.', 'a drawing of the {}.', 'a photo of the large {}.', 'a black and white photo of a {}.', 'the plushie {}.', 'a dark photo of a {}.', 'itap of a {}.', 'graffiti of the {}.', 'a toy {}.', 'itap of my {}.', 'a photo of a cool {}.', 'a photo of a small {}.', 'a tattoo of the {}.', 'there is a {} in the scene.', 'there is the {} in the scene.', 'this is a {} in the scene.', 'this is the {} in the scene.', 'this is one {} in the scene.']
text_features = []
for t in texts:
prompted_t = [template.format(t) for template in prompt_templates]
prompted_t = tokenize(prompted_t).to(device)
class_embeddings = model.encode_text(prompted_t)
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
class_embedding = class_embeddings.mean(dim=0)
class_embedding /= class_embedding.norm()
text_features.append(class_embedding)
text_features = torch.stack(text_features, dim=1).to(device).t()
return text_features
def get_similarity_map(sm, shape):
# min-max norm
sm = (sm - sm.min(1, keepdim=True)[0]) / (sm.max(1, keepdim=True)[0] - sm.min(1, keepdim=True)[0])
# reshape
side = int(sm.shape[1] ** 0.5) # square output
sm = sm.reshape(sm.shape[0], side, side, -1).permute(0, 3, 1, 2)
# interpolate
sm = torch.nn.functional.interpolate(sm, shape, mode='bilinear')
sm = sm.permute(0, 2, 3, 1)
return sm
def clip_feature_surgery(image_features, text_features, redundant_feats=None, t=2):
if redundant_feats != None:
similarity = image_features @ (text_features - redundant_feats).t()
else:
# weights to restrain influence of obvious classes on others
prob = image_features[:, :1, :] @ text_features.t()
prob = (prob * 2).softmax(-1)
w = prob / prob.mean(-1, keepdim=True)
# element-wise multiplied features
b, n_t, n_i, c = image_features.shape[0], text_features.shape[0], image_features.shape[1], image_features.shape[2]
feats = image_features.reshape(b, n_i, 1, c) * text_features.reshape(1, 1, n_t, c)
feats *= w.reshape(1, 1, n_t, 1)
redundant_feats = feats.mean(2, keepdim=True) # along cls dim
feats = feats - redundant_feats
# sum the element-wise multiplied features as cosine similarity
similarity = feats.sum(-1)
return similarity
# sm shape N_t
def similarity_map_to_points(sm, shape, t=0.8, down_sample=2):
side = int(sm.shape[0] ** 0.5)
sm = sm.reshape(1, 1, side, side)
# down sample to smooth results
down_side = side // down_sample
sm = torch.nn.functional.interpolate(sm, (down_side, down_side), mode='bilinear')[0, 0, :, :]
h, w = sm.shape
sm = sm.reshape(-1)
sm = (sm - sm.min()) / (sm.max() - sm.min())
rank = sm.sort(0)[1]
scale_h = float(shape[0]) / h
scale_w = float(shape[1]) / w
num = min((sm >= t).sum(), sm.shape[0] // 2)
labels = np.ones(num * 2).astype('uint8')
labels[num:] = 0
points = []
# positives
for idx in rank[-num:]:
x = min((idx % w + 0.5) * scale_w, shape[1] - 1) # +0.5 to center
y = min((idx // w + 0.5) * scale_h, shape[0] - 1)
points.append([int(x.item()), int(y.item())])
# negatives
for idx in rank[:num]:
x = min((idx % w + 0.5) * scale_w, shape[1] - 1)
y = min((idx // w + 0.5) * scale_h, shape[0] - 1)
points.append([int(x.item()), int(y.item())])
return points, labels

View File

@ -0,0 +1,396 @@
from collections import OrderedDict
from typing import Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1):
super().__init__()
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.relu2 = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu3 = nn.ReLU(inplace=True)
self.downsample = None
self.stride = stride
if stride > 1 or inplanes != planes * Bottleneck.expansion:
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
self.downsample = nn.Sequential(OrderedDict([
("-1", nn.AvgPool2d(stride)),
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
("1", nn.BatchNorm2d(planes * self.expansion))
]))
def forward(self, x: torch.Tensor):
identity = x
out = self.relu1(self.bn1(self.conv1(x)))
out = self.relu2(self.bn2(self.conv2(out)))
out = self.avgpool(out)
out = self.bn3(self.conv3(out))
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu3(out)
return out
class AttentionPool2d(nn.Module):
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
super().__init__()
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
self.num_heads = num_heads
def forward(self, x):
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
side = int((self.positional_embedding.shape[0] - 1) ** 0.5)
new_side = int((x.shape[0] - 1) ** 0.5)
# update the position embedding during inference for varied input size
if side != new_side:
new_pos = self.positional_embedding[1:, :].reshape(-1, side, side, x.shape[-1]).permute(0, 3, 1, 2)
new_pos = torch.nn.functional.interpolate(new_pos, (new_side, new_side), mode='bilinear')
new_pos = new_pos.reshape(-1, x.shape[-1], new_side * new_side).transpose(1, 2)
self.positional_embedding.data = torch.cat([self.positional_embedding[:1, :], new_pos[0]], 0)
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
x, _ = F.multi_head_attention_forward(
query=x, key=x, value=x,
embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
in_proj_weight=None,
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=0,
out_proj_weight=self.c_proj.weight,
out_proj_bias=self.c_proj.bias,
use_separate_proj_weight=True,
training=self.training,
need_weights=False
)
#return x[0]
return x.transpose(0, 1) # return both cls token and image tokens, B,N,C
class ModifiedResNet(nn.Module):
"""
A ResNet class that is similar to torchvision's but contains the following changes:
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
- The final pooling layer is a QKV attention instead of an average pool
"""
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
super().__init__()
self.output_dim = output_dim
self.input_resolution = input_resolution
# the 3-layer stem
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(width // 2)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(width // 2)
self.relu2 = nn.ReLU(inplace=True)
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(width)
self.relu3 = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(2)
# residual layers
self._inplanes = width # this is a *mutable* variable used during construction
self.layer1 = self._make_layer(width, layers[0])
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
embed_dim = width * 32 # the ResNet feature dimension
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
def _make_layer(self, planes, blocks, stride=1):
layers = [Bottleneck(self._inplanes, planes, stride)]
self._inplanes = planes * Bottleneck.expansion
for _ in range(1, blocks):
layers.append(Bottleneck(self._inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
def stem(x):
x = self.relu1(self.bn1(self.conv1(x)))
x = self.relu2(self.bn2(self.conv2(x)))
x = self.relu3(self.bn3(self.conv3(x)))
x = self.avgpool(x)
return x
x = x.type(self.conv1.weight.dtype)
x = stem(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.attnpool(x)
return x
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class ResidualAttentionBlock(nn.Module):
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, need_weights: bool = False):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(OrderedDict([
("c_fc", nn.Linear(d_model, d_model * 4)),
("gelu", QuickGELU()),
("c_proj", nn.Linear(d_model * 4, d_model))
]))
self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask
self.need_weights = need_weights
def attention(self, x: torch.Tensor):
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
if self.need_weights == False:
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
else:
return self.attn(x, x, x, need_weights=True, attn_mask=self.attn_mask)
def forward(self, x: torch.Tensor):
if self.need_weights == False:
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
else:
y, attn = self.attention(self.ln_1(x))
x = x + y
x = x + self.mlp(self.ln_2(x))
return x
class Transformer(nn.Module):
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, need_weights: bool = False):
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask, need_weights if i == layers - 1 else False) for i in range(layers)])
def forward(self, x: torch.Tensor):
return self.resblocks(x)
class VisionTransformer(nn.Module):
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
super().__init__()
self.input_resolution = input_resolution
self.output_dim = output_dim
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
scale = width ** -0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
self.ln_pre = LayerNorm(width)
self.transformer = Transformer(width, layers, heads, need_weights=True)
self.ln_post = LayerNorm(width)
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
def forward(self, x: torch.Tensor):
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
x = x + self.positional_embedding.to(x.dtype)
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
#x = self.ln_post(x[:, 0, :])
x = self.ln_post(x) # return both cls token and image tokens
if self.proj is not None:
x = x @ self.proj
return x
class CLIP(nn.Module):
def __init__(self,
embed_dim: int,
# vision
image_resolution: int,
vision_layers: Union[Tuple[int, int, int, int], int],
vision_width: int,
vision_patch_size: int,
# text
context_length: int,
vocab_size: int,
transformer_width: int,
transformer_heads: int,
transformer_layers: int
):
super().__init__()
self.context_length = context_length
if isinstance(vision_layers, (tuple, list)):
vision_heads = vision_width * 32 // 64
self.visual = ModifiedResNet(
layers=vision_layers,
output_dim=embed_dim,
heads=vision_heads,
input_resolution=image_resolution,
width=vision_width
)
else:
vision_heads = vision_width // 64
self.visual = VisionTransformer(
input_resolution=image_resolution,
patch_size=vision_patch_size,
width=vision_width,
layers=vision_layers,
heads=vision_heads,
output_dim=embed_dim
)
self.transformer = Transformer(
width=transformer_width,
layers=transformer_layers,
heads=transformer_heads,
attn_mask=self.build_attention_mask()
)
self.vocab_size = vocab_size
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
self.ln_final = LayerNorm(transformer_width)
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.initialize_parameters()
def initialize_parameters(self):
nn.init.normal_(self.token_embedding.weight, std=0.02)
nn.init.normal_(self.positional_embedding, std=0.01)
if isinstance(self.visual, ModifiedResNet):
if self.visual.attnpool is not None:
std = self.visual.attnpool.c_proj.in_features ** -0.5
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
for name, param in resnet_block.named_parameters():
if name.endswith("bn3.weight"):
nn.init.zeros_(param)
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
attn_std = self.transformer.width ** -0.5
fc_std = (2 * self.transformer.width) ** -0.5
for block in self.transformer.resblocks:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
if self.text_projection is not None:
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
def build_attention_mask(self):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.context_length, self.context_length)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask
@property
def dtype(self):
return self.visual.conv1.weight.dtype
def encode_image(self, image):
return self.visual(image.type(self.dtype))
def encode_text(self, text):
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding.type(self.dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x).type(self.dtype)
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
return x
def forward(self, image, text):
image_features = self.encode_image(image)
text_features = self.encode_text(text)
# normalized features
image_features = image_features / image_features.norm(dim=1, keepdim=True)
text_features = text_features / text_features.norm(dim=1, keepdim=True)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()
# shape = [global_batch_size, global_batch_size]
return logits_per_image, logits_per_text

View File

@ -0,0 +1,487 @@
from collections import OrderedDict
from typing import Tuple, Union
import numpy as np
import torch
from torch import nn
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1):
super().__init__()
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.relu2 = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu3 = nn.ReLU(inplace=True)
self.downsample = None
self.stride = stride
if stride > 1 or inplanes != planes * Bottleneck.expansion:
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
self.downsample = nn.Sequential(OrderedDict([
("-1", nn.AvgPool2d(stride)),
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
("1", nn.BatchNorm2d(planes * self.expansion))
]))
def forward(self, x: torch.Tensor):
identity = x
out = self.relu1(self.bn1(self.conv1(x)))
out = self.relu2(self.bn2(self.conv2(out)))
out = self.avgpool(out)
out = self.bn3(self.conv3(out))
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu3(out)
return out
# implement attention module for v-v self-attention
class Attention(nn.Module):
def __init__(self, out_dim, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., settings=''):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(out_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.settings = settings
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
# original self-attention for the original path
attn_ori = (q @ k.transpose(-2, -1)) * self.scale
attn_ori = attn_ori.softmax(dim=-1)
attn_ori = self.attn_drop(attn_ori)
# replace k & q by v
k = v
q = k
# resnets have only one self-attention, norm and larger scale perform better
if self.settings == 'resnet':
k = k / (k.norm(p=2, dim=-1, keepdim=True) + 1e-6)
q = k
scale = self.scale * 8
else:
scale = self.scale
# self-attention, higher temperate for resnets performs better
attn = (q @ k.transpose(-2, -1)) * scale
attn = (attn).softmax(dim=-1)
attn = self.attn_drop(attn)
x_ori = (attn_ori @ v).transpose(1, 2).reshape(B, N, C)
x = (attn @ v).transpose(1, 2).reshape(B, N, C) # clip_surgery
#x = v.transpose(1, 2).reshape(B, N, C) # mask_clip
x = self.proj_drop(self.proj(x))
x_ori = self.proj_drop(self.proj(x_ori))
return [x, x_ori]
class AttentionPool2d(nn.Module):
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
super().__init__()
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
self.num_heads = num_heads
self.attn = None
self.embed_dim = embed_dim
self.num_heads = num_heads
self.output_dim = output_dim
def forward(self, x):
# reform transformer layer after init and load weights, using v only
if self.attn == None:
self.attn = Attention(self.output_dim, self.embed_dim, self.num_heads, True)
self.attn.qkv.weight = torch.nn.Parameter(torch.cat([self.v_proj.weight, self.v_proj.weight, self.v_proj.weight], 0))
self.attn.qkv.bias = torch.nn.Parameter(torch.cat([self.v_proj.bias, self.v_proj.bias, self.v_proj.bias]))
self.attn.proj.weight = self.c_proj.weight
self.attn.proj.bias = self.c_proj.bias
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
side = int((self.positional_embedding.shape[0] - 1) ** 0.5)
new_side = int((x.shape[0] - 1) ** 0.5)
# update the position embedding during inference for varied input size
if side != new_side:
new_pos = self.positional_embedding[1:, :].reshape(-1, side, side, x.shape[-1]).permute(0, 3, 1, 2)
new_pos = torch.nn.functional.interpolate(new_pos, (new_side, new_side), mode='bilinear')
new_pos = new_pos.reshape(-1, x.shape[-1], new_side * new_side).transpose(1, 2)
self.positional_embedding.data = torch.cat([self.positional_embedding[:1, :], new_pos[0]], 0)
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
x, x_ori = self.attn(x.transpose(0, 1))
# cls token from the original path, and img tokens from the new path
x[:, 0, :] = x_ori[:, 0, :]
return x
class ModifiedResNet(nn.Module):
"""
A ResNet class that is similar to torchvision's but contains the following changes:
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
- The final pooling layer is a QKV attention instead of an average pool
"""
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
super().__init__()
self.output_dim = output_dim
self.input_resolution = input_resolution
# the 3-layer stem
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(width // 2)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(width // 2)
self.relu2 = nn.ReLU(inplace=True)
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(width)
self.relu3 = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(2)
# residual layers
self._inplanes = width # this is a *mutable* variable used during construction
self.layer1 = self._make_layer(width, layers[0])
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
embed_dim = width * 32 # the ResNet feature dimension
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
def _make_layer(self, planes, blocks, stride=1):
layers = [Bottleneck(self._inplanes, planes, stride)]
self._inplanes = planes * Bottleneck.expansion
for _ in range(1, blocks):
layers.append(Bottleneck(self._inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
def stem(x):
x = self.relu1(self.bn1(self.conv1(x)))
x = self.relu2(self.bn2(self.conv2(x)))
x = self.relu3(self.bn3(self.conv3(x)))
x = self.avgpool(x)
return x
x = x.type(self.conv1.weight.dtype)
x = stem(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.attnpool(x)
# shape BNC
return x
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class ResidualAttentionBlock(nn.Module):
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(OrderedDict([
("c_fc", nn.Linear(d_model, d_model * 4)),
("gelu", QuickGELU()),
("c_proj", nn.Linear(d_model * 4, d_model))
]))
self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask
def attention(self, x: torch.Tensor):
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
if isinstance(self.attn, Attention):
x = x.transpose(0, 1)
x, x_ori = self.attn(x)
return [x.transpose(0, 1), x_ori.transpose(0, 1)]
else:
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
def forward(self, x):
# dual paths for blocks deeper than "d"
if isinstance(self.attn, Attention):
if isinstance(x, list):
x, x_ori = x
x_res = self.attention(self.ln_1(x_ori))
x_res, x_ori_res = x_res
x_ori += x_ori_res
x_ori = x_ori + self.mlp(self.ln_2(x_ori))
x += x_res # skip ffn for the new path
return [x, x_ori]
# start of dual path
else:
x_res = self.attention(self.ln_1(x))
if isinstance(x_res, list):
x_res, x_ori_res = x_res
x_ori = x + x_ori_res
x_ori = x_ori + self.mlp(self.ln_2(x_ori))
x += x_res
return [x, x_ori]
# singl path before "d"
else:
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class Transformer(nn.Module):
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, need_weights: bool = False):
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for i in range(layers)])
def forward(self, x: torch.Tensor):
return self.resblocks(x)
class VisionTransformer(nn.Module):
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
super().__init__()
self.input_resolution = input_resolution
self.output_dim = output_dim
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
scale = width ** -0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
self.ln_pre = LayerNorm(width)
self.transformer = Transformer(width, layers, heads, need_weights=True)
self.attn = None
self.embed_dim = width
self.num_heads = heads
self.ln_post = LayerNorm(width)
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
@torch.no_grad()
def forward(self, x: torch.Tensor):
# reform the architecture during first inference
if self.attn == None:
# apply architecture surgery on the last 6 blocks
for i in range(1, 7): # surgery 7, maskclip 2
self.attn = Attention(self.embed_dim, self.embed_dim, self.num_heads, True)
self.attn.qkv.weight.data = self.transformer.resblocks[-i].attn.in_proj_weight.clone()
self.attn.qkv.bias.data = self.transformer.resblocks[-i].attn.in_proj_bias.clone()
self.attn.proj.weight.data = self.transformer.resblocks[-i].attn.out_proj.weight.clone()
self.attn.proj.bias.data = self.transformer.resblocks[-i].attn.out_proj.bias.clone()
self.transformer.resblocks[-i].attn = self.attn
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
side = int((self.positional_embedding.shape[0] - 1) ** 0.5)
new_side = int((x.shape[1] - 1) ** 0.5)
# update the position embedding during inference for varied input size
if side != new_side:
new_pos = self.positional_embedding[1:, :].reshape(-1, side, side, x.shape[-1]).permute(0, 3, 1, 2)
new_pos = torch.nn.functional.interpolate(new_pos, (new_side, new_side), mode='bilinear')
new_pos = new_pos.reshape(-1, x.shape[-1], new_side * new_side).transpose(1, 2)
self.positional_embedding.data = torch.cat([self.positional_embedding[:1, :], new_pos[0]], 0)
pos = self.positional_embedding.to(x.dtype)
x = x + pos
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x, x_ori = self.transformer(x)
x[0, :, :] = x_ori[0, :, :] # clip_surgery
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_post(x)
x = x @ self.proj
return x
class CLIPSurgery(nn.Module):
def __init__(self,
embed_dim: int,
# vision
image_resolution: int,
vision_layers: Union[Tuple[int, int, int, int], int],
vision_width: int,
vision_patch_size: int,
# text
context_length: int,
vocab_size: int,
transformer_width: int,
transformer_heads: int,
transformer_layers: int
):
super().__init__()
self.context_length = context_length
if isinstance(vision_layers, (tuple, list)):
vision_heads = vision_width * 32 // 64
self.visual = ModifiedResNet(
layers=vision_layers,
output_dim=embed_dim,
heads=vision_heads,
input_resolution=image_resolution,
width=vision_width
)
else:
vision_heads = vision_width // 64
self.visual = VisionTransformer(
input_resolution=image_resolution,
patch_size=vision_patch_size,
width=vision_width,
layers=vision_layers,
heads=vision_heads,
output_dim=embed_dim
)
self.transformer = Transformer(
width=transformer_width,
layers=transformer_layers,
heads=transformer_heads,
attn_mask=self.build_attention_mask()
)
self.vocab_size = vocab_size
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
self.ln_final = LayerNorm(transformer_width)
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.initialize_parameters()
def initialize_parameters(self):
nn.init.normal_(self.token_embedding.weight, std=0.02)
nn.init.normal_(self.positional_embedding, std=0.01)
if isinstance(self.visual, ModifiedResNet):
if self.visual.attnpool is not None:
std = self.visual.attnpool.c_proj.in_features ** -0.5
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
for name, param in resnet_block.named_parameters():
if name.endswith("bn3.weight"):
nn.init.zeros_(param)
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
attn_std = self.transformer.width ** -0.5
fc_std = (2 * self.transformer.width) ** -0.5
for block in self.transformer.resblocks:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
if self.text_projection is not None:
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
def build_attention_mask(self):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.context_length, self.context_length)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask
@property
def dtype(self):
return self.visual.conv1.weight.dtype
def encode_image(self, image):
return self.visual(image.type(self.dtype))
def encode_text(self, text):
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding.type(self.dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x).type(self.dtype)
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
return x
def forward(self, image, text):
image_features = self.encode_image(image)
text_features = self.encode_text(text)
# normalized features
image_features = image_features / image_features.norm(dim=1, keepdim=True)
text_features = text_features / text_features.norm(dim=1, keepdim=True)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()
# shape = [global_batch_size, global_batch_size]
return logits_per_image, logits_per_text

View File

@ -0,0 +1,132 @@
import gzip
import html
import os
from functools import lru_cache
import ftfy
import regex as re
@lru_cache()
def default_bpe():
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8+n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r'\s+', ' ', text)
text = text.strip()
return text
class SimpleTokenizer(object):
def __init__(self, bpe_path: str = default_bpe()):
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
merges = merges[1:49152-256-2+1]
merges = [tuple(merge.split()) for merge in merges]
vocab = list(bytes_to_unicode().values())
vocab = vocab + [v+'</w>' for v in vocab]
for merge in merges:
vocab.append(''.join(merge))
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
self.encoder = dict(zip(vocab, range(len(vocab))))
self.decoder = {v: k for k, v in self.encoder.items()}
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
pairs = get_pairs(word)
if not pairs:
return token+'</w>'
while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word
def encode(self, text):
bpe_tokens = []
text = whitespace_clean(basic_clean(text)).lower()
for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
return bpe_tokens
def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
return text

View File

@ -10,7 +10,7 @@ import sys
class Unprompted:
def __init__(self, base_dir="."):
self.VERSION = "8.2.0"
self.VERSION = "8.3.0"
print(f"Loading Unprompted v{self.VERSION} by Therefore Games")
self.log("Initializing Unprompted object...",False,"SETUP")
@ -113,7 +113,8 @@ class Unprompted:
def parse_advanced(self,string,context=None):
"""First runs the string through parse_alt_tags, the result of which then goes through simpleeval"""
if (len(string) < 1): return string
if string is None: return ""
if (len(string) < 1): return ""
string = self.parse_alt_tags(string,context)
if self.Config.advanced_expressions:
try:

View File

@ -35,6 +35,7 @@ Unprompted.wizard_dropdown = None
Unprompted.wizard_template_files = []
Unprompted.wizard_template_names = []
Unprompted.wizard_template_kwargs = []
def do_dry_run(string):
Unprompted.log(string)
@ -109,6 +110,9 @@ def wizard_generate_template(option,is_img2img,prepend="",append=""):
result = parse_children(group,result)
for kwarg in Unprompted.wizard_template_kwargs[option]:
result += f" {kwarg}='{Unprompted.wizard_template_kwargs[option][kwarg]}'"
# Closing bracket
result += Unprompted.Config.syntax.tag_end
@ -261,6 +265,7 @@ class Scripts(scripts.Script):
content = content.replace("\\r\\n", "<br>") + "<br><br>"
gr.Label(label="Options",value=f"{self.dropdown_item_name}")
gr.Markdown(value=content)
self.wizard_template_kwargs = kwargs
return("")
wizard_shortcode_parser.register(handler,"template",f"{Unprompted.Config.syntax.tag_close}template")
@ -306,6 +311,7 @@ class Scripts(scripts.Script):
wizard_add_template(is_first)
Unprompted.wizard_template_names.append(self.dropdown_item_name)
Unprompted.wizard_template_files.append(filename)
Unprompted.wizard_template_kwargs.append(self.wizard_template_kwargs)
if (is_first):
templates_dropdown.value = self.dropdown_item_name
is_first = False
@ -435,6 +441,8 @@ class Scripts(scripts.Script):
# Extra vars
Unprompted.shortcode_user_vars["batch_index"] = 0
original_model = opts.sd_model_checkpoint
Unprompted.shortcode_user_vars["sd_model"] = opts.sd_model_checkpoint
# Set up system var support - copy relevant p attributes into shortcode var object
for att in dir(p):
@ -451,7 +459,7 @@ class Scripts(scripts.Script):
# Special handling of vars
for att in Unprompted.shortcode_user_vars:
# change models
if att == "sd_model":
if att == "sd_model" and Unprompted.shortcode_user_vars[att] != original_model:
info = sd_models.get_closet_checkpoint_match(Unprompted.shortcode_user_vars["sd_model"])
if (info): sd_models.load_model(info,None,None)
# control controlnet
@ -510,8 +518,8 @@ class Scripts(scripts.Script):
# After routines
def postprocess(self, p, processed, is_enabled=True, unprompted_seed=-1, match_main_seed=True):
if not self.allow_postprocess or not is_enabled:
Unprompted.log("Bypassing After routine")
self.allow_postprocess = True
Unprompted.log("Bypassing After routine to avoid infinite loop")
# self.allow_postprocess = True
return False # Prevents endless loop with some shortcodes
self.allow_postprocess = False
Unprompted.log("Entering After routine...")

View File

@ -8,6 +8,7 @@ class Shortcode():
def run_block(self, pargs, kwargs, context, content):
index = int(self.Unprompted.parse_advanced(pargs[0])) if len(pargs) > 0 else 0
if self.last_index != index or "allow_dupe_index" in pargs:
self.Unprompted.log(f"Queueing up conent: {content}")
self.after_content.insert(index,content)
self.last_index = index
else: self.Unprompted.log("Duplicate [after] content detected, skipping - include allow_dupe_index to bypass this check")

View File

@ -8,6 +8,9 @@ class Shortcode():
self.description = "Processes the file content of 'path.'"
def run_atomic(self, pargs, kwargs, context):
if "_bypass_if" in kwargs:
if self.Unprompted.parse_advanced(kwargs["_bypass_if"],context): return ""
file_string = self.Unprompted.parse_alt_tags(pargs[0],context)
this_encoding = self.Unprompted.parse_advanced(kwargs["_encoding"],context) if "_encoding" in kwargs else "utf-8"

View File

@ -0,0 +1,100 @@
class Shortcode():
def __init__(self,Unprompted):
self.Unprompted = Unprompted
self.description = "Applies color correction to a resulting image."
self.wizard_prepend = Unprompted.Config.syntax.tag_start + "after" + Unprompted.Config.syntax.tag_end + Unprompted.Config.syntax.tag_start_alt + "color_correct"
self.wizard_append = Unprompted.Config.syntax.tag_end_alt + Unprompted.Config.syntax.tag_start + Unprompted.Config.syntax.tag_close + "after" + Unprompted.Config.syntax.tag_end
def run_atomic(self, pargs, kwargs, context):
from PIL import Image
def autocrop_image(image, border = 0):
# Get the bounding box
bbox = image.getbbox()
# Crop the image to the contents of the bounding box
image = image.crop(bbox)
# Determine the width and height of the cropped image
(width, height) = image.size
# Add border
width += border * 2
height += border * 2
# Create a new image object for the output image
cropped_image = Image.new("RGBA", (width, height), (0,0,0,0))
# Paste the cropped image onto the new image
cropped_image.paste(image, (border, border))
# Done!
return cropped_image
from blendmodes.blend import blendLayers, BlendType
color_correct_method = kwargs["method"] if "method" in kwargs else "mkl"
blend_lum = True if "blend_lum" in pargs else False
debug = True if "debug" in pargs else False
try:
image_to_fix = self.Unprompted.after_processed.images[0].copy()
except:
self.Unprompted.log("This must be used inside of an [after] block",context="ERROR")
return("")
starting_image = self.Unprompted.p_copy.init_images[0]
orig_image = image_to_fix.copy()
if self.Unprompted.shortcode_user_vars["image_mask"]:
mask = self.Unprompted.shortcode_user_vars["image_mask"]
mask = mask.convert("L")
else: mask = None
if "source" in kwargs:
set_kwargs = kwargs
set_pargs = pargs
set_pargs.insert(0,"return_image")
set_kwargs["txt2mask_init_image"] = starting_image
set_kwargs["precision"] = "150"
set_kwargs["padding"] = "0"
set_kwargs["method"] = "clipseg"
self.Unprompted.shortcode_user_vars["image_mask"] = None
source_mask = self.Unprompted.shortcode_objects["txt2mask"].run_block(set_pargs,set_kwargs,None,kwargs["source"])
source_mask = source_mask.convert("L")
starting_image.putalpha(source_mask)
starting_image = autocrop_image(starting_image)
import cv2
import numpy
from sklearn.cluster import KMeans
np_img = numpy.array(starting_image)
all_pixels = np_img.reshape(-1, 4)
#all_pixels.shape
just_non_alpha = all_pixels[all_pixels[:, 3] == 255]
#just_non_alpha.shape
avg_model = KMeans(3)
reshaped_arr = np_img[:, :, :3].reshape(-1, 3)
avg_model.fit(np_img[:, :, :3].reshape(-1, 3))
KMeans(algorithm='auto', copy_x=True, init='k-means++', max_iter=300,
n_clusters=3, n_init=10, random_state=None, tol=0.0001, verbose=0)
print(avg_model.cluster_centers_)
avg_color = avg_model.cluster_centers_[0]
if debug: self.Unprompted.log(f"Average color: {avg_color}")
new_image = Image.new("RGB", starting_image.size, (int(avg_color[0]),int(avg_color[1]),int(avg_color[2])))
new_image.paste(starting_image, (0, 0), starting_image)
if debug: new_image.save("color_correct_alpha_test.png")
starting_image = new_image.copy()
strength = float(kwargs["strength"]) if "strength" in kwargs else 1.0
fixed_image = self.Unprompted.color_match(starting_image,image_to_fix,color_correct_method,1).convert("RGBA")
if blend_lum:
fixed_image = blendLayers(fixed_image, orig_image, BlendType.LUMINOSITY)
if strength < 1.0:
fixed_image.putalpha(int(255 * strength))
else: self.Unprompted.after_processed.images[0] = fixed_image
orig_image.paste(fixed_image,(0,0), fixed_image)
orig_image.resize((self.Unprompted.after_processed.images[0].size[0],self.Unprompted.after_processed.images[0].size[1]))
self.Unprompted.after_processed.images[0].paste(orig_image,(0,0),mask)
# self.Unprompted.after_processed.images[0] = fixed_image

View File

@ -5,8 +5,16 @@ class Shortcode():
self.image_mask = None
self.show = False
self.description = "Creates an image mask from the content for use with inpainting."
try:
del self.cached_model
del self.cached_transform
del self.cached_model_method
del self.cached_predictor
except: pass
self.cached_model = -1
self.cached_transform = -1
self.cached_model_method = ""
self.cached_predictor = -1
def run_block(self, pargs, kwargs, context, content):
from PIL import ImageChops, Image, ImageOps
@ -16,45 +24,80 @@ class Shortcode():
from matplotlib import pyplot as plt
import cv2
import numpy
import gc
from modules.images import flatten
from modules.shared import opts
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
gc.collect()
if "txt2mask_init_image" in kwargs:
self.init_image = kwargs["txt2mask_init_image"]
self.init_image = kwargs["txt2mask_init_image"].copy()
elif "init_images" not in self.Unprompted.shortcode_user_vars:
self.Unprompted.log("No init_images found...")
return
else: self.init_image = self.Unprompted.shortcode_user_vars["init_images"][0]
else: self.init_image = self.Unprompted.shortcode_user_vars["init_images"][0].copy()
method = self.Unprompted.parse_advanced(kwargs["method"],context) if "method" in kwargs else "clipseg"
if method == "clipseg":
mask_width = 512
mask_height = 512
elif method == "sam":
import launch
if not launch.is_installed("groundingdino"):
self.Unprompted.log("Attempting to install GroundingDINO library. Buckle up bro")
try:
launch.run_pip("install git+https://github.com/IDEA-Research/GroundingDINO","requirements for Unprompted - txt2mask SAM method")
except Exception as e:
self.Unprompted.log(f"GroundingDINO problem: {e}",context="ERROR")
self.Unprompted.log(f"Please open an issue on their repo, not mine.",context="ERROR")
return ""
else:
if method == "grounded_sam":
import launch
if not launch.is_installed("groundingdino"):
self.Unprompted.log("Attempting to install GroundingDINO library. Buckle up bro")
try:
launch.run_pip("install git+https://github.com/IDEA-Research/GroundingDINO","requirements for Unprompted - txt2mask SAM method")
except Exception as e:
self.Unprompted.log(f"GroundingDINO problem: {e}",context="ERROR")
self.Unprompted.log(f"Please open an issue on their repo, not mine.",context="ERROR")
return ""
mask_width = self.Unprompted.shortcode_user_vars["width"]
mask_height = self.Unprompted.shortcode_user_vars["height"]
mask_width = self.init_image.size[0]
mask_height = self.init_image.size[1]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device == "cuda": torch.cuda.empty_cache()
if "stamp" in kwargs:
stamps = (self.Unprompted.parse_advanced(kwargs["stamp"],context)).split(self.Unprompted.Config.syntax.delimiter)
stamp_x = int(float(self.Unprompted.parse_advanced(kwargs["stamp_x"],context))) if "stamp_x" in kwargs else 0
stamp_y = int(float(self.Unprompted.parse_advanced(kwargs["stamp_y"],context))) if "stamp_y" in kwargs else 0
stamp_x_orig = stamp_x
stamp_y_orig = stamp_y
stamp_method = self.Unprompted.parse_advanced(kwargs["stamp_method"],context) if "stamp_method" in kwargs else "stretch"
for stamp in stamps:
# Checks for file in images/stamps, otherwise assumes absolute path
stamp_path = f"{self.Unprompted.base_dir}/images/stamps/{stamp}.png"
if not os.path.exists(stamp_path): stamp_path = stamp
if not os.path.exists(stamp_path):
self.Unprompted.log(f"Stamp not found: {stamp_path}",context="ERROR")
continue
stamp_img = Image.open(stamp_path).convert("RGBA")
if stamp_method == "stretch":
stamp_img = stamp_img.resize((self.init_image.size[0],self.init_image.size[1]))
elif stamp_method == "center":
stamp_x = stamp_x_orig + int((mask_width - stamp_img.size[0]) / 2)
stamp_y = stamp_y_orig + int((mask_height - stamp_img.size[1]) / 2)
stamp_blur = int(float(self.Unprompted.parse_advanced(kwargs["stamp_blur"],context))) if "stamp_blur" in kwargs else 0
if stamp_blur:
from PIL import ImageFilter
blur = ImageFilter.GaussianBlur(stamp_blur)
stamp_img = stamp_img.filter(blur)
self.init_image.paste(stamp_img,(stamp_x,stamp_y),stamp_img)
brush_mask_mode = self.Unprompted.parse_advanced(kwargs["mode"],context) if "mode" in kwargs else "add"
self.show = True if "show" in pargs else False
box_thresh = float(self.Unprompted.parse_advanced(kwargs["box_threshold"],context)) if "box_threshold" in kwargs else 0.3
text_thresh = float(self.Unprompted.parse_advanced(kwargs["text_threshold"],context)) if "text_threshold" in kwargs else 0.25
self.legacy_weights = True if "legacy_weights" in pargs else False
smoothing = int(self.Unprompted.parse_advanced(kwargs["smoothing"],context)) if "smoothing" in kwargs else 20
smoothing_kernel = None
@ -103,15 +146,16 @@ class Shortcode():
for i, mask in enumerate(masks):
filename = f"mask_{mode}_{i}.png"
if method == "clipseg":
plt.imsave(filename,torch.sigmoid(mask[0]))
img = cv2.imread(filename)
# TODO: Figure out how to convert the plot above to numpy instead of re-loading image
else:
plt.imsave(filename,mask)
import random
img = cv2.imread(filename)
img = cv2.resize(img,(mask_width,mask_height))
if padding_dilation_kernel is not None:
@ -119,7 +163,11 @@ class Shortcode():
else: img = cv2.erode(img,padding_dilation_kernel,iterations=1)
if smoothing_kernel is not None: img = cv2.filter2D(img,-1,smoothing_kernel)
#if method == "clip_surgery":
#gray_image = cv2.cvtColor(cv2.cvtColor(img, cv2.COLOR_BGR2LUV), cv2.COLOR_BGR2GRAY)
#else: gray_image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
gray_image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
Image.fromarray(gray_image).save("mask_gray_test.png")
(thresh, bw_image) = cv2.threshold(gray_image, mask_precision, 255, cv2.THRESH_BINARY)
if (mode == "discard"): bw_image = numpy.invert(bw_image)
@ -132,9 +180,140 @@ class Shortcode():
return(final_img)
def get_mask():
preds = []
negative_preds = []
image_pil = flatten(self.init_image, opts.img2img_background_color)
if method == "sam":
if method == "clip_surgery":
from lib_unprompted import clip_surgery as clip
import cv2
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from torchvision.transforms import InterpolationMode
BICUBIC = InterpolationMode.BICUBIC
from segment_anything import sam_model_registry, SamPredictor
# default imagenet redundant features
redundants = ['a bad photo of a {}.', 'a photo of many {}.', 'a sculpture of a {}.', 'a photo of the hard to see {}.', 'a low resolution photo of the {}.', 'a rendering of a {}.', 'graffiti of a {}.', 'a bad photo of the {}.', 'a cropped photo of the {}.', 'a tattoo of a {}.', 'the embroidered {}.', 'a photo of a hard to see {}.', 'a bright photo of a {}.', 'a photo of a clean {}.', 'a photo of a dirty {}.', 'a dark photo of the {}.', 'a drawing of a {}.', 'a photo of my {}.', 'the plastic {}.', 'a photo of the cool {}.', 'a close-up photo of a {}.', 'a black and white photo of the {}.', 'a painting of the {}.', 'a painting of a {}.', 'a pixelated photo of the {}.', 'a sculpture of the {}.', 'a bright photo of the {}.', 'a cropped photo of a {}.', 'a plastic {}.', 'a photo of the dirty {}.', 'a jpeg corrupted photo of a {}.', 'a blurry photo of the {}.', 'a photo of the {}.', 'a good photo of the {}.', 'a rendering of the {}.', 'a {} in a video game.', 'a photo of one {}.', 'a doodle of a {}.', 'a close-up photo of the {}.', 'a photo of a {}.', 'the origami {}.', 'the {} in a video game.', 'a sketch of a {}.', 'a doodle of the {}.', 'a origami {}.', 'a low resolution photo of a {}.', 'the toy {}.', 'a rendition of the {}.', 'a photo of the clean {}.', 'a photo of a large {}.', 'a rendition of a {}.', 'a photo of a nice {}.', 'a photo of a weird {}.', 'a blurry photo of a {}.', 'a cartoon {}.', 'art of a {}.', 'a sketch of the {}.', 'a embroidered {}.', 'a pixelated photo of a {}.', 'itap of the {}.', 'a jpeg corrupted photo of the {}.', 'a good photo of a {}.', 'a plushie {}.', 'a photo of the nice {}.', 'a photo of the small {}.', 'a photo of the weird {}.', 'the cartoon {}.', 'art of the {}.', 'a drawing of the {}.', 'a photo of the large {}.', 'a black and white photo of a {}.', 'the plushie {}.', 'a dark photo of a {}.', 'itap of a {}.', 'graffiti of the {}.', 'a toy {}.', 'itap of my {}.', 'a photo of a cool {}.', 'a photo of a small {}.', 'a tattoo of the {}.', 'there is a {} in the scene.', 'there is the {} in the scene.', 'this is a {} in the scene.', 'this is the {} in the scene.', 'this is one {} in the scene.']
if "redundant_features" in kwargs: redundants.extend(kwargs["redundant_features"].split(self.Unprompted.Config.syntax.delimiter))
self.bypass_sam = True if "bypass_sam" in pargs else False
### Init CLIP and data
if self.cached_model == -1 or self.cached_model_method != method:
model, preprocess = clip.load("CS-ViT-B/16", device=device)
model.eval()
# Cache for future runs
self.cached_model = model
self.cached_transform = preprocess
else:
self.Unprompted.log("Using cached model(s) for CLIP_Surgery method")
model = self.cached_model
preprocess = self.cached_transform
image = preprocess(image_pil).unsqueeze(0).to(device)
cv2_img = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
### CLIP Surgery for a single text, without fixed label sets
with torch.no_grad():
# CLIP architecture surgery acts on the image encoder
image_features = model.encode_image(image)
image_features = image_features / image_features.norm(dim=1, keepdim=True)
# Prompt ensemble for text features with normalization
text_features = clip.encode_text_with_prompt_ensemble(model, prompts, device)
if (negative_prompts):
negative_text_features = clip.encode_text_with_prompt_ensemble(model, negative_prompts, device)
# Extract redundant features from an empty string
redundant_features = clip.encode_text_with_prompt_ensemble(model, [""], device, redundants)
# no sam
if self.bypass_sam:
def reg_inference(text_features):
preds = []
# Apply feature surgery for single text
similarity = clip.clip_feature_surgery(image_features, text_features, redundant_features)
similarity_map = clip.get_similarity_map(similarity[:, 1:, :], cv2_img.shape[:2])
# Draw similarity map
for b in range(similarity_map.shape[0]):
for n in range(similarity_map.shape[-1]):
vis = (similarity_map[b, :, :, n].cpu().numpy() * 255).astype('uint8')
preds.append(vis)
return(preds)
preds = reg_inference(text_features)
if (negative_prompts): negative_preds = reg_inference(negative_text_features)
else:
point_thresh = float(self.Unprompted.parse_advanced(kwargs["point_threshold"],context)) if "point_threshold" in kwargs else 0.98
multimask_output = True if "multimask_output" in pargs else False
# Init SAM
if self.cached_predictor == -1 or self.cached_model_method != method:
sam_model_dir = f"{self.Unprompted.base_dir}/models/segment_anything"
os.makedirs(sam_model_dir, exist_ok=True)
sam_filename = "sam_vit_h_4b8939.pth"
sam_file = f"{sam_model_dir}/{sam_filename}"
# Download model weights if we don't have them yet
if not os.path.exists(sam_file):
print("Downloading SAM model weights...")
self.Unprompted.download_file(sam_file,f"https://dl.fbaipublicfiles.com/segment_anything/{sam_filename}")
model_type = "vit_h"
sam = sam_model_registry[model_type](checkpoint=sam_file)
sam.to(device=device)
predictor = SamPredictor(sam)
self.cached_predictor = predictor
else:
predictor = self.cached_predictor
predictor.set_image(np.array(image_pil))
self.cached_model_method = method
def sam_inference(text_features):
preds = []
# Combine features after removing redundant features and min-max norm
sm = clip.clip_feature_surgery(image_features, text_features, redundant_features)[0, 1:, :]
sm_norm = (sm - sm.min(0, keepdim=True)[0]) / (sm.max(0, keepdim=True)[0] - sm.min(0, keepdim=True)[0])
sm_mean = sm_norm.mean(-1, keepdim=True)
# get positive points from individual maps, and negative points from the mean map
p, l = clip.similarity_map_to_points(sm_mean, cv2_img.shape[:2], t=point_thresh)
num = len(p) // 2
points = p[num:] # negatives in the second half
labels = [l[num:]]
for i in range(sm.shape[-1]):
p, l = clip.similarity_map_to_points(sm[:, i], cv2_img.shape[:2], t=point_thresh)
num = len(p) // 2
points = points + p[:num] # positive in first half
labels.append(l[:num])
labels = np.concatenate(labels, 0)
# Inference SAM with points from CLIP Surgery
masks, scores, logits = predictor.predict(point_labels=labels, point_coords=np.array(points), multimask_output=multimask_output)
mask = masks[np.argmax(scores)]
mask = mask.astype('uint8')
vis = cv2_img.copy()
vis[mask > 0] = np.array([255, 255, 255], dtype=np.uint8)
vis[mask == 0] = np.array([0, 0, 0], dtype=np.uint8)
preds.append(vis)
if self.show:
for idx,mask in enumerate(masks):
plt.imsave(f"mask{idx}.png",mask)
return(preds)
preds = sam_inference(text_features)
if negative_prompts: negative_preds = sam_inference(negative_text_features)
elif method == "grounded_sam":
box_thresh = float(self.Unprompted.parse_advanced(kwargs["box_threshold"],context)) if "box_threshold" in kwargs else 0.3
text_thresh = float(self.Unprompted.parse_advanced(kwargs["text_threshold"],context)) if "text_threshold" in kwargs else 0.25
# Grounding DINO
import groundingdino.datasets.transforms as T
from groundingdino.models import build_model
@ -205,7 +384,7 @@ class Shortcode():
model_config_path = f"{self.Unprompted.base_dir}/lib_unprompted/groundingdino/config/GroundingDINO_SwinT_OGC.py"
# load model
if self.cached_model == -1:
if self.cached_model == -1 or self.cached_model_method != method:
args = SLConfig.fromfile(model_config_path)
args.device = device
model = build_model(args)
@ -223,6 +402,7 @@ class Shortcode():
)
self.cached_model = model
self.cached_transform = transform
self.cached_model_method = method
else:
self.Unprompted.log("Using cached GroundingDINO model.")
@ -248,10 +428,13 @@ class Shortcode():
preds = []
value = 0
mask_img = torch.zeros(masks.shape[-2:])
mask_img = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
for idx, mask in enumerate(masks):
mask_img[mask.cpu().numpy()[0] == True] = value + idx + 1
preds.append(mask_img.numpy())
# mask_img[mask.cpu().numpy()[0] == True] = value + idx + 1
mask_img[mask.cpu().numpy()[0] >= 1] = np.array([255, 255, 255], dtype=np.uint8)
mask_img[mask.cpu().numpy()[0] < 1] = np.array([0, 0, 0], dtype=np.uint8)
# TODO: Figure out if we can take advantage of individual mask layers rather than stacking as composite
preds.append(mask_img)
return(preds)
@ -293,7 +476,7 @@ class Shortcode():
# load model
if self.cached_model == -1:
if self.cached_model == -1 or self.cached_model_method != method:
self.Unprompted.log("Loading clipseg model...")
model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, complex_trans_conv=not self.legacy_weights)
@ -310,6 +493,7 @@ class Shortcode():
# Cache for future runs
self.cached_model = model
self.cached_transform = transform
self.cached_model_method = method
else:
self.Unprompted.log("Using cached clipseg model.")
model = self.cached_model
@ -319,11 +503,17 @@ class Shortcode():
# predict
with torch.no_grad():
preds = model(img.repeat(prompt_parts,1,1,1).to(device=device), prompts)[0].cpu()
if "image_prompt" in kwargs:
from PIL import Image
img_mask = flatten(Image.open(r"A:/inbox/test_mask.png"), opts.img2img_background_color)
img_mask = transform(img_mask).unsqueeze(0)
preds = model(img.to(device=device), img_mask.to(device=device))[0].cpu()
else:
preds = model(img.repeat(prompt_parts,1,1,1).to(device=device), prompts)[0].cpu()
if (negative_prompts): negative_preds = model(img.repeat(negative_prompt_parts,1,1,1).to(device=device), negative_prompts)[0].cpu()
# All of the below logic applies to both clipseg and sam
if "image_mask" not in self.Unprompted.shortcode_user_vars: self.Unprompted.shortcode_user_vars["image_mask"] = None
if (brush_mask_mode == "add" and self.Unprompted.shortcode_user_vars["image_mask"] is not None):
@ -411,6 +601,8 @@ class Shortcode():
if "unload_model" in pargs:
self.model = -1
self.cached_model = -1
self.cached_model_method = ""
self.cached_predictor = -1
return final_img
@ -450,7 +642,7 @@ class Shortcode():
def ui(self,gr):
gr.Radio(label="Mask blend mode 🡢 mode",choices=["add","subtract","discard"],value="add",interactive=True)
gr.Radio(label="Masking tech method 🡢 method",choices=["sam","clipseg"],value="sam",interactive=True)
gr.Radio(label="Masking tech method 🡢 method",choices=["clipseg","clip_surgery","grounded_sam"],value="clipseg",interactive=True)
gr.Checkbox(label="Show mask in output 🡢 show")
gr.Checkbox(label="Use clipseg legacy weights 🡢 legacy_weights")
gr.Number(label="Precision of selected area 🡢 precision",value=100,interactive=True)

View File

@ -55,7 +55,7 @@ class Shortcode():
debug = True if "debug" in pargs else False
show_original = True if "show_original" in pargs else False
color_correct_method = self.Unprompted.parse_alt_tags(kwargs["color_correct_method"],context) if "color_correct_method" in kwargs else "mkl"
color_correct_method = self.Unprompted.parse_alt_tags(kwargs["color_correct_method"],context) if "color_correct_method" in kwargs else "none"
color_correct_timing = self.Unprompted.parse_alt_tags(kwargs["color_correct_timing"],context) if "color_correct_timing" in kwargs else "pre"
color_correct_strength = int(float(self.Unprompted.parse_advanced(kwargs["color_correct_strength"],context))) if "color_correct_strength" in kwargs else 1
manual_mask_mode = self.Unprompted.parse_alt_tags(kwargs["mode"],context) if "mode" in kwargs else "subtract"
@ -108,7 +108,7 @@ class Shortcode():
cfg_min = float(self.Unprompted.parse_advanced(kwargs["cfg_scale_min"],context)) if "cfg_scale_min" in kwargs else 7.0
target_size_max = float(self.Unprompted.parse_advanced(kwargs["mask_size_max"],context)) if "mask_size_max" in kwargs else 0.5
target_size_max_orig = target_size_max
cfg_max = self.Unprompted.p_copy.cfg_scale - cfg_min
cfg_max = max(cfg_min,self.Unprompted.p_copy.cfg_scale - cfg_min)
padding_original = int(float(self.Unprompted.parse_advanced(kwargs["contour_padding"],context))) if "contour_padding" in kwargs else 0
min_area = int(float(self.Unprompted.parse_advanced(kwargs["min_area"],context))) if "min_area" in kwargs else 50
@ -141,7 +141,7 @@ class Shortcode():
target_size_max = target_size_max_orig * target_multiplier
sd_unit = 64
denoise_unit = (denoising_max / 2) * 0.125
denoise_unit = (denoising_max / 4) * 0.125
cfg_min_unit = (cfg_min / 2) * 0.125
cfg_max_unit = (cfg_max / 2) * 0.125
step_unit = math.ceil(self.Unprompted.p_copy.steps * 0.125)
@ -154,6 +154,7 @@ class Shortcode():
cfg_min += cfg_min_unit
cfg_max += cfg_max_unit
sharpen_amount += 0.125
denoising_max += denoise_unit
self.Unprompted.p_copy.steps += step_unit
upscale_width = min(hires_size_max,upscale_width)
@ -170,6 +171,7 @@ class Shortcode():
if "include_original" in pargs:
append_originals.append(image_pil.copy())
if "mask_method" in kwargs: set_kwargs["method"] = kwargs["mask_method"]
set_kwargs["txt2mask_init_image"] = image_pil
mask_image = self.Unprompted.shortcode_objects["txt2mask"].run_block(set_pargs,set_kwargs,None,target_mask)
@ -257,7 +259,7 @@ class Shortcode():
self.Unprompted.log(f"Denoising strength is {self.Unprompted.p_copy.denoising_strength}")
if "cfg_scale" not in kwargs:
self.Unprompted.p_copy.cfg_scale = cfg_min + sig * cfg_max
self.Unprompted.log(f"CFG Scale is {self.Unprompted.shortcode_user_vars['cfg_scale']} (min {cfg_min}, max {cfg_min+cfg_max})")
self.Unprompted.log(f"CFG Scale is {self.Unprompted.p_copy.cfg_scale} (min {cfg_min}, max {cfg_min+cfg_max})")
else:
self.Unprompted.log("Humongous target detected. Skipping zoom_enhance...")
continue
@ -312,16 +314,18 @@ class Shortcode():
# run img2img now to improve face
if is_img2img:
fixed_image = process_images_inner_(self.Unprompted.p_copy)
fixed_image = fixed_image.images[0]
else:
#workaround for txt2img
# workaround for txt2img, not sure if compatible with controlnet
for att in dir(self.Unprompted.p_copy):
if not att.startswith("__") and att != "sd_model":
self.Unprompted.shortcode_user_vars[att] = getattr(self.Unprompted.p_copy,att)
fixed_image = self.Unprompted.shortcode_objects["img2img"].run_atomic(set_pargs,None,None)
if debug: fixed_image.save("zoom_enhance_4after.png")
if color_correct_method != "none" and starting_image:
try:
@ -352,7 +356,7 @@ class Shortcode():
current_mask = current_mask.resize((width,height))
if debug: current_mask.save("zoom_enhance_5d_current_main_mask.png")
image_pil.paste(corrected_main_img,(0,0),current_mask)
image_pil.save("zoom_enhance_5e_corrected_main_image.png")
if debug: image_pil.save("zoom_enhance_5e_corrected_main_image.png")
except Exception as e:
self.Unprompted.log(f"{e}",context="ERROR")
@ -369,13 +373,17 @@ class Shortcode():
# Slap our new image back onto the original
image_pil.paste(fixed_image, (x1 - padding, y1 - padding),sub_mask)
# self.Unprompted.shortcode_user_vars["init_images"].append(image_pil)
if show_original: append_originals.append(image_pil.copy())
else: self.Unprompted.after_processed.images[image_idx] = image_pil
# test outside after block, WIP pls don't use yet
self.Unprompted.log(f"Adding zoom_enhance result for image_idx {image_idx}")
if context != "after":
self.Unprompted.log("Attempting to use zoom_enhance outside of an after block... good luck")
self.Unprompted.shortcode_user_vars["init_images"] = image_pil
# main return
else:
try:
if show_original: append_originals.append(image_pil.copy())
else: self.Unprompted.after_processed.images[image_idx] = image_pil
except Exception as e:
self.Unprompted.log(f"Could not append zoom_enhance result: {e}",context="ERROR")
# Remove temp image key in case [img2img] is used later
if "img2img_init_image" in self.Unprompted.shortcode_user_vars: del self.Unprompted.shortcode_user_vars["img2img_init_image"]

View File

@ -1,4 +1,4 @@
[template name="Bodysnatcher v1.0.0"]
[template name="Bodysnatcher v1.1.0"]
![Preview]([base_dir]/bodysnatcher.png)
## ⚠️ Important info, please read carefully:
@ -40,32 +40,35 @@ Always bodysnatch responsibly.
[set prefix _new _label="Prefix" _info="For example, the visual medium"]photo of[/set]
[set subject _new _label="New subject"]mona lisa[/set]
[set simple_description _new _label="Simple Description" _info="These terms will apply to both the full image and the cropped face, less is more"][/set]
[set simple_description _new _label="Simple Description" _info="These terms will apply to both the full image and the cropped face, 1-3 words are usually plenty"][/set]
[set class _new _label="Class" _info="The search term that determines the inpainting mask"]woman[/set]
[set background_mode _new _label="Background Mode" _ui="checkbox" _info="Inverts the class mask and disables the zoom_enhance step (note: you'll probably want to increase the mask precision)"]0[/set]
[set keep_hands _new _label="Keep original hands" _ui="checkbox" _info="You don't really want Stable Diffusion to remake those hands, do you?"]1[/set]
[set keep_feet _new _label="Keep original feet" _ui="checkbox"]1[/set]
[set use_optimized_inference_settings _new _label="Use optimized inference settings" _ui="checkbox" _info="Locks CFG scale, denoising strength, etc. to recommended values"]1[/set]
[set use_controlnet_preset _new _info="Loads multiple ControlNet units, make sure you have 'Allow other scripts to control this extension' enabled" _label="ControlNet preset" _ui="dropdown" _choices="none|photo_general_v1|dev"]none[/set]
[set use_controlnet_preset _new _info="Loads multiple ControlNet units, please make sure you have 'Allow other scripts to control this extension' enabled (note: the 'dev' preset is for internal testing)" _label="ControlNet preset" _ui="dropdown" _choices="none|photo_general_v1|dev"]none[/set]
[wizard_ui_accordion _label="⚙️ Advanced Options"]
{set fix_bodypart _new _label="Fix a body part"}face{/set}
{set fix_bodypart _new _label="Fix a body part" _info="Note: currently not compatible with Background Mode"}face{/set}
{set color_correct_method _new _label="Color correct method" _ui="dropdown" _choices="none|hm|mvgd|mkl|hm-mvgd-hm|hm-mkl-hm"}hm-mkl-hm{/set}
{set color_correct_timing _new _label="Color correct timing" _info="Post may produce more accurate colors, but it tends to look a bit posterized" _ui="dropdown" _choices="pre|post"}pre{/set}
{set color_correct_strength _new _label="Color correct strength" _ui="slider" _minimum=1 _maximum=5}1{/set}
{set mask_method _new _label="Masking method (sam requires manual setup)" _ui="radio" _choices="clipseg|sam"}clipseg{/set}
{set mask_method _new _label="Masking method" _ui="radio" _choices="clipseg|clip_surgery|grounded_sam"}clipseg{/set}
{set manual_mask_mode _new _label="Manual masking mode" _ui="radio" _choices="add|subtract|discard"}subtract{/set}
{set mask_precision _new _label="Mask precision"}75{/set}
{set stamp _new _label="Stamp" _info="Paste a temporary image on the init image for the purpose of masking (check unprompted/images/stamps for default stamps)"}{/set}
{set zoom_enhance_denoising_max _new}0.30{/set}
{set zoom_enhance_base_cfg _new _ui="slider" _minimum="1" _maximum="30"}10{/set}
{set zoom_enhance_base_cfg _new _ui="slider" _minimum="1" _maximum="30"}7{/set}
{set show_original _new _label="Show unenhanced image in output window" _ui="checkbox"}0{/set}
{set debug _new _label="Save debug images" _ui="checkbox"}0{/set}
[/wizard_ui_accordion]
[sets neg_mask=""]
[if keep_hands=1]{set neg_mask}fingers{/set}[/if]
[if keep_feet=1]{set neg_mask _append}|feet{/set}[/if]
[if "(keep_hands==1 and background_mode==0) or (keep_hands==0 and background_mode==1)"]{set neg_mask}fingers{/set}[/if]
[if "(keep_feet==1 and background_mode==0) or (keep_feet==0 and background_mode==1)"]{set neg_mask _append}|feet{/set}[/if]
[if use_optimized_inference_settings=1]
{sets cfg_scale=7.5 sampler_name="Euler a" steps=25 denoising_strength=0.75 mask_blur=10}
@ -75,8 +78,8 @@ Always bodysnatch responsibly.
{{sets controlnet_0_enabled=1 controlnet_0_module=softedge_hed controlnet_0_model=controlnet11Models_softedge controlnet_0_weight=0.25 controlnet_1_enabled=1 controlnet_1_module=mediapipe_face controlnet_1_model=control_mediapipe_face_sd15_v2 controlnet_1_weight=1.0 controlnet_2_enabled=1 controlnet_2_enabled=1 controlnet_2_module=openpose_full controlnet_2_model=controlnet11Models_openpose}}
{/case}
{case "dev"}
{{sets controlnet_0_enabled=1 controlnet_0_module=softedge_hed controlnet_0_model=controlnet11Models_softedge controlnet_0_weight=0.5 controlnet_1_enabled=1 controlnet_1_module=mediapipe_face controlnet_1_model=control_mediapipe_face_sd15_v2 controlnet_1_weight=1.0 controlnet_2_enabled=1 controlnet_2_module=t2ia_color_grid controlnet_2_model=coadapter-color-sd15v1 controlnet_2_weight=1.0 controlnet_3_enabled=1 controlnet_3_module=openpose_full controlnet_3_model=controlnet11Models_openpose controlnet_3_weight=1.0}}
{{sets controlnet_0_enabled=1 controlnet_0_module=softedge_hed controlnet_0_model=controlnet11Models_softedge controlnet_0_weight=0.25 controlnet_1_enabled=1 controlnet_1_module=mediapipe_face controlnet_1_model=control_mediapipe_face_sd15_v2 controlnet_1_weight=1.0 controlnet_2_enabled=1 controlnet_2_module=t2ia_color_grid controlnet_2_model=coadapter-color-sd15v1 controlnet_2_weight=1.0 controlnet_3_enabled=1 controlnet_3_module=openpose_full controlnet_3_model=controlnet11Models_openpose controlnet_3_weight=1.0}}
{/case}
[/switch]
[img2img_autosize][txt2mask precision="{get mask_precision}" method="{get mask_method}" mode="{get manual_mask_mode}" negative_mask="{get neg_mask}" padding=10 mask_blur=20][get class][/txt2mask][after]{zoom_enhance color_correct_method="[get color_correct_method]" color_correct_timing="[get color_correct_timing]" color_correct_strength="[get color_correct_strength]" [if show_original=1]show_original[/if] sharpen_amount=0.0 mode="subtract" [if debug=1]debug[/if] mask="[get class] [get fix_bodypart]" replacement="[get prefix] [get subject] [get fix_bodypart] [get simple_description _before=' ']" cfg_scale="[get zoom_enhance_base_cfg]" denoising_max="[get zoom_enhance_denoising_max]"}[/after][get prefix] [get subject][get simple_description _before=" "]
[img2img_autosize][if batch_index=0]{txt2mask precision="{{get mask_precision}}" method="{{get mask_method}}" mode="{{get manual_mask_mode}}" negative_mask="{{get neg_mask}}" padding=10 mask_blur=20}{get class}{/txt2mask}{if background_mode=1}{{invert_mask}}{/if}[/if][if "background_mode==0 and batch_index==0"]{after}{{zoom_enhance mask_method="[get mask_method]" color_correct_method="[get color_correct_method]" color_correct_timing="[get color_correct_timing]" color_correct_strength="[get color_correct_strength]" [if show_original=1]show_original[/if] sharpen_amount=0.0 mode="subtract" [if debug=1]debug[/if] mask="[get class] [get fix_bodypart]" replacement="[get prefix] [get subject] [get fix_bodypart] [get simple_description _before=' ']" cfg_scale_min="[get zoom_enhance_base_cfg]" cfg_scale_max="15" denoising_max="[get zoom_enhance_denoising_max]"}}{/after}[/if][get prefix] [get subject][get simple_description _before=" "]