Merge branch 'dev' into temp

pull/2427/head
Vladimir Mandic 2023-10-31 08:48:52 -04:00 committed by GitHub
commit 0d7807acd3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
310 changed files with 43425 additions and 421 deletions

2
.gitignore vendored
View File

@ -43,7 +43,6 @@ cache
!package.json
# all dynamic stuff
/repositories/**/*
/extensions/**/*
/outputs/**/*
/embeddings/**/*
@ -60,6 +59,5 @@ cache
/localizations
# unexcluded so folders get created
!/repositories/.placeholder
!/models/VAE-approx
!/models/VAE-approx/model.pt

4
.gitmodules vendored
View File

@ -32,3 +32,7 @@
path = extensions-builtin/sd-extension-chainner
url = https://github.com/vladmandic/sd-extension-chainner
ignore = dirty
[submodule "modules/k-diffusion"]
path = modules/k-diffusion
url = https://github.com/crowsonkb/k-diffusion
ignore = dirty

View File

@ -151,6 +151,7 @@ disable=bad-inline-option,
missing-function-docstring,
missing-module-docstring,
no-else-return,
not-callable,
pointless-string-statement,
raw-checker-failed,
simplifiable-if-expression,

View File

@ -1,67 +1,79 @@
# Change Log for SD.Next
## Update for 2023-10-25
## Update for 2023-10-30
*Note*: Pending release of `diffusers==0.22.0`
Mostly service release with support for several new models and additional optimizations...
Another pretty big release, this time with focus on
new models, new backends and optimizations and tons of fixes
Also, [Wiki](https://github.com/vladmandic/automatic/wiki) has been updated with new content, so check it out!
Some highlights: [OpenVINO](https://github.com/vladmandic/automatic/wiki/OpenVINO), [IntelArc](https://github.com/vladmandic/automatic/wiki/Intel-ARC), [DirectML](https://github.com/vladmandic/automatic/wiki/DirectML), [ONNX/Olive>](https://github.com/vladmandic/automatic/wiki/ONNX-Runtime)
- **Diffusers**
- new model type: [SegMind SSD-1B](https://huggingface.co/segmind/SSD-1B)
its a distilled model, this time 50% smaller and faster version of SD-XL!
- new model type: [Segmind SSD-1B](https://huggingface.co/segmind/SSD-1B)
its a *distilled* model, this time 50% smaller and faster version of SD-XL!
(and quality does not suffer, its just more optimized)
test shows batch-size:4 with 1k images used less than 6.5GB of VRAM
download using built-in **Huggingface** downloader: `segmind/SSD-1B`
- new model type: [LCM: Latent Consistency Models](https://github.com/openai/consistency_models)
near-instant generate in a as little as 3 steps!
combined with OpenVINO, generate on CPU takes less than 10 seconds: <https://www.youtube.com/watch?v=b90ESUTLsRo>
download using built-in **Huggingface** downloader: `SimianLuo/LCM_Dreamshaper_v7`
- support for **Custom pipelines**, thanks @disty0
download using built-in **Huggingface** downloader
think of them as plugins for diffusers not unlike original extensions that modify behavior of `ldm` backend
list of community pipelines: <https://github.com/huggingface/diffusers/tree/main/examples/community>
and make sure to check our reference one: `Disty0/zero123plus-pipeline`
which generates 4 output images with different camera positions: front, side, top, back!
list of community pipelines: <https://github.com/huggingface/diffusers/blob/main/examples/community/README.md>
- new custom pipeline: `Disty0/zero123plus-pipeline`
generate 4 output images with different camera positions: front, side, top, back!
for more details, see <https://github.com/vladmandic/automatic/discussions/2421>
- new backend: **ONNX/Olive** (experimental)
for details, see WiKi
- extend support for [Free-U](https://github.com/ChenyangSi/FreeU)
improve generations quality at no cost (other than finding params that work for you)
- **General**
- add **Lora OFT** support, thanks @antis0007 and @ai-casanova
- **Upscalers**
- **compile compile** option, thanks @disty0
- **compile** option, thanks @disty0
- **chaiNNer** add high quality models from [Helaman](https://openmodeldb.info/users/helaman)
- redesigned **progress bar** with full details on current operation
- redesigned **Progress bar** with full details on current operation
- **Extra networks** sort by name, size, date, etc.
- new option: *settings -> images -> keep incomplete*
can be used to skip vae decode on aborted/skipped/interrupted image generations
- new option: *settings -> system paths -> models*
can be used to set custom base path for *all* models (previously only as cli option)
- remove external clone of items in `/repositories`
- switch core font in default theme to **noto-sans**
previously default font was simply *system-ui*, but it lead to too much variations between browsers and platforms
- **Fixes**
- fix **freeu** for backend original and add it to xyz grid
- fix loading diffuser models in huggingface format from non-standard location
- fix default styles looking in wrong location
- fix missing upscaler folder on initial startup
- fix handling of relative path for models
- fix simple live preview device mismatch
- fix batch img2img
- fix diffusers samplers: dpm++ 2m, dpm++ 1s, deis
- fix new style filename template
- fix image name template using model name
- fix image name sequence
- fix model path using relative path
- fix `torch-rocm` and `tensorflow-rocm` version detection, thanks @xangelix
- fix **chainner** upscalers color clipping
- **Fixes**
- fix **freeu** for backend original and add it to xyz grid
- fix loading diffuser models in huggingface format from non-standard location
- fix default styles looking in wrong location
- fix missing upscaler folder on initial startup
- fix handling of relative path for models
- fix simple live preview device mismatch
- fix batch img2img
- fix diffusers samplers: dpm++ 2m, dpm++ 1s, deis
- fix new style filename template
- fix image name template using model name
- fix image name sequence
- fix model path using relative path
- fix `torch-rocm` and `tensorflow-rocm` version detection, thanks @xangelix
- fix **chainner** upscalers color clipping
- fix for base+refiner workflow in diffusers mode: number of steps, diffuser pipe mode
- fix for prompt encoder with refiner in diffusers mode
- fix prompts-from-file saving incorrect metadata
- fix before-hires step
- fix diffusers switch from invalid model
- **directml** and **ipex** updates
- force second requirements check on startup
- remove lyco, multiple_tqdm
- fix for prompt encoder with refiner in diffusers mode
- fix prompts-from-file saving incorrect metadata
- fix before-hires step
- fix diffusers switch from invalid model
- **directml** and **ipex** updates
- force second requirements check on startup
- remove **lyco**, multiple_tqdm
- enhance extension compatibility for exensions directly importing codeformers
- enhance extension compatibility for exensions directly accessing processing params
- css fixes
- clearly mark external themes in ui
- update `openvino`, thanks @disty0
- update `typing-extensions`
- **css** fixes
- clearly mark external themes in ui
- update `openvino`, thanks @disty0
- update `typing-extensions`
## Update for 2023-10-17

View File

@ -56,21 +56,23 @@ Additional models will be added as they become available and there is public int
- [Stable Diffusion](https://github.com/Stability-AI/stablediffusion/) 1.x and 2.x *(all variants)*
- [Stable Diffusion XL](https://github.com/Stability-AI/generative-models)
- [Kandinsky](https://github.com/ai-forever/Kandinsky-2) 2.1 and 2.2
- [DeepFloyd IF](https://github.com/deep-floyd/IF)
- [UniDiffusion](https://github.com/thu-ml/unidiffuser)
- [SD-Distilled](https://huggingface.co/blog/sd_distillation) *(all variants)*
- [Wuerstchen](https://huggingface.co/blog/wuertschen)
- [Segmind SSD-1B](https://huggingface.co/segmind/SSD-1B)
- [UniDiffusion](https://github.com/thu-ml/unidiffuser)
- [DeepFloyd IF](https://github.com/deep-floyd/IF)
## Platform support
- *nVidia* GPUs using **CUDA** libraries on both *Windows and Linux*
- *AMD* GPUs using **ROCm** libraries on *Linux*.
Support will be extended to *Windows* once AMD releases ROCm for Windows
- *Intel Arc* GPUs using **OneAPI** with *IPEX XPU* libraries on both *Windows and Linux*
- Any GPU compatible with *DirectX* on *Windows* using **DirectML** libraries.
This includes support for AMD GPUs that are not supported by native ROCm libraries
- Any GPU or device compatible with **OpenVINO** libraries on both *Windows and Linux*
- *Apple M1/M2* on *OSX* using built-in support in Torch with **MPS** optimizations
- *nVidia* GPUs using **CUDA** libraries on both *Windows and Linux*
- *AMD* GPUs using **ROCm** libraries on *Linux*
Support will be extended to *Windows* once AMD releases ROCm for Windows
- *Intel Arc* GPUs using **OneAPI** with *IPEX XPU* libraries on both *Windows and Linux*
- Any GPU compatible with *DirectX* on *Windows* using **DirectML** libraries
This includes support for AMD GPUs that are not supported by native ROCm libraries
- Any GPU or device compatible with **OpenVINO** libraries on both *Windows and Linux*
- *Apple M1/M2* on *OSX* using built-in support in Torch with **MPS** optimizations
- *ONNX/Olive* (experimental)
## Install & Run

View File

@ -0,0 +1,80 @@
model:
base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion
params:
embedding_dropout: 0.25
parameterization: "v"
linear_start: 0.00085
linear_end: 0.0120
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 96
channels: 4
cond_stage_trainable: false
conditioning_key: crossattn-adm
scale_factor: 0.18215
monitor: val/loss_simple_ema
use_ema: False
embedder_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
noise_aug_config:
target: ldm.modules.encoders.modules.CLIPEmbeddingNoiseAugmentation
params:
timestep_dim: 1024
noise_schedule_config:
timesteps: 1000
beta_schedule: squaredcos_cap_v2
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
num_classes: "sequential"
adm_in_channels: 2048
use_checkpoint: True
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_head_channels: 64 # need to fix for flash-attn
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: "vanilla-xformers"
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: [ ]
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params:
freeze: True
layer: "penultimate"

View File

@ -0,0 +1,83 @@
model:
base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion
params:
embedding_dropout: 0.25
parameterization: "v"
linear_start: 0.00085
linear_end: 0.0120
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 96
channels: 4
cond_stage_trainable: false
conditioning_key: crossattn-adm
scale_factor: 0.18215
monitor: val/loss_simple_ema
use_ema: False
embedder_config:
target: ldm.modules.encoders.modules.ClipImageEmbedder
params:
model: "ViT-L/14"
noise_aug_config:
target: ldm.modules.encoders.modules.CLIPEmbeddingNoiseAugmentation
params:
clip_stats_path: "checkpoints/karlo_models/ViT-L-14_stats.th"
timestep_dim: 768
noise_schedule_config:
timesteps: 1000
beta_schedule: squaredcos_cap_v2
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
num_classes: "sequential"
adm_in_channels: 1536
use_checkpoint: True
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_head_channels: 64 # need to fix for flash-attn
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: "vanilla-xformers"
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: [ ]
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params:
freeze: True
layer: "penultimate"

View File

@ -0,0 +1,74 @@
model:
base_learning_rate: 5.0e-07
target: ldm.models.diffusion.ddpm.LatentDepth2ImageDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false
conditioning_key: hybrid
scale_factor: 0.18215
monitor: val/loss_simple_ema
finetune_keys: null
use_ema: False
depth_stage_config:
target: ldm.modules.midas.api.MiDaSInference
params:
model_type: "dpt_hybrid"
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
use_checkpoint: True
image_size: 32 # unused
in_channels: 5
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_head_channels: 64 # need to fix for flash-attn
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
#attn_type: "vanilla-xformers"
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: [ ]
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params:
freeze: True
layer: "penultimate"

View File

@ -112,11 +112,12 @@ class KeyConvert:
self.converter = self.diffusers
self.is_sdxl = True if shared.sd_model_type == "sdxl" else False
self.UNET_CONVERSION_MAP = make_unet_conversion_map() if self.is_sdxl else None
self.LORA_PREFIX_UNET = "lora_unet"
self.LORA_PREFIX_TEXT_ENCODER = "lora_te"
self.LORA_PREFIX_UNET = "lora_unet_"
self.LORA_PREFIX_TEXT_ENCODER = "lora_te_"
self.OFT_PREFIX_UNET = "oft_unet_"
# SDXL: must starts with LORA_PREFIX_TEXT_ENCODER
self.LORA_PREFIX_TEXT_ENCODER1 = "lora_te1"
self.LORA_PREFIX_TEXT_ENCODER2 = "lora_te2"
self.LORA_PREFIX_TEXT_ENCODER1 = "lora_te1_"
self.LORA_PREFIX_TEXT_ENCODER2 = "lora_te2_"
def original(self, key):
key = convert_diffusers_name_to_compvis(key, self.is_sd2)
@ -142,13 +143,12 @@ class KeyConvert:
if self.is_sdxl:
map_keys = list(self.UNET_CONVERSION_MAP.keys()) # prefix of U-Net modules
map_keys.sort()
search_key = key.replace(self.LORA_PREFIX_UNET + "_", "").replace(self.LORA_PREFIX_TEXT_ENCODER1 + "_",
"").replace(
self.LORA_PREFIX_TEXT_ENCODER2 + "_", "")
search_key = key.replace(self.LORA_PREFIX_UNET, "").replace(self.OFT_PREFIX_UNET, "").replace(self.LORA_PREFIX_TEXT_ENCODER1, "").replace(self.LORA_PREFIX_TEXT_ENCODER2, "")
position = bisect.bisect_right(map_keys, search_key)
map_key = map_keys[position - 1]
if search_key.startswith(map_key):
key = key.replace(map_key, self.UNET_CONVERSION_MAP[map_key]) # pylint: disable=unsubscriptable-object
key = key.replace(map_key, self.UNET_CONVERSION_MAP[map_key]).replace("oft","lora") # pylint: disable=unsubscriptable-object
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
return key, sd_module

View File

@ -0,0 +1,49 @@
import torch
import diffusers.models.lora as diffusers_lora
import network
from modules import devices
class ModuleTypeOFT(network.ModuleType):
def create_module(self, net: network.Network, weights: network.NetworkWeights):
"""
weights.w.items()
alpha : tensor(0.0010, dtype=torch.bfloat16)
oft_blocks : tensor([[[ 0.0000e+00, 1.4400e-04, 1.7319e-03, ..., -8.8882e-04,
5.7373e-03, -4.4250e-03],
[-1.4400e-04, 0.0000e+00, 8.6594e-04, ..., 1.5945e-03,
-8.5449e-04, 1.9684e-03], ...etc...
, dtype=torch.bfloat16)"""
if "oft_blocks" in weights.w.keys():
module = NetworkModuleOFT(net, weights)
return module
else:
return None
class NetworkModuleOFT(network.NetworkModule):
def __init__(self, net: network.Network, weights: network.NetworkWeights):
super().__init__(net, weights)
self.weights = weights.w.get("oft_blocks").to(device=devices.device)
self.dim = self.weights.shape[0] # num blocks
self.alpha = self.multiplier()
self.block_size = self.weights.shape[-1]
def get_weight(self):
block_Q = self.weights - self.weights.transpose(1, 2)
I = torch.eye(self.block_size, device=devices.device).unsqueeze(0).repeat(self.dim, 1, 1)
block_R = torch.matmul(I + block_Q, (I - block_Q).inverse())
block_R_weighted = self.alpha * block_R + (1 - self.alpha) * I
R = torch.block_diag(*block_R_weighted)
return R
def calc_updown(self, orig_weight):
R = self.get_weight().to(device=devices.device, dtype=orig_weight.dtype)
if orig_weight.dim() == 4:
updown = torch.einsum("oihw, op -> pihw", orig_weight, R) * self.calc_scale()
else:
updown = torch.einsum("oi, op -> pi", orig_weight, R) * self.calc_scale()
return self.finalize_updown(updown, orig_weight, orig_weight.shape)

View File

@ -7,6 +7,7 @@ import network
import network_lora
import network_hada
import network_ia3
import network_oft
import network_lokr
import network_full
import network_norm
@ -32,6 +33,7 @@ module_types = [
network_lora.ModuleTypeLora(),
network_hada.ModuleTypeHada(),
network_ia3.ModuleTypeIa3(),
network_oft.ModuleTypeOFT(),
network_lokr.ModuleTypeLokr(),
network_full.ModuleTypeFull(),
network_norm.ModuleTypeNorm(),

View File

@ -53,6 +53,8 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
"local_preview": f"{path}.{shared.opts.samples_format}",
"metadata": json.dumps(l.metadata, indent=4) if l.metadata else None,
"tags": tags,
"mtime": os.path.getmtime(l.filename),
"size": os.path.getsize(l.filename),
}
return item
except Exception as e:

@ -1 +1 @@
Subproject commit e382d1618593ae05a7115006de8680d3ddbd9777
Subproject commit f2aafcf2beb99a03cbdf7db73852228ccd6bd1d6

@ -1 +1 @@
Subproject commit 7fb6a263d66c4167f386bf6f66b12f187d44505e
Subproject commit 7f57729626503837a70ad9eed92313bc36db7bf3

View File

@ -16,6 +16,7 @@
{"id":"","label":"⇰","localized":"","hint":"Apply selected styles to current prompt"},
{"id":"","label":"⇩","localized":"","hint":"Save parameters from last generated image as style template"},
{"id":"","label":"🕮","localized":"","hint":"Save parameters from last generated image as style template"},
{"id":"","label":"⇕","localized":"","hint":"Sort by: Name asc/desc, Size largest/smallest, Time newest/oldest"},
{"id":"","label":"⟲","localized":"","hint":"Refresh"},
{"id":"","label":"✕","localized":"","hint":"Close"},
{"id":"","label":"⊜","localized":"","hint":"Fill"},

View File

@ -591,6 +591,7 @@ def install_packages():
# clone required repositories
def install_repositories():
"""
if args.profile:
pr = cProfile.Profile()
pr.enable()
@ -615,6 +616,7 @@ def install_repositories():
clone(blip_repo, d('BLIP'), blip_commit)
if args.profile:
print_profile(pr, 'Repositories')
"""
# run extension installer
@ -659,7 +661,7 @@ def install_extensions():
pkg_resources._initialize_master_working_set() # pylint: disable=protected-access
pkgs = [f'{p.project_name}=={p._version}' for p in pkg_resources.working_set] # pylint: disable=protected-access,not-an-iterable
log.debug(f'Installed packages: {len(pkgs)}')
from modules.paths_internal import extensions_builtin_dir, extensions_dir
from modules.paths import extensions_builtin_dir, extensions_dir
extensions_duplicates = []
extensions_enabled = []
extension_folders = [extensions_builtin_dir] if args.safe else [extensions_builtin_dir, extensions_dir]
@ -793,7 +795,7 @@ def set_environment():
def check_extensions():
newest_all = os.path.getmtime('requirements.txt')
from modules.paths_internal import extensions_builtin_dir, extensions_dir
from modules.paths import extensions_builtin_dir, extensions_dir
extension_folders = [extensions_builtin_dir] if args.safe else [extensions_builtin_dir, extensions_dir]
disabled_extensions_all = opts.get('disable_all_extensions', 'none')
if disabled_extensions_all != 'none':
@ -981,7 +983,7 @@ def extensions_preload(parser):
log.info('Running in safe mode without user extensions')
try:
from modules.script_loading import preload_extensions
from modules.paths_internal import extensions_builtin_dir, extensions_dir
from modules.paths import extensions_builtin_dir, extensions_dir
extension_folders = [extensions_builtin_dir] if args.safe else [extensions_builtin_dir, extensions_dir]
preload_time = {}
for ext_dir in extension_folders:

View File

@ -119,9 +119,6 @@ svg.feather.feather-image, .feather .feather-image { display: none }
.gradio-button.tool { filter: hue-rotate(180deg) saturate(0.5); }
#txt2img_styles_row, #img2img_styles_row { margin-top: -6px; z-index: 200; }
/* custom elements overrides */
#steps-animation, #controlnet { border-width: 0; }
/* based on gradio built-in dark theme */
:root, .light, .dark {
--body-background-fill: #000000; /* Black */

View File

@ -14,7 +14,6 @@
width: 22em; min-height: 1.3em; font-size: 0.8em; transition: opacity 0.2s ease-in; pointer-events: none; opacity: 0; z-index: 999; }
.tooltip-show { opacity: 0.9; }
.toolbutton-selected { background: var(--background-fill-primary) !important; }
.jobStatus { position: fixed; bottom: 1em; right: 1em; background: var(--input-background-fill); padding: 0.4em; font-size: 0.8em; color: var(--body-text-color-subdued); }
/* live preview */
.progressDiv{ position: relative; height: 20px; background: #b4c0cc; margin-bottom: -3px; }
@ -94,7 +93,7 @@ table.settings-value-table td { padding: 0.4em; border: 1px solid #ccc; max-widt
.extra-network-cards { display: flex; flex-wrap: wrap; overflow-y: auto; overflow-x: hidden; align-content: flex-start; width: -moz-available; width: -webkit-fill-available; }
.extra-network-cards .card { height: fit-content; margin: 0 0 0.5em 0.5em; position: relative; scroll-snap-align: start; scroll-margin-top: 0; }
.extra-network-cards .card .overlay { position: absolute; bottom: 0; padding: 0.2em; z-index: 10; width: 100%; background: none; }
.extra-network-cards .card .overlay .name { font-size: 1.1em; font-weight: bold; text-shadow: 1px 1px black; color: white; overflow-wrap: break-word; }
.extra-network-cards .card .overlay .name { text-shadow: 1px 1px black; color: white; overflow-wrap: break-word; }
.extra-network-cards .card .preview { box-shadow: var(--button-shadow); min-height: 30px; }
.extra-network-cards .card:hover .overlay { background: rgba(0, 0, 0, 0.40); }
.extra-network-cards .card:hover .preview { box-shadow: none; filter: grayscale(100%); }

View File

@ -1,7 +1,8 @@
/* generic html tags */
@font-face { font-family: 'NotoSans'; font-display: swap; font-style: normal; font-weight: 100; src: local('NotoSans'), url('notosans-nerdfont-regular.ttf') }
:root, .light, .dark {
--font: "Source Sans Pro", 'ui-sans-serif', 'system-ui', "Roboto", sans-serif;
--font-mono: 'IBM Plex Mono', 'ui-monospace', 'Consolas', monospace;
--font: 'NotoSans';
--font-mono: 'ui-monospace', 'Consolas', monospace;
--font-size: 16px;
--left-column: 490px;
--highlight-color: #ce6400;
@ -18,15 +19,28 @@
--primary-800: #9a3412;
--primary-900: #7c2d12;
--primary-950: #6c2e12;
}
.light, .dark {
--highlight-color: var(--primary-200);
--inactive-color: var(--primary--800);
--body-text-color: var(--neutral-100);
--body-text-color-subdued: var(--neutral-300);
--background-color: #000000;
--background-fill-primary: var(--neutral-700);
--input-padding: 4px;
--radius-lg: 2px;
--radius-sm: 1px;
--input-background-fill: var(--neutral-800);
--input-shadow: 2px 2px 2px 2px var(--background-color);
--button-secondary-text-color: white;
--button-secondary-background-fill: linear-gradient(to bottom right, var(--neutral-400), var(--neutral-700));
--button-secondary-background-fill-hover: linear-gradient(to bottom right, var(--neutral-700), var(--neutral-400));
--block-title-text-color: var(--neutral-300);
--radius-sm: 2px;
--radius-lg: 4px;
--spacing-md: 4px;
--spacing-xxl: 12px;
--line-sm: 1.3em;
--line-md: 1.3em;
--spacing-xxl: 6px;
--line-sm: 1.2em;
--line-md: 1.4em;
--text-sm: 12px;
--text-md: 13px;
--text-lg: 15px;
}
html { font-size: var(--font-size); }
@ -119,9 +133,6 @@ svg.feather.feather-image, .feather .feather-image { display: none }
.gradio-button.tool { filter: hue-rotate(180deg) saturate(0.5); }
#txt2img_styles_row, #img2img_styles_row { margin-top: -6px; z-index: 200; }
/* custom elements overrides */
#steps-animation, #controlnet { border-width: 0; }
/* based on gradio built-in dark theme */
:root, .light, .dark {
--body-background-fill: var(--background-color);
@ -244,9 +255,6 @@ svg.feather.feather-image, .feather .feather-image { display: none }
--radius-xxl: 0;
--text-xxs: 9px;
--text-xs: 10px;
--text-sm: 12px;
--text-md: 14px;
--text-lg: 16px;
--text-xl: 22px;
--text-xxl: 26px;
--body-text-size: var(--text-md);

View File

@ -1,6 +1,7 @@
/* generic html tags */
@font-face { font-family: 'NotoSans'; font-display: swap; font-style: normal; font-weight: 100; src: local('NotoSans'), url('notosans-nerdfont-regular.ttf') }
:root, .light, .dark {
--font: 'system-ui', 'ui-sans-serif', 'system-ui', "Roboto", sans-serif;
--font: 'NotoSans';
--font-mono: 'ui-monospace', 'Consolas', monospace;
--font-size: 16px;
--left-column: 490px;
@ -34,10 +35,13 @@
--spacing-xxl: 6px;
--line-sm: 1.2em;
--line-md: 1.4em;
--text-sm: 12px;
--text-md: 13px;
--text-lg: 15px;
}
html { font-size: var(--font-size); }
body, button, input, select, textarea { font-family: var(--font);}
html { font-size: var(--font-size); font-family: var(--font); }
body, button, input, select, textarea { font-family: var(--font); }
button { font-size: 1.2rem; max-width: 400px; }
img { background-color: var(--background-color); }
input[type=range] { height: var(--line-sm) !important; appearance: none !important; margin-top: 0 !important; min-width: 160px !important;
@ -84,7 +88,8 @@ svg.feather.feather-image, .feather .feather-image { display: none }
.tab-nav { zoom: 120%; margin-top: 10px; margin-bottom: 10px; border-bottom: 2px solid var(--highlight-color) !important; padding-bottom: 2px; }
.label-wrap { margin: 8px 0px 4px 0px; }
.gradio-button.tool { border: none; background: none; box-shadow: none; filter: hue-rotate(340deg) saturate(0.5); }
#tab_extensions table td, #tab_extensions table th, #tab_config table td, #tab_config table th { border: none; padding: 0.5em; }
#tab_extensions table td, #tab_extensions table th, #tab_config table td, #tab_config table th { border: none; }
#tab_extensions table tr:hover, #tab_config table tr:hover { background-color: var(--neutral-500) !important; }
#tab_extensions table, #tab_config table { width: 96vw }
#tab_extensions table thead, #tab_config table thead { background-color: var(--neutral-700); }
#tab_extensions table, #tab_config table { background-color: #222222; }
@ -130,9 +135,6 @@ textarea[rows="1"] { height: 33px !important; width: 99% !important; padding: 8p
#pnginfo_html2_info { margin-top: -18px; background-color: var(--input-background-fill); padding: var(--input-padding) }
#txt2img_styles_row, #img2img_styles_row { margin-top: -6px; }
/* custom elements overrides */
#steps-animation, #controlnet { border-width: 0; }
/* based on gradio built-in dark theme */
:root, .light, .dark {
--body-background-fill: var(--background-color);
@ -246,9 +248,6 @@ textarea[rows="1"] { height: 33px !important; width: 99% !important; padding: 8p
--radius-xxl: 0;
--text-xxs: 9px;
--text-xs: 10px;
--text-sm: 12px;
--text-md: 14px;
--text-lg: 16px;
--text-xl: 22px;
--text-xxl: 26px;
--body-text-size: var(--text-md);

View File

@ -58,7 +58,8 @@ function readCardTags(el, tags) {
e.preventDefault();
e.stopPropagation();
const textarea = activePromptTextarea[getENActiveTab()];
if (textarea.value.indexOf(tag) !== -1) textarea.value = textarea.value.replace(tag, '');
if (textarea.value.indexOf(` ${tag}`) !== -1) textarea.value = textarea.value.replace(` ${tag}`, '');
else if (textarea.value.indexOf(`${tag} `) !== -1) textarea.value = textarea.value.replace(` ${tag} `, '');
else textarea.value += ` ${tag}`;
updateInput(textarea);
};
@ -144,6 +145,38 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text) {
return false;
}
let sortVal = 0;
function sortExtraNetworks() {
const sortDesc = ['Name [A-Z]', 'Name [Z-A]', 'Date [Newest]', 'Date [Oldest]', 'Size [Largest]', 'Size [Smallest]'];
const pagename = getENActivePage();
if (!pagename) return 'sort error: unknown page';
const allPages = Array.from(gradioApp().querySelectorAll('.extra-network-cards'));
const pages = allPages.filter((el) => el.id.includes(pagename.toLowerCase()));
let num = 0;
for (const pg of pages) {
const cards = Array.from(pg.querySelectorAll('.card') || []);
num = cards.length;
if (num === 0) return 'sort: no cards';
cards.sort((a, b) => { // eslint-disable-line no-loop-func
switch (sortVal) {
case 0: return a.dataset.name ? a.dataset.name.localeCompare(b.dataset.name) : 0;
case 1: return b.dataset.name ? b.dataset.name.localeCompare(a.dataset.name) : 0;
case 2: return a.dataset.mtime && !isNaN(a.dataset.mtime) ? parseFloat(b.dataset.mtime) - parseFloat(a.dataset.mtime) : 0;
case 3: return b.dataset.mtime && !isNaN(b.dataset.mtime) ? parseFloat(a.dataset.mtime) - parseFloat(b.dataset.mtime) : 0;
case 4: return a.dataset.size && !isNaN(a.dataset.size) ? parseFloat(b.dataset.size) - parseFloat(a.dataset.size) : 0;
case 5: return b.dataset.size && !isNaN(b.dataset.size) ? parseFloat(a.dataset.size) - parseFloat(b.dataset.size) : 0;
}
return 0;
});
for (const card of cards) pg.appendChild(card);
}
const desc = sortDesc[sortVal];
sortVal = (sortVal + 1) % sortDesc.length;
log('sortExtraNetworks', pagename, num, desc);
return `sort page ${pagename} cards ${num} by ${desc}`;
}
function refreshExtraNetworks(tabname) {
log('refreshExtraNetworks', tabname, gradioApp().querySelector(`#${tabname}_extra_networks textarea`)?.value);
gradioApp().querySelector(`#${tabname}_extra_networks textarea`)?.dispatchEvent(new Event('input'));
@ -195,6 +228,7 @@ function setupExtraNetworksForTab(tabname) {
const btnScan = gradioApp().getElementById(`${tabname}_extra_scan`);
const btnSave = gradioApp().getElementById(`${tabname}_extra_save`);
const btnClose = gradioApp().getElementById(`${tabname}_extra_close`);
const btnSort = gradioApp().getElementById(`${tabname}_extra_sort`);
const btnModel = gradioApp().getElementById(`${tabname}_extra_model`);
const btnApply = gradioApp().getElementById(`${tabname}_extra_apply`);
const buttons = document.createElement('span');
@ -204,6 +238,7 @@ function setupExtraNetworksForTab(tabname) {
if (btnApply) buttons.appendChild(btnApply);
if (btnScan) buttons.appendChild(btnScan);
if (btnSave) buttons.appendChild(btnSave);
if (btnSort) buttons.appendChild(btnSort);
if (btnClose) buttons.appendChild(btnClose);
btnModel.onclick = () => btnModel.classList.toggle('toolbutton-selected');
tabs.appendChild(buttons);

View File

@ -126,9 +126,6 @@ button.selected {background: var(--button-primary-background-fill);}
#txt2img_tools, #img2img_tools { margin-top: -4px; margin-bottom: -4px; }
#txt2img_styles_row, #img2img_styles_row { margin-top: -6px; z-index: 200; }
/* custom elements overrides */
#steps-animation, #controlnet { border-width: 0; }
/* based on gradio built-in dark theme */
:root, .light, .dark {
--body-background-fill: var(--background-color);

View File

@ -1,6 +1,7 @@
/* generic html tags */
@font-face { font-family: 'NotoSans'; font-display: swap; font-style: normal; font-weight: 100; src: local('NotoSans'), url('notosans-nerdfont-regular.ttf') }
:root, .light, .dark {
--font: 'system-ui', 'ui-sans-serif', 'system-ui', "Roboto", sans-serif;
--font: 'NotoSans';
--font-mono: 'ui-monospace', 'Consolas', monospace;
--font-size: 16px;
--left-column: 490px;
@ -34,6 +35,9 @@
--spacing-xxl: 8px;
--line-sm: 1.2em;
--line-md: 1.4em;
--text-sm: 12px;
--text-md: 13px;
--text-lg: 15px;
}
html { font-size: var(--font-size); }
@ -126,9 +130,6 @@ svg.feather.feather-image, .feather .feather-image { display: none }
#pnginfo_html2_info { margin-top: -18px; background-color: var(--input-background-fill); padding: var(--input-padding) }
#txt2img_styles_row, #img2img_styles_row { margin-top: -6px; z-index: 200; }
/* custom elements overrides */
#steps-animation, #controlnet { border-width: 0; }
/* based on gradio built-in dark theme */
:root, .light, .dark {
--background-fill-secondary: none;
@ -308,9 +309,6 @@ svg.feather.feather-image, .feather .feather-image { display: none }
--table-radius: var(--radius-lg);
--table-row-focus: var(--color-accent-soft);
--text-lg: 16px;
--text-md: 14px;
--text-sm: 12px;
--text-xl: 22px;
--text-xs: 10px;
--text-xxl: 26px;
--text-xxs: 9px;

View File

@ -1,6 +1,5 @@
let logMonitorEl = null;
let logMonitorStatus = true;
let jobStatusEl = null;
async function logMonitor() {
if (logMonitorStatus) setTimeout(logMonitor, opts.logmonitor_refresh_period);
@ -52,10 +51,6 @@ async function initLogMonitor() {
</table>
`;
el.style.display = 'none';
jobStatusEl = document.createElement('div');
jobStatusEl.className = 'jobStatus';
jobStatusEl.style.display = 'none';
gradioApp().appendChild(jobStatusEl);
fetch(`/sdapi/v1/start?agent=${encodeURI(navigator.userAgent)}`);
logMonitor();
log('initLogMonitor');

View File

@ -119,9 +119,6 @@ svg.feather.feather-image, .feather .feather-image { display: none }
.gradio-button.tool { filter: hue-rotate(120deg) saturate(0.5); }
#txt2img_styles_row, #img2img_styles_row { margin-top: -6px; z-index: 200; }
/* custom elements overrides */
#steps-animation, #controlnet { border-width: 0; }
/* based on gradio built-in dark theme */
:root, .light, .dark {
--body-background-fill: var(--background-color);

Binary file not shown.

View File

@ -42,24 +42,23 @@ function checkPaused(state) {
function setProgress(res) {
const elements = ['txt2img_generate', 'img2img_generate', 'extras_generate'];
const progress = (res?.progress || 0);
const job = res?.job || '';
const perc = res && (progress > 0) ? `${Math.round(100.0 * progress)}%` : '';
let sec = res?.eta || 0;
let eta = '';
if (res?.paused) eta = 'Paused';
else if (res?.completed || (progress > 0.99)) eta = 'Finishing';
else if (sec === 0) eta = `Init${res?.job?.length > 0 ? `: ${res.job}` : ''}`;
else if (sec === 0) eta = 'Starting';
else {
const min = Math.floor(sec / 60);
sec %= 60;
eta = min > 0 ? `ETA: ${Math.round(min)}m ${Math.round(sec)}s` : `ETA: ${Math.round(sec)}s`;
eta = min > 0 ? `${Math.round(min)}m ${Math.round(sec)}s` : `${Math.round(sec)}s`;
}
document.title = `SD.Next ${perc}`;
for (const elId of elements) {
const el = document.getElementById(elId);
el.innerText = res
? `${perc} ${eta}`
: 'Generate';
el.style.background = res
el.innerText = (res ? `${job} ${perc} ${eta}` : 'Generate');
el.style.background = res && (progress > 0)
? `linear-gradient(to right, var(--primary-500) 0%, var(--primary-800) ${perc}, var(--neutral-700) ${perc})`
: 'var(--button-primary-background-fill)';
}
@ -106,7 +105,6 @@ function requestProgress(id_task, progressEl, galleryEl, atEnd = null, onProgres
debug('taskEnd:', id_task);
localStorage.removeItem('task');
setProgress();
if (jobStatusEl) jobStatusEl.style.display = 'none';
if (parentGallery && livePreview) parentGallery.removeChild(livePreview);
checkPaused(true);
if (atEnd) atEnd();
@ -114,8 +112,6 @@ function requestProgress(id_task, progressEl, galleryEl, atEnd = null, onProgres
const start = (id_task, id_live_preview) => { // eslint-disable-line no-shadow
request('./internal/progress', { id_task, id_live_preview }, (res) => {
if (jobStatusEl) jobStatusEl.innerText = (res?.job || '').trim().toUpperCase();
if (jobStatusEl) jobStatusEl.style.display = jobStatusEl.innerText.length > 0 ? 'block' : 'none';
lastState = res;
const elapsedFromStart = (new Date() - dateStart) / 1000;
hasStarted |= res.active;

View File

@ -1,3 +1,4 @@
@font-face { font-family: 'Roboto'; font-display: swap; font-style: normal; font-weight: 100; src: local('Roboto'), url('roboto.ttf') }
:root { --left-column: 490px; }
a { font-weight: bold; cursor: pointer; }
h2 { margin-top: 1em !important; font-size: 1.4em !important; }
@ -25,7 +26,7 @@ tr { border-bottom: none !important; padding: 0.1em 0.5em !important; }
.gradio-button.secondary-down:hover { background: var(--button-secondary-background-fill-hover); color: var(--button-secondary-text-color-hover); }
.gradio-button.tool { max-width: min-content; min-width: min-content !important; align-self: end; font-size: 1.4em; color: var(--body-text-color) !important; margin-bottom: var(--spacing-md); }
.gradio-checkbox { margin: 0.75em 1.5em 0 0; align-self: center; }
.gradio-column { min-width: unset !important; }
.gradio-column { min-width: unset; }
.gradio-container { max-width: unset !important; padding: var(--block-label-padding) !important; }
.gradio-container .prose a, .gradio-container .prose a:visited{ color: unset; text-decoration: none; }
@ -45,12 +46,12 @@ tr { border-bottom: none !important; padding: 0.1em 0.5em !important; }
/* custom gradio elements */
.accordion-compact { padding: 8px 0px 4px 0px !important; }
.settings-accordion .gap { padding-right: 1000px; }
.settings-accordion >div { flex-flow: wrap; }
.small-accordion .form { min-width: var(--left-column) !important; }
.settings-accordion > div { flex-flow: wrap; }
.small-accordion .form { min-width: var(--left-column) !important; width: max-content; }
.small-accordion .label-wrap .icon { margin-right: 1.6em; margin-left: 0.6em; color: var(--button-primary-border-color); }
.small-accordion .label-wrap { padding: 16px 0px 8px 0px; margin: 0; border-top: 2px solid var(--button-secondary-border-color); }
.small-accordion { width: fit-content !important; padding-left: 0 !important; }
.extension-script { max-width: 50%; }
button.custom-button{ border-radius: var(--button-large-radius); padding: var(--button-large-padding); font-weight: var(--button-large-text-weight); border: var(--button-border-width) solid var(--button-secondary-border-color);
background: var(--button-secondary-background-fill); color: var(--button-secondary-text-color); font-size: var(--button-large-text-size);
display: inline-flex; justify-content: center; align-items: center; transition: var(--button-transition); box-shadow: var(--button-shadow); text-align: center; }
@ -72,7 +73,7 @@ button.custom-button{ border-radius: var(--button-large-radius); padding: var(--
#txt2img_footer, #img2img_footer { height: fit-content; display: none; }
#txt2img_generate_box, #img2img_generate_box { gap: 0.5em; flex-wrap: wrap-reverse; height: fit-content; }
#txt2img_actions_column, #img2img_actions_column { gap: 0.5em; height: fit-content; }
#txt2img_generate_box > button, #img2img_generate_box > button { min-height: 42px; max-height: 42px; }
#txt2img_generate_box > button, #img2img_generate_box > button, #txt2img_enqueue, #img2img_enqueue { min-height: 42px; max-height: 42px; line-height: 1em; }
#txt2img_generate_line2, #img2img_generate_line2, #txt2img_tools, #img2img_tools { display: flex; }
#txt2img_generate_line2 > button, #img2img_generate_line2 > button, #extras_generate_box > button, #txt2img_tools > button, #img2img_tools > button { height: 2em; line-height: 0; font-size: var(--input-text-size);
min-width: unset; display: block !important; margin-left: 0.4em; margin-right: 0.4em; }
@ -96,7 +97,6 @@ div#extras_scale_to_tab div.form{ flex-direction: row; }
width: 22em; min-height: 1.3em; font-size: 0.8em; transition: opacity 0.2s ease-in; pointer-events: none; opacity: 0; z-index: 999; }
.tooltip-show { opacity: 0.9; }
.toolbutton-selected { background: var(--background-fill-primary) !important; }
.jobStatus { position: fixed; bottom: 1em; right: 1em; background: var(--input-background-fill); padding: 0.4em; font-size: 0.8em; color: var(--body-text-color-subdued); }
/* settings */
#si-sparkline-memo, #si-sparkline-load { background-color: #111; }
@ -163,6 +163,7 @@ table.settings-value-table td { padding: 0.4em; border: 1px solid #ccc; max-widt
/* extensions */
#tab_extensions table, #tab_config table{ border-collapse: collapse; }
#tab_extensions table td, #tab_extensions table th, #tab_config table td, #tab_config table th { border: 1px solid #ccc; padding: 0.25em 0.5em; }
#tab_extensions table tr:hover, #tab_config table tr:hover { background-color: var(--neutral-500) !important; }
#tab_extensions table input[type="checkbox"] { margin-right: 0.5em; appearance: checkbox; }
#tab_extensions button{ max-width: 16em; }
#tab_extensions input[disabled="disabled"]{ opacity: 0.5; }
@ -195,9 +196,10 @@ table.settings-value-table td { padding: 0.4em; border: 1px solid #ccc; max-widt
.extra-network-cards .card:hover .overlay { background: rgba(0, 0, 0, 0.40); }
.extra-network-cards .card:hover .preview { box-shadow: none; filter: grayscale(100%); }
.extra-network-cards .card:hover .overlay { background: rgba(0, 0, 0, 0.40); }
.extra-network-cards .card .overlay .tags { margin: 4px; display: none; overflow-wrap: break-word; }
.extra-network-cards .card .overlay .tag { padding: 2px; margin: 2px; background: var(--neutral-700); cursor: pointer; display: inline-block; }
.extra-network-cards .card .overlay .tags { display: none; overflow-wrap: break-word; }
.extra-network-cards .card .overlay .tag { padding: 3px; background: rgba(70, 70, 70, 0.60); font-size: var(--text-lg); cursor: pointer; display: inline-block; margin-bottom: 4px; }
.extra-network-cards .card .actions > span { padding: 4px; }
.extra-network-cards .card .actions > span:hover { color: var(--highlight-color); }
.extra-network-cards .card:hover .actions { display: block; }
.extra-network-cards .card:hover .overlay .tags { display: block; }
.extra-network-cards .card .actions { font-size: 3em; display: none; text-align-last: right; cursor: pointer; font-variant: unicase; position: absolute; z-index: 100; right: 0; height: 0.7em; width: 100%; background: rgba(0, 0, 0, 0.40); }

View File

@ -28,9 +28,9 @@ def init_modules():
parser = modules.cmd_args.parser
installer.add_args(parser)
args, _ = parser.parse_known_args()
import modules.paths_internal
script_path = modules.paths_internal.script_path
extensions_dir = modules.paths_internal.extensions_dir
import modules.paths
script_path = modules.paths.script_path
extensions_dir = modules.paths.extensions_dir
def get_custom_args():

View File

@ -356,68 +356,54 @@ class Api:
def extras_batch_images_api(self, req: models.ExtrasBatchImagesRequest):
reqDict = setUpscalers(req)
image_list = reqDict.pop('imageList', [])
image_folder = [decode_base64_to_image(x.data) for x in image_list]
with self.queue_lock:
result = postprocessing.run_extras(extras_mode=1, image_folder=image_folder, image="", input_dir="", output_dir="", save_output=False, **reqDict)
return models.ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
def pnginfoapi(self, req: models.PNGInfoRequest):
if not req.image.strip():
return models.PNGInfoResponse(info="")
image = decode_base64_to_image(req.image.strip())
if image is None:
return models.PNGInfoResponse(info="")
geninfo, items = images.read_info_from_image(image)
if geninfo is None:
geninfo = ""
items = {**{'parameters': geninfo}, **items}
return models.PNGInfoResponse(info=geninfo, items=items)
def progressapi(self, req: models.ProgressRequest = Depends()):
# copy from check_progress_call of ui.py
if shared.state.job_count == 0:
return models.ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo)
# avoid dividing zero
progress = 0.01
if shared.state.job_count > 0:
progress += shared.state.job_no / shared.state.job_count
if shared.state.sampling_steps > 0:
progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
time_since_start = time.time() - shared.state.time_start
eta = time_since_start / progress
eta_relative = eta-time_since_start
progress = min(progress, 1)
shared.state.set_current_image()
current_image = None
if shared.state.current_image and not req.skip_current_image:
current_image = encode_pil_to_base64(shared.state.current_image)
return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo)
batch_x = max(shared.state.job_no, 0)
batch_y = max(shared.state.job_count, 1)
step_x = max(shared.state.sampling_step, 0)
step_y = max(shared.state.sampling_steps, 1)
current = step_y * batch_x + step_x
total = step_y * batch_y
progress = current / total if total > 0 else 0
time_since_start = time.time() - shared.state.time_start
eta_relative = (time_since_start / progress) - time_since_start
res = models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo)
return res
def interrogateapi(self, interrogatereq: models.InterrogateRequest):
image_b64 = interrogatereq.image
if image_b64 is None:
raise HTTPException(status_code=404, detail="Image not found")
img = decode_base64_to_image(image_b64)
img = img.convert('RGB')
# Override object param
with self.queue_lock:
if interrogatereq.model == "clip":
processed = shared.interrogator.interrogate(img)
@ -425,7 +411,6 @@ class Api:
processed = deepbooru.model.tag(img)
else:
raise HTTPException(status_code=404, detail="Model not found")
return models.InterrogateResponse(caption=processed)
def interruptapi(self):
@ -473,18 +458,8 @@ class Api:
def get_sd_vaes(self):
return [{"model_name": x, "filename": vae_dict[x]} for x in vae_dict.keys()]
def get_upscalers(self):
return [
{
"name": upscaler.name,
"model_name": upscaler.scaler.model_name,
"model_path": upscaler.data_path,
"model_url": None,
"scale": upscaler.scale,
}
for upscaler in shared.sd_upscalers
]
return [{"name": upscaler.name, "model_name": upscaler.scaler.model_name, "model_path": upscaler.data_path, "model_url": None, "scale": upscaler.scale} for upscaler in shared.sd_upscalers]
def get_sd_models(self):
return [{"title": x.title, "name": x.name, "filename": x.filename, "type": x.type, "hash": x.shorthash, "sha256": x.sha256, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()]
@ -500,23 +475,13 @@ class Api:
def get_embeddings(self):
db = sd_hijack.model_hijack.embedding_db
def convert_embedding(embedding):
return {
"step": embedding.step,
"sd_checkpoint": embedding.sd_checkpoint,
"sd_checkpoint_name": embedding.sd_checkpoint_name,
"shape": embedding.shape,
"vectors": embedding.vectors,
}
return {"step": embedding.step, "sd_checkpoint": embedding.sd_checkpoint, "sd_checkpoint_name": embedding.sd_checkpoint_name, "shape": embedding.shape, "vectors": embedding.vectors}
def convert_embeddings(embeddings):
return {embedding.name: convert_embedding(embedding) for embedding in embeddings.values()}
return {
"loaded": convert_embeddings(db.word_embeddings),
"skipped": convert_embeddings(db.skipped_embeddings),
}
return {"loaded": convert_embeddings(db.word_embeddings), "skipped": convert_embeddings(db.skipped_embeddings)}
def get_extra_networks(self, page: Optional[str] = None, name: Optional[str] = None, filename: Optional[str] = None, title: Optional[str] = None, fullname: Optional[str] = None, hash: Optional[str] = None): # pylint: disable=redefined-builtin
res = []
@ -553,7 +518,7 @@ class Api:
def create_embedding(self, args: dict):
try:
shared.state.begin('api-create-embedding')
shared.state.begin('api-embedding')
filename = create_embedding(**args) # create empty embedding
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
shared.state.end()
@ -564,7 +529,7 @@ class Api:
def create_hypernetwork(self, args: dict):
try:
shared.state.begin('api-create-hypernetwork')
shared.state.begin('api-hypernetwork')
filename = create_hypernetwork(**args) # create empty embedding # pylint: disable=E1111
shared.state.end()
return models.CreateResponse(info = f"create hypernetwork filename: {filename}")
@ -590,7 +555,7 @@ class Api:
def train_embedding(self, args: dict):
try:
shared.state.begin('api-train-embedding')
shared.state.begin('api-embedding')
apply_optimizations = False
error = None
filename = ''
@ -611,7 +576,7 @@ class Api:
def train_hypernetwork(self, args: dict):
try:
shared.state.begin('api-train-hypernetwork')
shared.state.begin('api-hypernetwork')
shared.loaded_hypernetworks = []
apply_optimizations = False
error = None

View File

@ -1,6 +1,6 @@
import os
import argparse
from modules.paths_internal import data_path
from modules.paths import data_path
parser = argparse.ArgumentParser(description="SD.Next", conflict_handler='resolve', epilog='For other options see UI Settings page', prog='', add_help=True, formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=55, indent_increment=2, width=200))
parser._optionals = parser.add_argument_group('Other options') # pylint: disable=protected-access

View File

@ -2,7 +2,7 @@ import os
from datetime import datetime
import git
from modules import shared, errors
from modules.paths_internal import extensions_dir, extensions_builtin_dir
from modules.paths import extensions_dir, extensions_builtin_dir
extensions = []

View File

@ -54,7 +54,7 @@ def to_half(tensor, enable):
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata): # pylint: disable=unused-argument
shared.state.begin('model-merge')
shared.state.begin('merge')
save_as_half = save_as_half == 0
def fail(message):
@ -319,7 +319,7 @@ def run_modelconvert(model, checkpoint_formats, precision, conv_type, custom_nam
"vae": vae_conv,
"other": others_conv
}
shared.state.begin('model-convert')
shared.state.begin('convert')
model_info = sd_models.checkpoints_list[model]
shared.state.textinfo = f"Loading {model_info.filename}..."
shared.log.info(f"Model convert loading: {model_info.filename}")

View File

@ -69,7 +69,7 @@ def sha256(filename, title, use_addnet_hash=False):
if not os.path.isfile(filename):
return None
orig_state = copy.deepcopy(shared.state)
shared.state.begin("hashing")
shared.state.begin("hash")
if use_addnet_hash:
if progress_ok:
try:

View File

@ -460,7 +460,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
hypernetwork.load(path)
shared.loaded_hypernetworks = [hypernetwork]
shared.state.job = "train-hypernetwork"
shared.state.job = "train"
shared.state.textinfo = "Initializing hypernetwork training..."
shared.state.job_count = steps

View File

@ -135,9 +135,9 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
def get_font(fontsize):
try:
return ImageFont.truetype(shared.opts.font or 'html/roboto.ttf', fontsize)
return ImageFont.truetype(shared.opts.font or 'javascript/roboto.ttf', fontsize)
except Exception:
return ImageFont.truetype('html/roboto.ttf', fontsize)
return ImageFont.truetype('javascript/roboto.ttf', fontsize)
def draw_texts(drawing: ImageDraw, draw_x, draw_y, lines, initial_fnt, initial_fontsize):
for line in lines:
@ -553,13 +553,11 @@ def save_image(image, path, basename = '', seed=None, prompt=None, extension=sha
return None, None
if not check_grid_size([image]):
return None, None
if path is None or len(path) == 0:
if path is None or len(path) == 0: # set default path to avoid errors when functions are triggered manually or via api and param is not set
path = shared.opts.outdir_save
# namegen
namegen = FilenameGenerator(p, seed, prompt, image, grid=grid)
if shared.opts.save_to_dirs:
dirname = namegen.apply(shared.opts.directories_filename_pattern or "[date]")
dirname = namegen.apply(shared.opts.directories_filename_pattern or "[prompt_words]")
path = os.path.join(path, dirname)
if forced_filename is None:
if short_filename or seed is None:
@ -567,11 +565,10 @@ def save_image(image, path, basename = '', seed=None, prompt=None, extension=sha
if shared.opts.samples_filename_pattern and len(shared.opts.samples_filename_pattern) > 0:
file_decoration = shared.opts.samples_filename_pattern
else:
file_decoration = "[seq]-[model_name]-[prompt_words]"
file_decoration = "[seq]-[prompt_words]"
file_decoration = namegen.apply(file_decoration)
filename = os.path.join(path, f"{file_decoration}{suffix}.{extension}") if basename is None or basename == '' else os.path.join(path, f"{basename}-{file_decoration}{suffix}.{extension}")
else:
filename = os.path.join(path, f"{forced_filename}.{extension}")
file_decoration += suffix
filename = os.path.join(path, f"{file_decoration}.{extension}") if basename == '' else os.path.join(path, f"{basename}-{file_decoration}.{extension}")
pnginfo = existing_info or {}
if info is not None:
pnginfo[pnginfo_section_name] = info
@ -579,7 +576,6 @@ def save_image(image, path, basename = '', seed=None, prompt=None, extension=sha
params.filename = namegen.sanitize(filename)
dirname = os.path.dirname(params.filename)
os.makedirs(dirname, exist_ok=True)
# sequence
if shared.opts.save_images_add_number or '[seq]' in params.filename:
if '[seq]' not in params.filename:
@ -592,7 +588,6 @@ def save_image(image, path, basename = '', seed=None, prompt=None, extension=sha
debug(f'Prompt sequence: input="{params.filename}" seq={seq} output="{filename}"')
params.filename = filename
break
# callbacks
script_callbacks.before_image_saved_callback(params)
exifinfo = params.pnginfo.get('UserComment', '')

View File

@ -40,7 +40,6 @@ def process_batch(p, input_files, input_dir, output_dir, inpaint_mask_dir, args)
btcrept = p.batch_size
shared.log.info(f"Process batch: inputs={len(image_files)} outputs={p.n_iter * p.batch_size} per input")
for i in range(0, len(image_files), window_size):
shared.state.job = f"{i+1} to {min(i+window_size, len(image_files))} out of {len(image_files)}"
if shared.state.skipped:
shared.state.skipped = False
if shared.state.interrupted:

1
modules/k-diffusion Submodule

@ -0,0 +1 @@
Subproject commit 045515774882014cc14c1ba2668ab5bad9cbf7c0

View File

@ -85,7 +85,7 @@ def download_civit_preview(model_path: str, preview_url: str):
block_size = 16384 # 16KB blocks
written = 0
img = None
shared.state.begin('civitai-download-preview')
shared.state.begin('civitai')
try:
with open(preview_file, 'wb') as f:
with p.Progress(p.TextColumn('[cyan]{task.description}'), p.DownloadColumn(), p.BarColumn(), p.TaskProgressColumn(), p.TimeRemainingColumn(), p.TimeElapsedColumn(), p.TransferSpeedColumn(), console=shared.console) as progress:
@ -142,7 +142,7 @@ def download_civit_model_thread(model_name, model_url, model_path, model_type, p
total_size = int(r.headers.get('content-length', 0))
res += f' size={round((starting_pos + total_size)/1024/1024)}Mb'
shared.log.info(res)
shared.state.begin('civitai-download-model')
shared.state.begin('civitai')
block_size = 16384 # 16KB blocks
written = starting_pos
global download_pbar # pylint: disable=global-statement
@ -188,7 +188,7 @@ def download_diffusers_model(hub_id: str, cache_dir: str = None, download_config
return None
from diffusers import DiffusionPipeline
import huggingface_hub as hf
shared.state.begin('huggingface-download-model')
shared.state.begin('huggingface')
if download_config is None:
download_config = {
"force_download": False,

View File

@ -1,9 +1,43 @@
# this module must not have any dependencies as it first import
import os
import sys
from modules import paths_internal, errors
import json
import argparse
from modules.errors import log
# parse args, parse again after we have the data-dir and early-read the config file
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument("--ckpt", type=str, default=os.environ.get("SD_MODEL", None), help="Path to model checkpoint to load immediately, default: %(default)s")
parser.add_argument("--data-dir", type=str, default=os.environ.get("SD_DATADIR", ''), help="Base path where all user data is stored, default: %(default)s")
parser.add_argument("--models-dir", type=str, default=os.environ.get("SD_MODELSDIR", None), help="Base path where all models are stored, default: %(default)s",)
cli = parser.parse_known_args()[0]
parser.add_argument("--config", type=str, default=os.environ.get("SD_CONFIG", os.path.join(cli.data_dir, 'config.json')), help="Use specific server configuration file, default: %(default)s")
cli = parser.parse_known_args()[0]
config_path = cli.config if os.path.isabs(cli.config) else os.path.join(cli.data_dir, cli.config)
try:
with open(config_path, 'r', encoding='utf8') as f:
config = json.load(f)
except Exception as err:
print(f'Error loading config file: ${config_path} {err}')
config = {}
debug = errors.log.info if os.environ.get('SD_PATH_DEBUG', None) is not None else lambda *args, **kwargs: None
modules_path = os.path.dirname(os.path.realpath(__file__))
script_path = os.path.dirname(modules_path)
data_path = cli.data_dir
models_config = cli.models_dir or config.get('models_dir') or 'models'
models_path = models_config if os.path.isabs(models_config) else os.path.join(data_path, models_config)
extensions_dir = os.path.join(data_path, "extensions")
extensions_builtin_dir = "extensions-builtin"
sd_configs_path = os.path.join(script_path, "configs")
sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml")
sd_model_file = cli.ckpt or os.path.join(script_path, 'model.ckpt') # not used
default_sd_model_file = sd_model_file # not used
debug = log.info if os.environ.get('SD_PATH_DEBUG', None) is not None else lambda *args, **kwargs: None
if os.environ.get('SD_PATH_DEBUG', None) is not None:
print(f'Paths: script-path="{script_path}" data-dir="{data_path}" models-dir="{models_path}" config="{config_path}"')
"""
data_path = paths_internal.data_path
script_path = paths_internal.script_path
models_path = paths_internal.models_path
@ -13,26 +47,17 @@ sd_model_file = paths_internal.sd_model_file
default_sd_model_file = paths_internal.default_sd_model_file
extensions_dir = paths_internal.extensions_dir
extensions_builtin_dir = paths_internal.extensions_builtin_dir
"""
# data_path = cmd_opts_pre.data
sys.path.insert(0, script_path)
# search for directory of stable diffusion in following places
sd_path = None
possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion-stability-ai'), '.', os.path.dirname(script_path)]
for possible_sd_path in possible_sd_paths:
if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')):
sd_path = os.path.abspath(possible_sd_path)
break
assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possible_sd_paths}"
sd_path = os.path.join(script_path, 'repositories')
path_dirs = [
(sd_path, 'ldm', 'Stable Diffusion', []),
(os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers', []),
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
(sd_path, 'ldm', 'ldm', []),
(sd_path, 'taming', 'Taming Transformers', []),
(os.path.join(sd_path, 'blip'), 'models/blip.py', 'BLIP', []),
(os.path.join(sd_path, 'codeformer'), 'inference_codeformer.py', 'CodeFormer', []),
(os.path.join('modules', 'k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
]
paths = {}
@ -40,25 +65,26 @@ paths = {}
for d, must_exist, what, _options in path_dirs:
must_exist_path = os.path.abspath(os.path.join(script_path, d, must_exist))
if not os.path.exists(must_exist_path):
errors.log.error(f'Required path not found: path={must_exist_path} item={what}')
log.error(f'Required path not found: path={must_exist_path} item={what}')
else:
d = os.path.abspath(d)
sys.path.append(d)
paths[what] = d
def create_paths(opts):
def create_path(folder):
if folder is None or folder == '':
return
if os.path.exists(folder):
return
try:
os.makedirs(folder, exist_ok=True)
errors.log.info(f'Create folder={folder}')
except Exception as e:
errors.log.error(f'Create Failed folder={folder} {e}')
def create_path(folder):
if folder is None or folder == '':
return
if os.path.exists(folder):
return
try:
os.makedirs(folder, exist_ok=True)
log.info(f'Create folder={folder}')
except Exception as e:
log.error(f'Create Failed folder={folder} {e}')
def create_paths(opts):
def fix_path(folder):
tgt = opts.data.get(folder, None) or opts.data_labels[folder].default
if tgt is None or tgt == '':

View File

@ -1,5 +1,8 @@
"""this module defines internal paths used by program and is safe to import before dependencies are installed in launch.py"""
# no longer used, all paths are defined in paths.py
from modules.paths import modules_path, script_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, data_path, models_path, extensions_dir, extensions_builtin_dir # pylint: disable=unused-import
"""
import argparse
import os
@ -7,15 +10,21 @@ modules_path = os.path.dirname(os.path.realpath(__file__))
script_path = os.path.dirname(modules_path)
sd_configs_path = os.path.join(script_path, "configs")
sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml")
sd_model_file = os.path.join(script_path, 'model.ckpt')
default_sd_model_file = sd_model_file
# Parse the --data-dir flag first so we can use it as a base for our other argument default values
parser_pre = argparse.ArgumentParser(add_help=False)
parser_pre.add_argument("--data-dir", type=str, default="", help="base path where all user data is stored", )
parser_pre.add_argument("--models-dir", type=str, default="models", help="base path where all models are stored",)
parser_pre.add_argument("--ckpt", type=str, default=os.environ.get("SD_MODEL", None), help="Path to model checkpoint to load immediately, default: %(default)s")
parser_pre.add_argument("--data-dir", type=str, default=os.environ.get("SD_DATADIR", ''), help="Base path where all user data is stored, default: %(default)s")
parser_pre.add_argument("--models-dir", type=str, default=os.environ.get("SD_MODELSDIR", 'models'), help="Base path where all models are stored, default: %(default)s",)
cmd_opts_pre = parser_pre.parse_known_args()[0]
# parser_pre.add_argument("--config", type=str, default=os.environ.get("SD_CONFIG", os.path.join(data_path, 'config.json')), help="Use specific server configuration file, default: %(default)s")
data_path = cmd_opts_pre.data_dir
models_path = cmd_opts_pre.models_dir if os.path.isabs(cmd_opts_pre.models_dir) else os.path.join(data_path, cmd_opts_pre.models_dir)
extensions_dir = os.path.join(data_path, "extensions")
extensions_builtin_dir = "extensions-builtin"
sd_model_file = cmd_opts_pre.ckpt or os.path.join(script_path, 'model.ckpt') # not used
default_sd_model_file = sd_model_file # not used
"""

View File

@ -80,6 +80,22 @@ def create_binary_mask(image):
return image
def images_tensor_to_samples(image, approximation=None, model=None):
if model is None:
model = shared.sd_model
model.first_stage_model.to(devices.dtype_vae)
image = image.to(shared.device, dtype=devices.dtype_vae)
image = image * 2 - 1
if len(image) > 1:
x_latent = torch.stack([
model.get_first_stage_encoding(model.encode_first_stage(torch.unsqueeze(img, 0)))[0]
for img in image
])
else:
x_latent = model.get_first_stage_encoding(model.encode_first_stage(image))
return x_latent
def txt2img_image_conditioning(sd_model, x, width, height):
if sd_model.model.conditioning_key in {'hybrid', 'concat'}: # Inpainting models
# The "masked-image" in this case will just be all zeros since the entire image is masked.
@ -450,6 +466,8 @@ def decode_first_stage(model, x, full_quality=True):
shared.log.debug(f'Decode VAE: skipped={shared.state.skipped} interrupted={shared.state.interrupted}')
x_sample = torch.zeros((len(x), 3, x.shape[2] * 8, x.shape[3] * 8), dtype=devices.dtype_vae, device=devices.device)
return x_sample
prev_job = shared.state.job
shared.state.job = 'vae'
with devices.autocast(disable = x.dtype==devices.dtype_vae):
try:
if full_quality:
@ -467,6 +485,7 @@ def decode_first_stage(model, x, full_quality=True):
except Exception as e:
x_sample = x
shared.log.error(f'Decode VAE: {e}')
shared.state.job = prev_job
return x_sample
@ -777,12 +796,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
return ''
ema_scope_context = p.sd_model.ema_scope if shared.backend == shared.Backend.ORIGINAL else nullcontext
shared.state.job_count = p.n_iter
with devices.inference_context(), ema_scope_context():
t0 = time.time()
with devices.autocast():
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
if shared.state.job_count == -1:
shared.state.job_count = p.n_iter
extra_network_data = None
for n in range(p.n_iter):
p.iteration = n
@ -814,8 +832,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
step_multiplier = 1
sampler_config = modules.sd_samplers.find_sampler_config(p.sampler_name)
step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
if p.n_iter > 1:
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
if shared.backend == shared.Backend.ORIGINAL:
uc = get_conds_with_caching(modules.prompt_parser.get_learned_conditioning, p.negative_prompts, p.steps * step_multiplier, cached_uc)
@ -921,7 +937,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
output_images.append(image_mask_composite)
del x_samples_ddim
devices.torch_gc()
shared.state.nextjob()
t1 = time.time()
shared.log.info(f'Processed: images={len(output_images)} time={t1 - t0:.2f}s its={(p.steps * len(output_images)) / (t1 - t0):.2f} memory={modules.memstats.memory_stats()}')
@ -1044,12 +1059,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.is_hr_pass = False
return
self.is_hr_pass = True
if not shared.state.processing_has_refined_job_count:
if shared.state.job_count == -1:
shared.state.job_count = self.n_iter
shared.state.job_count = shared.state.job_count * 2
shared.state.processing_has_refined_job_count = True
hypertile_set(self, hr=True)
shared.state.job_count = 2 * self.n_iter
shared.log.debug(f'Init hires: upscaler="{self.hr_upscaler}" sampler="{self.latent_sampler}" resize={self.hr_resize_x}x{self.hr_resize_y} upscale={self.hr_upscale_to_x}x{self.hr_upscale_to_y}')
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
@ -1069,11 +1080,13 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.sampler.initialize(self)
x = create_random_tensors([4, self.height // 8, self.width // 8], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
shared.state.nextjob()
if not self.enable_hr or shared.state.interrupted or shared.state.skipped:
return samples
self.init_hr()
if self.is_hr_pass:
prev_job = shared.state.job
target_width = self.hr_upscale_to_x
target_height = self.hr_upscale_to_y
decoded_samples = None
@ -1091,6 +1104,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.extra_generation_params, self.restore_faces = bak_extra_generation_params, bak_restore_faces
images.save_image(image, self.outpath_samples, "", seeds[i], prompts[i], shared.opts.samples_format, info=info, suffix="-before-hires")
if latent_scale_mode is None or self.hr_force: # non-latent upscaling
shared.state.job = 'upscale'
if decoded_samples is None:
decoded_samples = decode_first_stage(self.sd_model, samples.to(dtype=devices.dtype_vae), self.full_quality)
decoded_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
@ -1120,6 +1134,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
if self.latent_sampler == "PLMS":
self.latent_sampler = 'UniPC'
if self.hr_force or latent_scale_mode is not None:
shared.state.job = 'hires'
if self.denoising_strength > 0:
self.ops.append('hires')
devices.torch_gc() # GC now before running the next img2img to prevent running out of memory
@ -1135,8 +1150,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
else:
self.ops.append('upscale')
x = None
shared.state.nextjob()
self.is_hr_pass = False
shared.state.job = prev_job
shared.state.nextjob()
return samples
@ -1301,6 +1317,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
samples = samples * self.nmask + self.init_latent * self.mask
del x
devices.torch_gc()
shared.state.nextjob()
return samples
def get_token_merging_ratio(self, for_hr=False):

View File

@ -63,14 +63,6 @@ def process_diffusers(p: StableDiffusionProcessing, seeds, prompts, negative_pro
def diffusers_callback(step: int, _timestep: int, latents: torch.FloatTensor):
shared.state.sampling_step = step
if p.is_hr_pass:
shared.state.job = 'hires'
shared.state.sampling_steps = p.hr_second_pass_steps # add optional hires
elif p.is_refiner_pass:
shared.state.job = 'refine'
shared.state.sampling_steps = calculate_refiner_steps() # add optional refiner
else:
shared.state.sampling_steps = p.steps # base steps
shared.state.current_latent = latents
if shared.state.interrupted or shared.state.skipped:
raise AssertionError('Interrupted...')
@ -133,6 +125,8 @@ def process_diffusers(p: StableDiffusionProcessing, seeds, prompts, negative_pro
return encoded
def vae_decode(latents, model, output_type='np', full_quality=True):
prev_job = shared.state.job
shared.state.job = 'vae'
if not torch.is_tensor(latents): # already decoded
return latents
if latents.shape[0] == 0:
@ -150,6 +144,7 @@ def process_diffusers(p: StableDiffusionProcessing, seeds, prompts, negative_pro
else:
decoded = taesd_vae_decode(latents=latents)
imgs = model.image_processor.postprocess(decoded, output_type=output_type)
shared.state.job = prev_job
return imgs
def vae_encode(image, model, full_quality=True): # pylint: disable=unused-variable
@ -186,16 +181,17 @@ def process_diffusers(p: StableDiffusionProcessing, seeds, prompts, negative_pro
def task_specific_kwargs(model):
task_args = {}
if sd_models.get_diffusers_task(model) == sd_models.DiffusersTaskType.TEXT_2_IMAGE:
is_img2img_model = bool("Zero123" in shared.sd_model.__class__.__name__)
if sd_models.get_diffusers_task(model) == sd_models.DiffusersTaskType.TEXT_2_IMAGE and not is_img2img_model:
p.ops.append('txt2img')
task_args = {"height": 8 * math.ceil(p.height / 8), "width": 8 * math.ceil(p.width / 8)}
elif sd_models.get_diffusers_task(model) == sd_models.DiffusersTaskType.IMAGE_2_IMAGE and len(getattr(p, 'init_images' ,[])) > 0:
elif (sd_models.get_diffusers_task(model) == sd_models.DiffusersTaskType.IMAGE_2_IMAGE or is_img2img_model) and len(getattr(p, 'init_images' ,[])) > 0:
p.ops.append('img2img')
task_args = {"image": p.init_images, "strength": p.denoising_strength}
elif sd_models.get_diffusers_task(model) == sd_models.DiffusersTaskType.INSTRUCT and len(getattr(p, 'init_images' ,[])) > 0:
p.ops.append('instruct')
task_args = {"height": 8 * math.ceil(p.height / 8), "width": 8 * math.ceil(p.width / 8), "image": p.init_images, "strength": p.denoising_strength}
elif sd_models.get_diffusers_task(model) == sd_models.DiffusersTaskType.INPAINTING and len(getattr(p, 'init_images' ,[])) > 0:
elif (sd_models.get_diffusers_task(model) == sd_models.DiffusersTaskType.INPAINTING or is_img2img_model) and len(getattr(p, 'init_images' ,[])) > 0:
p.ops.append('inpaint')
if getattr(p, 'mask', None) is None:
p.mask = TF.to_pil_image(torch.ones_like(TF.to_tensor(p.init_images[0]))).convert("L")
@ -388,6 +384,7 @@ def process_diffusers(p: StableDiffusionProcessing, seeds, prompts, negative_pro
clip_skip=p.clip_skip,
desc='Base',
)
shared.state.sampling_steps = base_args['num_inference_steps']
p.extra_generation_params['CFG rescale'] = p.diffusers_guidance_rescale
p.extra_generation_params["Sampler Eta"] = shared.opts.scheduler_eta if shared.opts.scheduler_eta is not None and shared.opts.scheduler_eta > 0 and shared.opts.scheduler_eta < 1 else None
try:
@ -403,6 +400,7 @@ def process_diffusers(p: StableDiffusionProcessing, seeds, prompts, negative_pro
if hasattr(shared.sd_model, 'embedding_db') and len(shared.sd_model.embedding_db.embeddings_used) > 0:
p.extra_generation_params['Embeddings'] = ', '.join(shared.sd_model.embedding_db.embeddings_used)
shared.state.nextjob()
if shared.state.interrupted or shared.state.skipped:
return results
@ -412,10 +410,12 @@ def process_diffusers(p: StableDiffusionProcessing, seeds, prompts, negative_pro
latent_scale_mode = shared.latent_upscale_modes.get(p.hr_upscaler, None) if (hasattr(p, "hr_upscaler") and p.hr_upscaler is not None) else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "None")
if p.is_hr_pass:
p.init_hr()
prev_job = shared.state.job
if p.width != p.hr_upscale_to_x or p.height != p.hr_upscale_to_y:
p.ops.append('upscale')
if shared.opts.save and not p.do_not_save_samples and shared.opts.save_images_before_highres_fix and hasattr(shared.sd_model, 'vae'):
save_intermediate(latents=output.images, suffix="-before-hires")
shared.state.job = 'upscale'
output.images = hires_resize(latents=output.images)
if latent_scale_mode is not None or p.hr_force:
p.ops.append('hires')
@ -438,15 +438,22 @@ def process_diffusers(p: StableDiffusionProcessing, seeds, prompts, negative_pro
strength=p.denoising_strength,
desc='Hires',
)
shared.state.job = 'hires'
shared.state.sampling_steps = hires_args['num_inference_steps']
try:
output = shared.sd_model(**hires_args) # pylint: disable=not-callable
except AssertionError as e:
shared.log.info(e)
p.init_images = []
shared.state.job = prev_job
shared.state.nextjob()
p.is_hr_pass = False
# optional refiner pass or decode
if is_refiner_enabled:
prev_job = shared.state.job
shared.state.job = 'refine'
shared.state.job_count +=1
if shared.opts.save and not p.do_not_save_samples and shared.opts.save_images_before_refiner and hasattr(shared.sd_model, 'vae'):
save_intermediate(latents=output.images, suffix="-before-refiner")
if shared.opts.diffusers_move_base and not getattr(shared.sd_model, 'has_accelerate', False):
@ -491,6 +498,7 @@ def process_diffusers(p: StableDiffusionProcessing, seeds, prompts, negative_pro
clip_skip=p.clip_skip,
desc='Refiner',
)
shared.state.sampling_steps = refiner_args['num_inference_steps']
try:
refiner_output = shared.sd_refiner(**refiner_args) # pylint: disable=not-callable
except AssertionError as e:
@ -505,7 +513,9 @@ def process_diffusers(p: StableDiffusionProcessing, seeds, prompts, negative_pro
shared.log.debug('Moving to CPU: model=refiner')
shared.sd_refiner.to(devices.cpu)
devices.torch_gc()
p.is_refiner_pass = True
shared.state.job = prev_job
shared.state.nextjob()
p.is_refiner_pass = False
# final decode since there is no refiner
if not is_refiner_enabled:

View File

@ -66,15 +66,20 @@ def progressapi(req: ProgressRequest):
paused = shared.state.paused
if not active:
return InternalProgressResponse(job=shared.state.job, active=active, queued=queued, paused=paused, completed=completed, id_live_preview=-1, textinfo="Queued..." if queued else "Waiting...")
progress = 0
if shared.state.job_count > 0:
progress += shared.state.job_no / shared.state.job_count
if shared.state.sampling_steps > 0 and shared.state.job_count > 0:
progress += 1 / (shared.state.job_count / 2 if shared.state.processing_has_refined_job_count else 1) * shared.state.sampling_step / shared.state.sampling_steps
progress = min(progress, 1)
if shared.state.job_no > shared.state.job_count:
shared.state.job_count = shared.state.job_no
batch_x = max(shared.state.job_no, 0)
batch_y = max(shared.state.job_count, 1)
step_x = max(shared.state.sampling_step, 0)
step_y = max(shared.state.sampling_steps, 1)
current = step_y * batch_x + step_x
total = step_y * batch_y
progress = min(1, current / total if total > 0 else 0)
elapsed_since_start = time.time() - shared.state.time_start
predicted_duration = elapsed_since_start / progress if progress > 0 else None
eta = predicted_duration - elapsed_since_start if predicted_duration is not None else None
id_live_preview = req.id_live_preview
live_preview = None
shared.state.set_current_image()
@ -83,4 +88,6 @@ def progressapi(req: ProgressRequest):
shared.state.current_image.save(buffered, format='jpeg')
live_preview = f'data:image/jpeg;base64,{base64.b64encode(buffered.getvalue()).decode("ascii")}'
id_live_preview = shared.state.id_live_preview
return InternalProgressResponse(job=shared.state.job, active=active, queued=queued, paused=paused, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo)
res = InternalProgressResponse(job=shared.state.job, active=active, queued=queued, paused=paused, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo)
return res

View File

@ -321,6 +321,7 @@ class ScriptRunner:
self.paste_field_names = []
self.script_load_ctr = 0
self.is_img2img = False
self.inputs = [None]
def initialize_scripts(self, is_img2img):
from modules import scripts_auto_postprocessing
@ -355,6 +356,31 @@ class ScriptRunner:
except Exception as e:
log.error(f'Script initialize: {path} {e}')
def create_script_ui(self, script):
import modules.api.models as api_models
script.args_from = len(self.inputs)
script.args_to = len(self.inputs)
controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
if controls is None:
return
script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower()
api_args = []
for control in controls:
control.custom_script_source = os.path.basename(script.filename)
arg_info = api_models.ScriptArg(label=control.label or "")
for field in ("value", "minimum", "maximum", "step", "choices"):
v = getattr(control, field, None)
if v is not None:
setattr(arg_info, field, v)
api_args.append(arg_info)
script.api_info = api_models.ScriptInfo(name=script.name, is_img2img=script.is_img2img, is_alwayson=script.alwayson, args=api_args)
if script.infotext_fields is not None:
self.infotext_fields += script.infotext_fields
if script.paste_field_names is not None:
self.paste_field_names += script.paste_field_names
self.inputs += controls
script.args_to = len(self.inputs)
def setup_ui_for_section(self, section, scriptlist=None):
if scriptlist is None:
scriptlist = self.alwayson_scripts
@ -377,7 +403,7 @@ class ScriptRunner:
inputs = []
inputs_alwayson = [True]
def create_script_ui(script, inputs, inputs_alwayson):
def create_script_ui(script, inputs, inputs_alwayson): # TODO this is legacy implementation, see self.create_script_ui
script.args_from = len(inputs)
script.args_to = len(inputs)
controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
@ -445,7 +471,7 @@ class ScriptRunner:
for script in self.alwayson_scripts:
t0 = time.time()
elem_id = f'script_{"txt2img" if script.is_txt2img else "img2img"}_{script.title().lower().replace(" ", "_")}'
with gr.Group(elem_id=elem_id) as group:
with gr.Group(elem_id=elem_id, elem_classes=['extension-script']) as group:
create_script_ui(script, inputs, inputs_alwayson)
script.group = group
time_setup[script.title()] = time_setup.get(script.title(), 0) + (time.time()-t0)

View File

@ -88,7 +88,7 @@ class ScriptPostprocessingRunner:
def setup_ui(self):
inputs = []
for script in self.scripts_in_preferred_order():
with gr.Accordion(label=script.name, open=False) as group:
with gr.Accordion(label=script.name, open=False, elem_classes=['postprocess']) as group:
self.create_script_ui(script, inputs)
script.group = group
self.ui_created = True

View File

@ -317,7 +317,7 @@ def get_xformers_flash_attention_op(q, k, v):
return None
try:
flash_attention_op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp
flash_attention_op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp # pylint: disable=used-before-assignment
fw, _bw = flash_attention_op
if fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)):
return flash_attention_op

View File

@ -24,7 +24,7 @@ from modules import paths, shared, shared_items, shared_state, modelloader, devi
from modules.sd_hijack_inpainting import do_inpainting_hijack
from modules.timer import Timer
from modules.memstats import memory_stats
from modules.paths_internal import models_path, script_path
from modules.paths import models_path, script_path
try:
import diffusers
@ -848,6 +848,8 @@ def load_diffuser(checkpoint_info=None, already_loaded_state_dict=None, timer=No
vae = sd_vae.load_vae_diffusers(checkpoint_info.path, vae_file, vae_source)
if vae is not None:
diffusers_load_config["vae"] = vae
if 'LCM' in checkpoint_info.path:
diffusers_load_config['custom_pipeline'] = 'latent_consistency_txt2img'
if os.path.isdir(checkpoint_info.path):
err1 = None
@ -858,18 +860,21 @@ def load_diffuser(checkpoint_info=None, already_loaded_state_dict=None, timer=No
sd_model.model_type = sd_model.__class__.__name__
except Exception as e:
err1 = e
# shared.log.error(f'AutoPipeline: {e}')
try: # try diffusion pipeline next second-best choice, works for most non-linked pipelines
if err1 is not None:
sd_model = diffusers.DiffusionPipeline.from_pretrained(checkpoint_info.path, cache_dir=shared.opts.diffusers_dir, **diffusers_load_config)
sd_model.model_type = sd_model.__class__.__name__
except Exception as e:
err2 = e
# shared.log.error(f'DiffusionPipeline: {e}')
try: # try basic pipeline next just in case
if err2 is not None:
sd_model = diffusers.StableDiffusionPipeline.from_pretrained(checkpoint_info.path, cache_dir=shared.opts.diffusers_dir, **diffusers_load_config)
sd_model.model_type = sd_model.__class__.__name__
except Exception as e:
err3 = e # ignore last error
shared.log.error(f'StableDiffusionPipeline: {e}')
if err3 is not None:
shared.log.error(f'Failed loading {op}: {checkpoint_info.path} auto={err1} diffusion={err2}')
return
@ -1155,7 +1160,7 @@ def reload_model_weights(sd_model=None, info=None, reuse_dict=False, op='model')
return None
orig_state = copy.deepcopy(shared.state)
shared.state = shared_state.State()
shared.state.begin(f'load-{op}')
shared.state.begin('load')
if load_dict:
shared.log.debug(f'Model dict: existing={sd_model is not None} target={checkpoint_info.filename} info={info}')
else:

View File

@ -4,10 +4,10 @@ import torch
from modules import paths, sd_disable_initialization, devices
sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion")
sd_repo_configs_path = 'configs'
config_default = paths.sd_default_config
config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")
config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference-512-base.yaml")
config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-768-v.yaml")
config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml")

View File

@ -173,7 +173,12 @@ class CFGDenoiser(torch.nn.Module):
else:
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
if self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised
if devices.backend == "directml":
self.init_latent = self.init_latent.float()
denoised = self.init_latent * self.mask + self.nmask * denoised
self.init_latent = self.init_latent.half()
else:
denoised = self.init_latent * self.mask + self.nmask * denoised
after_cfg_callback_params = AfterCFGCallbackParams(denoised, shared.state.sampling_step, shared.state.sampling_steps)
cfg_after_cfg_callback(after_cfg_callback_params)
denoised = after_cfg_callback_params.x
@ -333,6 +338,7 @@ class KDiffusionSampler:
's_min_uncond': self.s_min_uncond
}
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
samples = samples.type(devices.dtype)
return samples
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):

View File

@ -49,6 +49,7 @@ class CFGDenoiserTimesteps(CFGDenoiser):
self.alphas = shared.sd_model.alphas_cumprod
self.mask_before_denoising = True
self.model_wrap = None
def get_pred_x0(self, x_in, x_out, sigma):
ts = sigma.to(dtype=int)

View File

@ -3,7 +3,7 @@ import collections
import glob
from copy import deepcopy
import torch
from modules import shared, paths, paths_internal, devices, script_callbacks, sd_models
from modules import shared, paths, devices, script_callbacks, sd_models
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
@ -200,7 +200,7 @@ def load_vae_diffusers(model_file, vae_file=None, vae_source="unknown-source"):
import diffusers
if os.path.isfile(vae_file):
_pipeline, model_type = sd_models.detect_pipeline(model_file, 'vae')
diffusers_load_config = { "config_file": paths_internal.sd_default_config if model_type != 'Stable Diffusion XL' else os.path.join(paths_internal.sd_configs_path, 'sd_xl_base.yaml')}
diffusers_load_config = { "config_file": paths.sd_default_config if model_type != 'Stable Diffusion XL' else os.path.join(paths.sd_configs_path, 'sd_xl_base.yaml')}
vae = diffusers.AutoencoderKL.from_single_file(vae_file, **diffusers_load_config)
vae = vae.to(devices.dtype_vae)
else:

View File

@ -12,13 +12,13 @@ import gradio as gr
import fasteners
from rich.console import Console
from modules import errors, shared_items, shared_state, cmd_args, ui_components, theme
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # pylint: disable=W0611
from modules.paths import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # pylint: disable=W0611
from modules.dml import memory_providers, default_memory_provider, directml_do_hijack
import modules.interrogate
import modules.memmon
import modules.styles
import modules.devices as devices # pylint: disable=R0402
import modules.paths_internal as paths
import modules.paths as paths
from installer import print_dict
from installer import log as central_logger # pylint: disable=E0611
@ -337,7 +337,7 @@ options_templates.update(options_section(('diffusers', "Diffusers Settings"), {
"diffusers_attention_slicing": OptionInfo(False, "Enable attention slicing"),
"diffusers_model_load_variant": OptionInfo("default", "Diffusers model loading variant", gr.Radio, {"choices": ['default', 'fp32', 'fp16']}),
"diffusers_vae_load_variant": OptionInfo("default", "Diffusers VAE loading variant", gr.Radio, {"choices": ['default', 'fp32', 'fp16']}),
"custom_diffusers_pipeline": OptionInfo('hf-internal-testing/diffusers-dummy-pipeline', 'Custom Diffusers pipeline to use'),
"custom_diffusers_pipeline": OptionInfo('', 'Load custom Diffusers pipeline'),
"diffusers_lora_loader": OptionInfo("diffusers" if cmd_opts.use_openvino else "sequential apply", "Diffusers LoRA loading variant", gr.Radio, {"choices": ['diffusers', 'sequential apply', 'merge and apply']}),
"diffusers_force_zeros": OptionInfo(True, "Force zeros for prompts when empty"),
"diffusers_aesthetics_score": OptionInfo(False, "Require aesthetics score"),
@ -346,8 +346,8 @@ options_templates.update(options_section(('diffusers', "Diffusers Settings"), {
}))
options_templates.update(options_section(('system-paths', "System Paths"), {
"temp_dir": OptionInfo("", "Directory for temporary images; leave empty for default", folder=True),
"clean_temp_dir_at_start": OptionInfo(True, "Cleanup non-default temporary directory when starting webui"),
"models_paths_sep_options": OptionInfo("<h2>Models paths</h2>", "", gr.HTML),
"models_dir": OptionInfo('models', "Base path where all models are stored", folder=True),
"ckpt_dir": OptionInfo(os.path.join(paths.models_path, 'Stable-diffusion'), "Folder with stable diffusion models", folder=True),
"diffusers_dir": OptionInfo(os.path.join(paths.models_path, 'Diffusers'), "Folder with Hugggingface models", folder=True),
"vae_dir": OptionInfo(os.path.join(paths.models_path, 'VAE'), "Folder with VAE files", folder=True),
@ -366,6 +366,10 @@ options_templates.update(options_section(('system-paths', "System Paths"), {
"swinir_models_path": OptionInfo(os.path.join(paths.models_path, 'SwinIR'), "Folder with SwinIR models", folder=True),
"ldsr_models_path": OptionInfo(os.path.join(paths.models_path, 'LDSR'), "Folder with LDSR models", folder=True),
"clip_models_path": OptionInfo(os.path.join(paths.models_path, 'CLIP'), "Folder with CLIP models", folder=True),
"other_paths_sep_options": OptionInfo("<h2>Other paths</h2>", "", gr.HTML),
"temp_dir": OptionInfo("", "Directory for temporary images; leave empty for default", folder=True),
"clean_temp_dir_at_start": OptionInfo(True, "Cleanup non-default temporary directory when starting webui"),
}))
options_templates.update(options_section(('saving-images', "Image Options"), {
@ -474,7 +478,7 @@ options_templates.update(options_section(('sampler-params', "Sampler Settings"),
"schedulers_use_karras": OptionInfo(True, "Use Karras sigmas", gr.Checkbox, {"visible": False}),
"schedulers_use_thresholding": OptionInfo(False, "Use dynamic thresholding", gr.Checkbox, {"visible": False}),
"schedulers_use_loworder": OptionInfo(True, "Use simplified solvers in final steps", gr.Checkbox, {"visible": False}),
"schedulers_prediction_type": OptionInfo("default", "Override model prediction type", gr.Radio, {"choices": ['default', 'epsilon', 'sample', 'v_prediction'], "visible": False}),
"schedulers_prediction_type": OptionInfo("default", "Override model prediction type", gr.Radio, {"choices": ['default', 'epsilon', 'sample', 'v_prediction']}),
# managed from ui.py for backend diffusers
"schedulers_sep_diffusers": OptionInfo("<h2>Diffusers specific config</h2>", "", gr.HTML),

View File

@ -13,7 +13,6 @@ class State:
job_no = 0
job_count = 0
total_jobs = 0
processing_has_refined_job_count = False
job_timestamp = '0'
sampling_step = 0
sampling_steps = 0
@ -72,7 +71,6 @@ class State:
self.job_no = 0
self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
self.paused = False
self.processing_has_refined_job_count = False
self.sampling_step = 0
self.skipped = False
self.textinfo = None

View File

@ -9,7 +9,7 @@ from modules import paths
class Style():
def __init__(self, name: str, desc: str = "", prompt: str = "", negative_prompt: str = "", extra: str = "", filename: str = "", preview: str = ""):
def __init__(self, name: str, desc: str = "", prompt: str = "", negative_prompt: str = "", extra: str = "", filename: str = "", preview: str = "", mtime: float = 0):
self.name = name
self.description = desc
self.prompt = prompt
@ -17,6 +17,7 @@ class Style():
self.extra = extra
self.filename = filename
self.preview = preview
self.mtime = mtime
def merge_prompts(style_prompt: str, prompt: str) -> str:
if "{prompt}" in style_prompt:
@ -105,7 +106,8 @@ class StyleDatabase:
negative_prompt=style.get("negative", ""),
extra=style.get("extra", ""),
preview=style.get("preview", None),
filename=fn
filename=fn,
mtime=os.path.getmtime(fn),
)
except Exception as e:
log.error(f'Failed to load style: file={fn} error={e}')

View File

@ -6,7 +6,7 @@ https://github.com/madebyollin/taesd
"""
import os
from PIL import Image
from modules import devices, paths_internal
from modules import devices, paths
from modules.taesd.taesd import TAESD
taesd_models = { 'sd-decoder': None, 'sd-encoder': None, 'sdxl-decoder': None, 'sdxl-encoder': None }
@ -25,7 +25,7 @@ def download_model(model_path):
def model(model_class = 'sd', model_type = 'decoder'):
vae = taesd_models[f'{model_class}-{model_type}']
if vae is None:
model_path = os.path.join(paths_internal.models_path, "TAESD", f"tae{model_class}_{model_type}.pth")
model_path = os.path.join(paths.models_path, "TAESD", f"tae{model_class}_{model_type}.pth")
download_model(model_path)
if os.path.exists(model_path):
from modules.shared import log
@ -52,7 +52,7 @@ def decode(latents):
return Image.new('RGB', (8, 8), color = (0, 0, 0))
vae = taesd_models[f'{model_class}-decoder']
if vae is None:
model_path = os.path.join(paths_internal.models_path, "TAESD", f"tae{model_class}_decoder.pth")
model_path = os.path.join(paths.models_path, "TAESD", f"tae{model_class}_decoder.pth")
download_model(model_path)
if os.path.exists(model_path):
taesd_models[f'{model_class}-decoder'] = TAESD(decoder_path=model_path, encoder_path=None)
@ -73,7 +73,7 @@ def encode(image):
return Image.new('RGB', (8, 8), color = (0, 0, 0))
vae = taesd_models[f'{model_class}-encoder']
if vae is None:
model_path = os.path.join(paths_internal.models_path, "TAESD", f"tae{model_class}_encoder.pth")
model_path = os.path.join(paths.models_path, "TAESD", f"tae{model_class}_encoder.pth")
download_model(model_path)
if os.path.exists(model_path):
taesd_models[f'{model_class}-encoder'] = TAESD(encoder_path=model_path, decoder_path=None)

View File

@ -133,7 +133,7 @@ def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, t
image = srcimage.copy()
fontsize = 32
if textfont is None:
textfont = opts.font or 'html/roboto.ttf'
textfont = opts.font or 'javascript/roboto.ttf'
factor = 1.5
gradient = Image.new('RGBA', (1, image.size[1]), color=(0, 0, 0, 0))

View File

@ -425,7 +425,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
log_directory = f"{os.path.join(shared.cmd_opts.data_dir, 'train/log/embeddings')}"
template_file = template_file.path
shared.state.job = "train-embedding"
shared.state.job = "train"
shared.state.textinfo = "Initializing textual inversion training..."
shared.state.job_count = steps

View File

@ -644,7 +644,7 @@ def create_ui(startup_timer = None):
steps, sampler_index = create_sampler_and_steps_selection(modules.sd_samplers.samplers_for_img2img, "img2img")
with gr.Accordion(open=False, label="Resize", elem_classes=["small-accordion"], elem_id="img2img_resize_group"):
with FormRow():
with gr.Row():
resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["None", "Resize fixed", "Crop and resize", "Resize and fill", "Latent upscale"], type="index", value="None")
with FormRow():

View File

@ -118,14 +118,14 @@ class ExtraNetworksPage:
self.list_time = 0
# class additional is to keep old extensions happy
self.card = '''
<div class='card' onclick={card_click} title='{name}' data-tab='{tabname}' data-page='{page}' data-name='{name}' data-filename='{filename}' data-tags='{tags}'>
<div class='card' onclick={card_click} title='{name}' data-tab='{tabname}' data-page='{page}' data-name='{name}' data-filename='{filename}' data-tags='{tags}' data-mtime='{mtime}' data-size='{size}'>
<div class='overlay'>
<span style="display:none" class='search_term'>{search_term}</span>
<div class='tags'></div>
<div class='name'>{title}</div>
</div>
<div class='actions'>
<span title="Get details" onclick="showCardDetails(event)">&#x1f6c8;</span>
<span class='details' title="Get details" onclick="showCardDetails(event)">&#x1f6c8;</span>
<div class='additional'><ul></ul></div>
</div>
<img class='preview' src='{preview}' style='width: {width}px; height: {height}px; object-fit: {fit}' loading='lazy'></img>
@ -282,6 +282,8 @@ class ExtraNetworksPage:
"search_term": item.get("search_term", ""),
"description": item.get("description") or "",
"card_click": item.get("onclick", '"' + html.escape(f'return cardClicked({item.get("prompt", None)}, {"true" if self.allow_negative_prompt else "false"})') + '"'),
"mtime": item.get("mtime", 0),
"size": item.get("size", 0),
}
alias = item.get("alias", None)
if alias is not None:
@ -392,6 +394,7 @@ class ExtraNetworksUi:
self.button_refresh: gr.Button = None
self.button_scan: gr.Button = None
self.button_save: gr.Button = None
self.button_sort: gr.Button = None
self.button_apply: gr.Button = None
self.button_close: gr.Button = None
self.button_model: gr.Checkbox = None
@ -485,7 +488,8 @@ def create_ui(container, button_parent, tabname, skip_indexing = False):
ui.button_refresh = ToolButton(symbols.refresh, elem_id=tabname+"_extra_refresh")
ui.button_scan = ToolButton(symbols.scan, elem_id=tabname+"_extra_scan", visible=True)
ui.button_save = ToolButton(symbols.book, elem_id=tabname+"_extra_save", visible=False)
ui.button_close = ToolButton(symbols.close, elem_id=tabname+"_extra_close")
ui.button_sort = ToolButton(symbols.sort, elem_id=tabname+"_extra_sort", visible=True)
ui.button_close = ToolButton(symbols.close, elem_id=tabname+"_extra_close", visible=True)
ui.button_model = ToolButton(symbols.refine, elem_id=tabname+"_extra_model", visible=True)
ui.search = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", elem_classes="textbox", lines=2, container=False)
ui.description = gr.Textbox('', show_label=False, elem_id=tabname+"_description", elem_classes="textbox", lines=2, interactive=False, container=False)
@ -700,9 +704,14 @@ def create_ui(container, button_parent, tabname, skip_indexing = False):
res = show_details(text=None, img=None, desc=None, info=None, meta=None, params=params)
return res
def ui_sort_cards(msg):
shared.log.debug(f'Extra networks: {msg}')
return msg
dummy_state = gr.State(value=False) # pylint: disable=abstract-class-instantiated
button_parent.click(fn=toggle_visibility, inputs=[ui.visible], outputs=[ui.visible, container, button_parent])
ui.button_close.click(fn=toggle_visibility, inputs=[ui.visible], outputs=[ui.visible, container])
ui.button_sort.click(fn=ui_sort_cards, _js='sortExtraNetworks', inputs=[ui.search], outputs=[ui.description])
ui.button_refresh.click(fn=ui_refresh_click, _js='getENActivePage', inputs=[ui.search], outputs=ui.pages)
ui.button_scan.click(fn=ui_scan_click, _js='getENActivePage', inputs=[ui.search], outputs=ui.pages)
ui.button_save.click(fn=ui_save_click, inputs=[], outputs=ui.details_components + [ui.details])

View File

@ -30,6 +30,8 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
"info": self.find_info(fn),
"metadata": checkpoint.metadata,
"onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"',
"mtime": os.path.getmtime(checkpoint.filename),
"size": os.path.getsize(checkpoint.filename),
}
yield record
except Exception as e:

View File

@ -25,6 +25,8 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
"search_term": self.search_terms_from_path(name),
"prompt": json.dumps(f"<hypernet:{name}:{shared.opts.extra_networks_default_multiplier}>"),
"local_preview": f"{fn}.{shared.opts.samples_format}",
"mtime": os.path.getmtime(path),
"size": os.path.getsize(path),
}
except Exception as e:
shared.log.debug(f"Extra networks error: type=hypernetwork file={path} {e}")

View File

@ -85,6 +85,8 @@ class ExtraNetworksPageStyles(ui_extra_networks.ExtraNetworksPage):
"extra": getattr(style, 'extra', ''),
"local_preview": f"{fn}.{shared.opts.samples_format}",
"onclick": '"' + html.escape(f"""return selectStyle({json.dumps(name)})""") + '"',
"mtime": getattr(style, 'mtime', 0),
"size": os.path.getsize(style.filename),
}
except Exception as e:
shared.log.debug(f"Extra networks error: type=style file={k} {e}")

View File

@ -56,6 +56,8 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
"prompt": json.dumps(os.path.splitext(embedding.name)[0]),
"local_preview": f"{path}.{shared.opts.samples_format}",
"tags": tags,
"mtime": os.path.getmtime(embedding.filename),
"size": os.path.getsize(embedding.filename),
}
except Exception as e:
shared.log.debug(f"Extra networks error: type=embedding file={embedding.filename} {e}")

View File

@ -28,6 +28,8 @@ class ExtraNetworksPageVAEs(ui_extra_networks.ExtraNetworksPage):
"info": self.find_info(fn),
"metadata": {},
"onclick": '"' + html.escape(f"""return selectVAE({json.dumps(name)})""") + '"',
"mtime": os.path.getmtime(filename),
"size": os.path.getsize(filename),
}
yield record
except Exception as e:

View File

@ -1,8 +1,7 @@
# TODO: a1111 compatibility item, not used
import gradio as gr
from modules import shared, ui_common, ui_components, styles
from modules import shared, styles
styles_edit_symbol = '\U0001f58c\uFE0F' # 🖌️
styles_materialize_symbol = '\U0001f4cb' # 📋
@ -34,7 +33,7 @@ def delete_style(name):
return '', '', ''
def materialize_styles(prompt, negative_prompt, styles):
def materialize_styles(prompt, negative_prompt, styles): # pylint: disable=redefined-outer-name
prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles)
negative_prompt = shared.prompt_styles.apply_negative_styles_to_prompt(negative_prompt, styles)
return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=negative_prompt), gr.Dropdown.update(value=[])]
@ -45,7 +44,7 @@ def refresh_styles():
class UiPromptStyles:
def __init__(self, tabname, main_ui_prompt, main_ui_negative_prompt):
def __init__(self, tabname, main_ui_prompt, main_ui_negative_prompt): # pylint: disable=unused-argument
self.dropdown = gr.Dropdown(label="Styles", elem_id=f"{tabname}_styles", choices=[style.name for style in shared.prompt_styles.styles.values()], value=[], multiselect=True)
"""

View File

@ -11,6 +11,7 @@ networks = '🌐'
paste = ''
refine = ''
switch = ''
sort = ''
detect = '📐'
folder = '📂'
random = '🎲️'

View File

@ -34,7 +34,11 @@ exclude = [
"extensions-builtin",
"modules/lora",
"modules/dml",
"modules/models/diffusion",
"modules/k-diffusion",
"repositories/ldm",
"repositories/taming",
"repositories/blip",
"repositories/codeformer",
]
ignore = [
"A003", # Class attirbute shadowing builtin

View File

@ -0,0 +1,2 @@
# Comment line immediately above ownership line is reserved for related gus information. Please be careful while editing.
#ECCN:Open Source

View File

@ -0,0 +1,105 @@
# Salesforce Open Source Community Code of Conduct
## About the Code of Conduct
Equality is a core value at Salesforce. We believe a diverse and inclusive
community fosters innovation and creativity, and are committed to building a
culture where everyone feels included.
Salesforce open-source projects are committed to providing a friendly, safe, and
welcoming environment for all, regardless of gender identity and expression,
sexual orientation, disability, physical appearance, body size, ethnicity, nationality,
race, age, religion, level of experience, education, socioeconomic status, or
other similar personal characteristics.
The goal of this code of conduct is to specify a baseline standard of behavior so
that people with different social values and communication styles can work
together effectively, productively, and respectfully in our open source community.
It also establishes a mechanism for reporting issues and resolving conflicts.
All questions and reports of abusive, harassing, or otherwise unacceptable behavior
in a Salesforce open-source project may be reported by contacting the Salesforce
Open Source Conduct Committee at ossconduct@salesforce.com.
## Our Pledge
In the interest of fostering an open and welcoming environment, we as
contributors and maintainers pledge to making participation in our project and
our community a harassment-free experience for everyone, regardless of gender
identity and expression, sexual orientation, disability, physical appearance,
body size, ethnicity, nationality, race, age, religion, level of experience, education,
socioeconomic status, or other similar personal characteristics.
## Our Standards
Examples of behavior that contributes to creating a positive environment
include:
* Using welcoming and inclusive language
* Being respectful of differing viewpoints and experiences
* Gracefully accepting constructive criticism
* Focusing on what is best for the community
* Showing empathy toward other community members
Examples of unacceptable behavior by participants include:
* The use of sexualized language or imagery and unwelcome sexual attention or
advances
* Personal attacks, insulting/derogatory comments, or trolling
* Public or private harassment
* Publishing, or threatening to publish, others' private information—such as
a physical or electronic address—without explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
* Advocating for or encouraging any of the above behaviors
## Our Responsibilities
Project maintainers are responsible for clarifying the standards of acceptable
behavior and are expected to take appropriate and fair corrective action in
response to any instances of unacceptable behavior.
Project maintainers have the right and responsibility to remove, edit, or
reject comments, commits, code, wiki edits, issues, and other contributions
that are not aligned with this Code of Conduct, or to ban temporarily or
permanently any contributor for other behaviors that they deem inappropriate,
threatening, offensive, or harmful.
## Scope
This Code of Conduct applies both within project spaces and in public spaces
when an individual is representing the project or its community. Examples of
representing a project or community include using an official project email
address, posting via an official social media account, or acting as an appointed
representative at an online or offline event. Representation of a project may be
further defined and clarified by project maintainers.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported by contacting the Salesforce Open Source Conduct Committee
at ossconduct@salesforce.com. All complaints will be reviewed and investigated
and will result in a response that is deemed necessary and appropriate to the
circumstances. The committee is obligated to maintain confidentiality with
regard to the reporter of an incident. Further details of specific enforcement
policies may be posted separately.
Project maintainers who do not follow or enforce the Code of Conduct in good
faith may face temporary or permanent repercussions as determined by other
members of the project's leadership and the Salesforce Open Source Conduct
Committee.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][contributor-covenant-home],
version 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html.
It includes adaptions and additions from [Go Community Code of Conduct][golang-coc],
[CNCF Code of Conduct][cncf-coc], and [Microsoft Open Source Code of Conduct][microsoft-coc].
This Code of Conduct is licensed under the [Creative Commons Attribution 3.0 License][cc-by-3-us].
[contributor-covenant-home]: https://www.contributor-covenant.org (https://www.contributor-covenant.org/)
[golang-coc]: https://golang.org/conduct
[cncf-coc]: https://github.com/cncf/foundation/blob/master/code-of-conduct.md
[microsoft-coc]: https://opensource.microsoft.com/codeofconduct/
[cc-by-3-us]: https://creativecommons.org/licenses/by/3.0/us/

View File

@ -0,0 +1,12 @@
Copyright (c) 2022, Salesforce.com, Inc.
All rights reserved.
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
* Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

116
repositories/blip/README.md Normal file
View File

@ -0,0 +1,116 @@
## BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation
## Announcement: BLIP is now officially integrated into [LAVIS](https://github.com/salesforce/LAVIS) - a one-stop library for language-and-vision research and applications!
<img src="BLIP.gif" width="700">
This is the PyTorch code of the <a href="https://arxiv.org/abs/2201.12086">BLIP paper</a> [[blog](https://blog.salesforceairesearch.com/blip-bootstrapping-language-image-pretraining/)]. The code has been tested on PyTorch 1.10.
To install the dependencies, run <pre/>pip install -r requirements.txt</pre>
Catalog:
- [x] Inference demo
- [x] Pre-trained and finetuned checkpoints
- [x] Finetuning code for Image-Text Retrieval, Image Captioning, VQA, and NLVR2
- [x] Pre-training code
- [x] Zero-shot video-text retrieval
- [x] Download of bootstrapped pre-training datasets
### Inference demo:
Run our interactive demo using [Colab notebook](https://colab.research.google.com/github/salesforce/BLIP/blob/main/demo.ipynb) (no GPU needed).
The demo includes code for:
1. Image captioning
2. Open-ended visual question answering
3. Multimodal / unimodal feature extraction
4. Image-text matching
Try out the [Web demo](https://huggingface.co/spaces/Salesforce/BLIP), integrated into [Huggingface Spaces 🤗](https://huggingface.co/spaces) using [Gradio](https://github.com/gradio-app/gradio).
Replicate web demo and Docker image is also available at [![Replicate](https://replicate.com/salesforce/blip/badge)](https://replicate.com/salesforce/blip)
### Pre-trained checkpoints:
Num. pre-train images | BLIP w/ ViT-B | BLIP w/ ViT-B and CapFilt-L | BLIP w/ ViT-L
--- | :---: | :---: | :---:
14M | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_14M.pth">Download</a>| - | -
129M | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth">Download</a> | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large.pth">Download</a>
### Finetuned checkpoints:
Task | BLIP w/ ViT-B | BLIP w/ ViT-B and CapFilt-L | BLIP w/ ViT-L
--- | :---: | :---: | :---:
Image-Text Retrieval (COCO) | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth">Download</a>| - | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_coco.pth">Download</a>
Image-Text Retrieval (Flickr30k) | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_flickr.pth">Download</a>| - | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_flickr.pth">Download</a>
Image Captioning (COCO) | - | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth">Download</a> |
VQA | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_vqa.pth">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth">Download</a> | -
NLVR2 | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_nlvr.pth">Download</a>| - | -
### Image-Text Retrieval:
1. Download COCO and Flickr30k datasets from the original websites, and set 'image_root' in configs/retrieval_{dataset}.yaml accordingly.
2. To evaluate the finetuned BLIP model on COCO, run:
<pre>python -m torch.distributed.run --nproc_per_node=8 train_retrieval.py \
--config ./configs/retrieval_coco.yaml \
--output_dir output/retrieval_coco \
--evaluate</pre>
3. To finetune the pre-trained checkpoint using 8 A100 GPUs, first set 'pretrained' in configs/retrieval_coco.yaml as "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth". Then run:
<pre>python -m torch.distributed.run --nproc_per_node=8 train_retrieval.py \
--config ./configs/retrieval_coco.yaml \
--output_dir output/retrieval_coco </pre>
### Image-Text Captioning:
1. Download COCO and NoCaps datasets from the original websites, and set 'image_root' in configs/caption_coco.yaml and configs/nocaps.yaml accordingly.
2. To evaluate the finetuned BLIP model on COCO, run:
<pre>python -m torch.distributed.run --nproc_per_node=8 train_caption.py --evaluate</pre>
3. To evaluate the finetuned BLIP model on NoCaps, generate results with: (evaluation needs to be performed on official server)
<pre>python -m torch.distributed.run --nproc_per_node=8 eval_nocaps.py </pre>
4. To finetune the pre-trained checkpoint using 8 A100 GPUs, first set 'pretrained' in configs/caption_coco.yaml as "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth". Then run:
<pre>python -m torch.distributed.run --nproc_per_node=8 train_caption.py </pre>
### VQA:
1. Download VQA v2 dataset and Visual Genome dataset from the original websites, and set 'vqa_root' and 'vg_root' in configs/vqa.yaml.
2. To evaluate the finetuned BLIP model, generate results with: (evaluation needs to be performed on official server)
<pre>python -m torch.distributed.run --nproc_per_node=8 train_vqa.py --evaluate</pre>
3. To finetune the pre-trained checkpoint using 16 A100 GPUs, first set 'pretrained' in configs/vqa.yaml as "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth". Then run:
<pre>python -m torch.distributed.run --nproc_per_node=16 train_vqa.py </pre>
### NLVR2:
1. Download NLVR2 dataset from the original websites, and set 'image_root' in configs/nlvr.yaml.
2. To evaluate the finetuned BLIP model, run
<pre>python -m torch.distributed.run --nproc_per_node=8 train_nlvr.py --evaluate</pre>
3. To finetune the pre-trained checkpoint using 16 A100 GPUs, first set 'pretrained' in configs/nlvr.yaml as "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth". Then run:
<pre>python -m torch.distributed.run --nproc_per_node=16 train_nlvr.py </pre>
### Finetune with ViT-L:
In order to finetune a model with ViT-L, simply change the config file to set 'vit' as large. Batch size and learning rate may also need to be adjusted accordingly (please see the paper's appendix for hyper-parameter details). <a href="https://github.com/facebookresearch/fairscale">Gradient checkpoint</a> can also be activated in the config file to reduce GPU memory usage.
### Pre-train:
1. Prepare training json files where each json file contains a list. Each item in the list is a dictonary with two key-value pairs: {'image': path_of_image, 'caption': text_of_image}.
2. In configs/pretrain.yaml, set 'train_file' as the paths for the json files .
3. Pre-train the model using 8 A100 GPUs:
<pre>python -m torch.distributed.run --nproc_per_node=8 pretrain.py --config ./configs/Pretrain.yaml --output_dir output/Pretrain </pre>
### Zero-shot video-text retrieval:
1. Download MSRVTT dataset following the instructions from https://github.com/salesforce/ALPRO, and set 'video_root' accordingly in configs/retrieval_msrvtt.yaml.
2. Install [decord](https://github.com/dmlc/decord) with <pre>pip install decord</pre>
3. To perform zero-shot evaluation, run
<pre>python -m torch.distributed.run --nproc_per_node=8 eval_retrieval_video.py</pre>
### Pre-training datasets download:
We provide bootstrapped pre-training datasets as json files. Each json file contains a list. Each item in the list is a dictonary with two key-value pairs: {'url': url_of_image, 'caption': text_of_image}.
Image source | Filtered web caption | Filtered synthetic caption by ViT-B | Filtered synthetic caption by ViT-L
--- | :---: | :---: | :---:
CC3M+CC12M+SBU | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/ccs_filtered.json">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/ccs_synthetic_filtered.json">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/ccs_synthetic_filtered_large.json">Download</a>
LAION115M | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/laion_filtered.json">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/laion_synthetic_filtered.json">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/laion_synthetic_filtered_large.json">Download</a>
### Citation
If you find this code to be useful for your research, please consider citing.
<pre>
@inproceedings{li2022blip,
title={BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation},
author={Junnan Li and Dongxu Li and Caiming Xiong and Steven Hoi},
year={2022},
booktitle={ICML},
}</pre>
### Acknowledgement
The implementation of BLIP relies on resources from <a href="https://github.com/salesforce/ALBEF">ALBEF</a>, <a href="https://github.com/huggingface/transformers">Huggingface Transformers</a>, and <a href="https://github.com/rwightman/pytorch-image-models/tree/master/timm">timm</a>. We thank the original authors for their open-sourcing.

View File

@ -0,0 +1,7 @@
## Security
Please report any security issue to [security@salesforce.com](mailto:security@salesforce.com)
as soon as it is discovered. This library limits its runtime dependencies in
order to reduce the total cost of ownership as much as can be, but all consumers
should remain vigilant and have their security stakeholders review all third-party
products (3PP) like this one and their dependencies.

View File

@ -0,0 +1,17 @@
build:
gpu: true
cuda: "11.1"
python_version: "3.8"
system_packages:
- "libgl1-mesa-glx"
- "libglib2.0-0"
python_packages:
- "ipython==7.30.1"
- "torchvision==0.11.1"
- "torch==1.10.0"
- "timm==0.4.12"
- "transformers==4.15.0"
- "fairscale==0.4.4"
- "pycocoevalcap==1.2"
predict: "predict.py:Predictor"

View File

@ -0,0 +1,21 @@
{
"architectures": [
"BertModel"
],
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 0,
"type_vocab_size": 2,
"vocab_size": 30522,
"encoder_width": 768,
"add_cross_attention": true
}

View File

@ -0,0 +1,33 @@
image_root: '/export/share/datasets/vision/coco/images/'
ann_root: 'annotation'
coco_gt_root: 'annotation/coco_gt'
# set pretrained as a file path or an url
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
# size of vit model; base or large
vit: 'base'
vit_grad_ckpt: False
vit_ckpt_layer: 0
batch_size: 32
init_lr: 1e-5
# vit: 'large'
# vit_grad_ckpt: True
# vit_ckpt_layer: 5
# batch_size: 16
# init_lr: 2e-6
image_size: 384
# generation configs
max_length: 20
min_length: 5
num_beams: 3
prompt: 'a picture of '
# optimizer
weight_decay: 0.05
min_lr: 0
max_epoch: 5

View File

@ -0,0 +1,21 @@
{
"architectures": [
"BertModel"
],
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 0,
"type_vocab_size": 2,
"vocab_size": 30524,
"encoder_width": 768,
"add_cross_attention": true
}

View File

@ -0,0 +1,21 @@
image_root: '/export/share/datasets/vision/NLVR2/'
ann_root: 'annotation'
# set pretrained as a file path or an url
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_nlvr.pth'
#size of vit model; base or large
vit: 'base'
batch_size_train: 16
batch_size_test: 64
vit_grad_ckpt: False
vit_ckpt_layer: 0
max_epoch: 15
image_size: 384
# optimizer
weight_decay: 0.05
init_lr: 3e-5
min_lr: 0

View File

@ -0,0 +1,15 @@
image_root: '/export/share/datasets/vision/nocaps/'
ann_root: 'annotation'
# set pretrained as a file path or an url
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
vit: 'base'
batch_size: 32
image_size: 384
max_length: 20
min_length: 5
num_beams: 3
prompt: 'a picture of '

View File

@ -0,0 +1,27 @@
train_file: ['/export/share/junnan-li/VL_pretrain/annotation/coco_karpathy_train.json',
'/export/share/junnan-li/VL_pretrain/annotation/vg_caption.json',
]
laion_path: ''
# size of vit model; base or large
vit: 'base'
vit_grad_ckpt: False
vit_ckpt_layer: 0
image_size: 224
batch_size: 75
queue_size: 57600
alpha: 0.4
# optimizer
weight_decay: 0.05
init_lr: 3e-4
min_lr: 1e-6
warmup_lr: 1e-6
lr_decay_rate: 0.9
max_epoch: 20
warmup_steps: 3000

View File

@ -0,0 +1,34 @@
image_root: '/export/share/datasets/vision/coco/images/'
ann_root: 'annotation'
dataset: 'coco'
# set pretrained as a file path or an url
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth'
# size of vit model; base or large
vit: 'base'
batch_size_train: 32
batch_size_test: 64
vit_grad_ckpt: True
vit_ckpt_layer: 4
init_lr: 1e-5
# vit: 'large'
# batch_size_train: 16
# batch_size_test: 32
# vit_grad_ckpt: True
# vit_ckpt_layer: 12
# init_lr: 5e-6
image_size: 384
queue_size: 57600
alpha: 0.4
k_test: 256
negative_all_rank: True
# optimizer
weight_decay: 0.05
min_lr: 0
max_epoch: 6

View File

@ -0,0 +1,34 @@
image_root: '/export/share/datasets/vision/flickr30k/'
ann_root: 'annotation'
dataset: 'flickr'
# set pretrained as a file path or an url
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_flickr.pth'
# size of vit model; base or large
vit: 'base'
batch_size_train: 32
batch_size_test: 64
vit_grad_ckpt: True
vit_ckpt_layer: 4
init_lr: 1e-5
# vit: 'large'
# batch_size_train: 16
# batch_size_test: 32
# vit_grad_ckpt: True
# vit_ckpt_layer: 10
# init_lr: 5e-6
image_size: 384
queue_size: 57600
alpha: 0.4
k_test: 128
negative_all_rank: False
# optimizer
weight_decay: 0.05
min_lr: 0
max_epoch: 6

View File

@ -0,0 +1,12 @@
video_root: '/export/share/dongxuli/data/msrvtt_retrieval/videos'
ann_root: 'annotation'
# set pretrained as a file path or an url
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth'
# size of vit model; base or large
vit: 'base'
batch_size: 64
k_test: 128
image_size: 384
num_frm_test: 8

View File

@ -0,0 +1,25 @@
vqa_root: '/export/share/datasets/vision/VQA/Images/mscoco/' #followed by train2014/
vg_root: '/export/share/datasets/vision/visual-genome/' #followed by image/
train_files: ['vqa_train','vqa_val','vg_qa']
ann_root: 'annotation'
# set pretrained as a file path or an url
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth'
# size of vit model; base or large
vit: 'base'
batch_size_train: 16
batch_size_test: 32
vit_grad_ckpt: False
vit_ckpt_layer: 0
init_lr: 2e-5
image_size: 480
k_test: 128
inference: 'rank'
# optimizer
weight_decay: 0.05
min_lr: 0
max_epoch: 10

View File

@ -0,0 +1,101 @@
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from data.coco_karpathy_dataset import coco_karpathy_train, coco_karpathy_caption_eval, coco_karpathy_retrieval_eval
from data.nocaps_dataset import nocaps_eval
from data.flickr30k_dataset import flickr30k_train, flickr30k_retrieval_eval
from data.vqa_dataset import vqa_dataset
from data.nlvr_dataset import nlvr_dataset
from data.pretrain_dataset import pretrain_dataset
from transform.randaugment import RandomAugment
def create_dataset(dataset, config, min_scale=0.5):
normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
transform_train = transforms.Compose([
transforms.RandomResizedCrop(config['image_size'],scale=(min_scale, 1.0),interpolation=InterpolationMode.BICUBIC),
transforms.RandomHorizontalFlip(),
RandomAugment(2,5,isPIL=True,augs=['Identity','AutoContrast','Brightness','Sharpness','Equalize',
'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),
transforms.ToTensor(),
normalize,
])
transform_test = transforms.Compose([
transforms.Resize((config['image_size'],config['image_size']),interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
normalize,
])
if dataset=='pretrain':
dataset = pretrain_dataset(config['train_file'], config['laion_path'], transform_train)
return dataset
elif dataset=='caption_coco':
train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root'], prompt=config['prompt'])
val_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'val')
test_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'test')
return train_dataset, val_dataset, test_dataset
elif dataset=='nocaps':
val_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'val')
test_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'test')
return val_dataset, test_dataset
elif dataset=='retrieval_coco':
train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root'])
val_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val')
test_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test')
return train_dataset, val_dataset, test_dataset
elif dataset=='retrieval_flickr':
train_dataset = flickr30k_train(transform_train, config['image_root'], config['ann_root'])
val_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val')
test_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test')
return train_dataset, val_dataset, test_dataset
elif dataset=='vqa':
train_dataset = vqa_dataset(transform_train, config['ann_root'], config['vqa_root'], config['vg_root'],
train_files = config['train_files'], split='train')
test_dataset = vqa_dataset(transform_test, config['ann_root'], config['vqa_root'], config['vg_root'], split='test')
return train_dataset, test_dataset
elif dataset=='nlvr':
train_dataset = nlvr_dataset(transform_train, config['image_root'], config['ann_root'],'train')
val_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'val')
test_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'test')
return train_dataset, val_dataset, test_dataset
def create_sampler(datasets, shuffles, num_tasks, global_rank):
samplers = []
for dataset,shuffle in zip(datasets,shuffles):
sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle)
samplers.append(sampler)
return samplers
def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns):
loaders = []
for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns):
if is_train:
shuffle = (sampler is None)
drop_last = True
else:
shuffle = False
drop_last = False
loader = DataLoader(
dataset,
batch_size=bs,
num_workers=n_worker,
pin_memory=True,
sampler=sampler,
shuffle=shuffle,
collate_fn=collate_fn,
drop_last=drop_last,
)
loaders.append(loader)
return loaders

View File

@ -0,0 +1,126 @@
import os
import json
from torch.utils.data import Dataset
from torchvision.datasets.utils import download_url
from PIL import Image
from data.utils import pre_caption
class coco_karpathy_train(Dataset):
def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''):
'''
image_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
'''
url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_train.json'
filename = 'coco_karpathy_train.json'
download_url(url,ann_root)
self.annotation = json.load(open(os.path.join(ann_root,filename),'r'))
self.transform = transform
self.image_root = image_root
self.max_words = max_words
self.prompt = prompt
self.img_ids = {}
n = 0
for ann in self.annotation:
img_id = ann['image_id']
if img_id not in self.img_ids.keys():
self.img_ids[img_id] = n
n += 1
def __len__(self):
return len(self.annotation)
def __getitem__(self, index):
ann = self.annotation[index]
image_path = os.path.join(self.image_root,ann['image'])
image = Image.open(image_path).convert('RGB')
image = self.transform(image)
caption = self.prompt+pre_caption(ann['caption'], self.max_words)
return image, caption, self.img_ids[ann['image_id']]
class coco_karpathy_caption_eval(Dataset):
def __init__(self, transform, image_root, ann_root, split):
'''
image_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
split (string): val or test
'''
urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json',
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'}
filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'}
download_url(urls[split],ann_root)
self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
self.transform = transform
self.image_root = image_root
def __len__(self):
return len(self.annotation)
def __getitem__(self, index):
ann = self.annotation[index]
image_path = os.path.join(self.image_root,ann['image'])
image = Image.open(image_path).convert('RGB')
image = self.transform(image)
img_id = ann['image'].split('/')[-1].strip('.jpg').split('_')[-1]
return image, int(img_id)
class coco_karpathy_retrieval_eval(Dataset):
def __init__(self, transform, image_root, ann_root, split, max_words=30):
'''
image_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
split (string): val or test
'''
urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json',
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'}
filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'}
download_url(urls[split],ann_root)
self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
self.transform = transform
self.image_root = image_root
self.text = []
self.image = []
self.txt2img = {}
self.img2txt = {}
txt_id = 0
for img_id, ann in enumerate(self.annotation):
self.image.append(ann['image'])
self.img2txt[img_id] = []
for i, caption in enumerate(ann['caption']):
self.text.append(pre_caption(caption,max_words))
self.img2txt[img_id].append(txt_id)
self.txt2img[txt_id] = img_id
txt_id += 1
def __len__(self):
return len(self.annotation)
def __getitem__(self, index):
image_path = os.path.join(self.image_root, self.annotation[index]['image'])
image = Image.open(image_path).convert('RGB')
image = self.transform(image)
return image, index

View File

@ -0,0 +1,93 @@
import os
import json
from torch.utils.data import Dataset
from torchvision.datasets.utils import download_url
from PIL import Image
from data.utils import pre_caption
class flickr30k_train(Dataset):
def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''):
'''
image_root (string): Root directory of images (e.g. flickr30k/)
ann_root (string): directory to store the annotation file
'''
url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_train.json'
filename = 'flickr30k_train.json'
download_url(url,ann_root)
self.annotation = json.load(open(os.path.join(ann_root,filename),'r'))
self.transform = transform
self.image_root = image_root
self.max_words = max_words
self.prompt = prompt
self.img_ids = {}
n = 0
for ann in self.annotation:
img_id = ann['image_id']
if img_id not in self.img_ids.keys():
self.img_ids[img_id] = n
n += 1
def __len__(self):
return len(self.annotation)
def __getitem__(self, index):
ann = self.annotation[index]
image_path = os.path.join(self.image_root,ann['image'])
image = Image.open(image_path).convert('RGB')
image = self.transform(image)
caption = self.prompt+pre_caption(ann['caption'], self.max_words)
return image, caption, self.img_ids[ann['image_id']]
class flickr30k_retrieval_eval(Dataset):
def __init__(self, transform, image_root, ann_root, split, max_words=30):
'''
image_root (string): Root directory of images (e.g. flickr30k/)
ann_root (string): directory to store the annotation file
split (string): val or test
'''
urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_val.json',
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_test.json'}
filenames = {'val':'flickr30k_val.json','test':'flickr30k_test.json'}
download_url(urls[split],ann_root)
self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
self.transform = transform
self.image_root = image_root
self.text = []
self.image = []
self.txt2img = {}
self.img2txt = {}
txt_id = 0
for img_id, ann in enumerate(self.annotation):
self.image.append(ann['image'])
self.img2txt[img_id] = []
for i, caption in enumerate(ann['caption']):
self.text.append(pre_caption(caption,max_words))
self.img2txt[img_id].append(txt_id)
self.txt2img[txt_id] = img_id
txt_id += 1
def __len__(self):
return len(self.annotation)
def __getitem__(self, index):
image_path = os.path.join(self.image_root, self.annotation[index]['image'])
image = Image.open(image_path).convert('RGB')
image = self.transform(image)
return image, index

View File

@ -0,0 +1,78 @@
import os
import json
import random
from torch.utils.data import Dataset
from torchvision.datasets.utils import download_url
from PIL import Image
from data.utils import pre_caption
class nlvr_dataset(Dataset):
def __init__(self, transform, image_root, ann_root, split):
'''
image_root (string): Root directory of images
ann_root (string): directory to store the annotation file
split (string): train, val or test
'''
urls = {'train':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_train.json',
'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_dev.json',
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_test.json'}
filenames = {'train':'nlvr_train.json','val':'nlvr_dev.json','test':'nlvr_test.json'}
download_url(urls[split],ann_root)
self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
self.transform = transform
self.image_root = image_root
def __len__(self):
return len(self.annotation)
def __getitem__(self, index):
ann = self.annotation[index]
image0_path = os.path.join(self.image_root,ann['images'][0])
image0 = Image.open(image0_path).convert('RGB')
image0 = self.transform(image0)
image1_path = os.path.join(self.image_root,ann['images'][1])
image1 = Image.open(image1_path).convert('RGB')
image1 = self.transform(image1)
sentence = pre_caption(ann['sentence'], 40)
if ann['label']=='True':
label = 1
else:
label = 0
words = sentence.split(' ')
if 'left' not in words and 'right' not in words:
if random.random()<0.5:
return image0, image1, sentence, label
else:
return image1, image0, sentence, label
else:
if random.random()<0.5:
return image0, image1, sentence, label
else:
new_words = []
for word in words:
if word=='left':
new_words.append('right')
elif word=='right':
new_words.append('left')
else:
new_words.append(word)
sentence = ' '.join(new_words)
return image1, image0, sentence, label

View File

@ -0,0 +1,32 @@
import os
import json
from torch.utils.data import Dataset
from torchvision.datasets.utils import download_url
from PIL import Image
class nocaps_eval(Dataset):
def __init__(self, transform, image_root, ann_root, split):
urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nocaps_val.json',
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nocaps_test.json'}
filenames = {'val':'nocaps_val.json','test':'nocaps_test.json'}
download_url(urls[split],ann_root)
self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
self.transform = transform
self.image_root = image_root
def __len__(self):
return len(self.annotation)
def __getitem__(self, index):
ann = self.annotation[index]
image_path = os.path.join(self.image_root,ann['image'])
image = Image.open(image_path).convert('RGB')
image = self.transform(image)
return image, int(ann['img_id'])

View File

@ -0,0 +1,59 @@
import json
import os
import random
from torch.utils.data import Dataset
from PIL import Image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
Image.MAX_IMAGE_PIXELS = None
from data.utils import pre_caption
import os,glob
class pretrain_dataset(Dataset):
def __init__(self, ann_file, laion_path, transform):
self.ann_pretrain = []
for f in ann_file:
print('loading '+f)
ann = json.load(open(f,'r'))
self.ann_pretrain += ann
self.laion_path = laion_path
if self.laion_path:
self.laion_files = glob.glob(os.path.join(laion_path,'*.json'))
print('loading '+self.laion_files[0])
with open(self.laion_files[0],'r') as f:
self.ann_laion = json.load(f)
self.annotation = self.ann_pretrain + self.ann_laion
else:
self.annotation = self.ann_pretrain
self.transform = transform
def reload_laion(self, epoch):
n = epoch%len(self.laion_files)
print('loading '+self.laion_files[n])
with open(self.laion_files[n],'r') as f:
self.ann_laion = json.load(f)
self.annotation = self.ann_pretrain + self.ann_laion
def __len__(self):
return len(self.annotation)
def __getitem__(self, index):
ann = self.annotation[index]
image = Image.open(ann['image']).convert('RGB')
image = self.transform(image)
caption = pre_caption(ann['caption'],30)
return image, caption

View File

@ -0,0 +1,112 @@
import re
import json
import os
import torch
import torch.distributed as dist
import utils
def pre_caption(caption,max_words=50):
caption = re.sub(
r"([.!\"()*#:;~])",
' ',
caption.lower(),
)
caption = re.sub(
r"\s{2,}",
' ',
caption,
)
caption = caption.rstrip('\n')
caption = caption.strip(' ')
#truncate caption
caption_words = caption.split(' ')
if len(caption_words)>max_words:
caption = ' '.join(caption_words[:max_words])
return caption
def pre_question(question,max_ques_words=50):
question = re.sub(
r"([.!\"()*#:;~])",
'',
question.lower(),
)
question = question.rstrip(' ')
#truncate question
question_words = question.split(' ')
if len(question_words)>max_ques_words:
question = ' '.join(question_words[:max_ques_words])
return question
def save_result(result, result_dir, filename, remove_duplicate=''):
result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,utils.get_rank()))
final_result_file = os.path.join(result_dir, '%s.json'%filename)
json.dump(result,open(result_file,'w'))
dist.barrier()
if utils.is_main_process():
# combine results from all processes
result = []
for rank in range(utils.get_world_size()):
result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,rank))
res = json.load(open(result_file,'r'))
result += res
if remove_duplicate:
result_new = []
id_list = []
for res in result:
if res[remove_duplicate] not in id_list:
id_list.append(res[remove_duplicate])
result_new.append(res)
result = result_new
json.dump(result,open(final_result_file,'w'))
print('result file saved to %s'%final_result_file)
return final_result_file
from pycocotools.coco import COCO
from pycocoevalcap.eval import COCOEvalCap
from torchvision.datasets.utils import download_url
def coco_caption_eval(coco_gt_root, results_file, split):
urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val_gt.json',
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test_gt.json'}
filenames = {'val':'coco_karpathy_val_gt.json','test':'coco_karpathy_test_gt.json'}
download_url(urls[split],coco_gt_root)
annotation_file = os.path.join(coco_gt_root,filenames[split])
# create coco object and coco_result object
coco = COCO(annotation_file)
coco_result = coco.loadRes(results_file)
# create coco_eval object by taking coco and coco_result
coco_eval = COCOEvalCap(coco, coco_result)
# evaluate on a subset of images by setting
# coco_eval.params['image_id'] = coco_result.getImgIds()
# please remove this line when evaluating the full validation set
# coco_eval.params['image_id'] = coco_result.getImgIds()
# evaluate results
# SPICE will take a few minutes the first time, but speeds up due to caching
coco_eval.evaluate()
# print output evaluation scores
for metric, score in coco_eval.eval.items():
print(f'{metric}: {score:.3f}')
return coco_eval

View File

@ -0,0 +1,110 @@
from torch.utils.data import Dataset
from torchvision.datasets.utils import download_url
from PIL import Image
import torch
import numpy as np
import random
import decord
from decord import VideoReader
import json
import os
from data.utils import pre_caption
decord.bridge.set_bridge("torch")
class ImageNorm(object):
"""Apply Normalization to Image Pixels on GPU
"""
def __init__(self, mean, std):
self.mean = torch.tensor(mean).view(1, 3, 1, 1)
self.std = torch.tensor(std).view(1, 3, 1, 1)
def __call__(self, img):
if torch.max(img) > 1 and self.mean.max() <= 1:
img.div_(255.)
return img.sub_(self.mean).div_(self.std)
def load_jsonl(filename):
with open(filename, "r") as f:
return [json.loads(l.strip("\n")) for l in f.readlines()]
class VideoDataset(Dataset):
def __init__(self, video_root, ann_root, num_frm=4, frm_sampling_strategy="rand", max_img_size=384, video_fmt='.mp4'):
'''
image_root (string): Root directory of video
ann_root (string): directory to store the annotation file
'''
url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/msrvtt_test.jsonl'
filename = 'msrvtt_test.jsonl'
download_url(url,ann_root)
self.annotation = load_jsonl(os.path.join(ann_root,filename))
self.num_frm = num_frm
self.frm_sampling_strategy = frm_sampling_strategy
self.max_img_size = max_img_size
self.video_root = video_root
self.video_fmt = video_fmt
self.img_norm = ImageNorm(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
self.text = [pre_caption(ann['caption'],40) for ann in self.annotation]
self.txt2video = [i for i in range(len(self.annotation))]
self.video2txt = self.txt2video
def __len__(self):
return len(self.annotation)
def __getitem__(self, index):
ann = self.annotation[index]
video_path = os.path.join(self.video_root, ann['clip_name'] + self.video_fmt)
vid_frm_array = self._load_video_from_path_decord(video_path, height=self.max_img_size, width=self.max_img_size)
video = self.img_norm(vid_frm_array.float())
return video, ann['clip_name']
def _load_video_from_path_decord(self, video_path, height=None, width=None, start_time=None, end_time=None, fps=-1):
try:
if not height or not width:
vr = VideoReader(video_path)
else:
vr = VideoReader(video_path, width=width, height=height)
vlen = len(vr)
if start_time or end_time:
assert fps > 0, 'must provide video fps if specifying start and end time.'
start_idx = min(int(start_time * fps), vlen)
end_idx = min(int(end_time * fps), vlen)
else:
start_idx, end_idx = 0, vlen
if self.frm_sampling_strategy == 'uniform':
frame_indices = np.arange(start_idx, end_idx, vlen / self.num_frm, dtype=int)
elif self.frm_sampling_strategy == 'rand':
frame_indices = sorted(random.sample(range(vlen), self.num_frm))
elif self.frm_sampling_strategy == 'headtail':
frame_indices_head = sorted(random.sample(range(vlen // 2), self.num_frm // 2))
frame_indices_tail = sorted(random.sample(range(vlen // 2, vlen), self.num_frm // 2))
frame_indices = frame_indices_head + frame_indices_tail
else:
raise NotImplementedError('Invalid sampling strategy {} '.format(self.frm_sampling_strategy))
raw_sample_frms = vr.get_batch(frame_indices)
except Exception as e:
return None
raw_sample_frms = raw_sample_frms.permute(0, 3, 1, 2)
return raw_sample_frms

View File

@ -0,0 +1,88 @@
import os
import json
import random
from PIL import Image
import torch
from torch.utils.data import Dataset
from data.utils import pre_question
from torchvision.datasets.utils import download_url
class vqa_dataset(Dataset):
def __init__(self, transform, ann_root, vqa_root, vg_root, train_files=[], split="train"):
self.split = split
self.transform = transform
self.vqa_root = vqa_root
self.vg_root = vg_root
if split=='train':
urls = {'vqa_train':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_train.json',
'vqa_val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_val.json',
'vg_qa':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vg_qa.json'}
self.annotation = []
for f in train_files:
download_url(urls[f],ann_root)
self.annotation += json.load(open(os.path.join(ann_root,'%s.json'%f),'r'))
else:
download_url('https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_test.json',ann_root)
self.annotation = json.load(open(os.path.join(ann_root,'vqa_test.json'),'r'))
download_url('https://storage.googleapis.com/sfr-vision-language-research/datasets/answer_list.json',ann_root)
self.answer_list = json.load(open(os.path.join(ann_root,'answer_list.json'),'r'))
def __len__(self):
return len(self.annotation)
def __getitem__(self, index):
ann = self.annotation[index]
if ann['dataset']=='vqa':
image_path = os.path.join(self.vqa_root,ann['image'])
elif ann['dataset']=='vg':
image_path = os.path.join(self.vg_root,ann['image'])
image = Image.open(image_path).convert('RGB')
image = self.transform(image)
if self.split == 'test':
question = pre_question(ann['question'])
question_id = ann['question_id']
return image, question, question_id
elif self.split=='train':
question = pre_question(ann['question'])
if ann['dataset']=='vqa':
answer_weight = {}
for answer in ann['answer']:
if answer in answer_weight.keys():
answer_weight[answer] += 1/len(ann['answer'])
else:
answer_weight[answer] = 1/len(ann['answer'])
answers = list(answer_weight.keys())
weights = list(answer_weight.values())
elif ann['dataset']=='vg':
answers = [ann['answer']]
weights = [0.2]
return image, question, answers, weights
def vqa_collate_fn(batch):
image_list, question_list, answer_list, weight_list, n = [], [], [], [], []
for image, question, answer, weights in batch:
image_list.append(image)
question_list.append(question)
weight_list += weights
answer_list += answer
n.append(len(answer))
return torch.stack(image_list,dim=0), question_list, answer_list, torch.Tensor(weight_list), n

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,118 @@
'''
* Copyright (c) 2022, salesforce.com, inc.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
* By Junnan Li
'''
import argparse
import os
import ruamel_yaml as yaml
import numpy as np
import random
import time
import datetime
import json
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.distributed as dist
from torch.utils.data import DataLoader
from models.blip import blip_decoder
import utils
from data import create_dataset, create_sampler, create_loader
from data.utils import save_result
@torch.no_grad()
def evaluate(model, data_loader, device, config):
# evaluate
model.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Evaluation:'
print_freq = 10
result = []
for image, image_id in metric_logger.log_every(data_loader, print_freq, header):
image = image.to(device)
captions = model.generate(image, sample=False, num_beams=config['num_beams'], max_length=config['max_length'],
min_length=config['min_length'], repetition_penalty=1.1)
for caption, img_id in zip(captions, image_id):
result.append({"image_id": img_id.item(), "caption": caption})
return result
def main(args, config):
utils.init_distributed_mode(args)
device = torch.device(args.device)
# fix the seed for reproducibility
seed = args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
cudnn.benchmark = True
#### Dataset ####
print("Creating captioning dataset")
val_dataset, test_dataset = create_dataset('nocaps', config)
if args.distributed:
num_tasks = utils.get_world_size()
global_rank = utils.get_rank()
samplers = create_sampler([val_dataset,test_dataset], [False,False], num_tasks, global_rank)
else:
samplers = [None,None]
val_loader, test_loader = create_loader([val_dataset, test_dataset],samplers,
batch_size=[config['batch_size']]*2,num_workers=[4,4],
is_trains=[False, False], collate_fns=[None,None])
#### Model ####
print("Creating model")
model = blip_decoder(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'],
prompt=config['prompt'])
model = model.to(device)
model_without_ddp = model
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
val_result = evaluate(model_without_ddp, val_loader, device, config)
val_result_file = save_result(val_result, args.result_dir, 'val', remove_duplicate='image_id')
test_result = evaluate(model_without_ddp, test_loader, device, config)
test_result_file = save_result(test_result, args.result_dir, 'test', remove_duplicate='image_id')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--config', default='./configs/nocaps.yaml')
parser.add_argument('--output_dir', default='output/NoCaps')
parser.add_argument('--device', default='cuda')
parser.add_argument('--seed', default=42, type=int)
parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
parser.add_argument('--distributed', default=True, type=bool)
args = parser.parse_args()
config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
args.result_dir = os.path.join(args.output_dir, 'result')
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
Path(args.result_dir).mkdir(parents=True, exist_ok=True)
yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
main(args, config)

View File

@ -0,0 +1,250 @@
'''
* Copyright (c) 2022, salesforce.com, inc.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
* By Junnan Li
'''
import argparse
import os
import ruamel_yaml as yaml
import numpy as np
import random
import time
import datetime
import json
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.distributed as dist
from torch.utils.data import DataLoader
from models.blip_retrieval import blip_retrieval
import utils
from data.video_dataset import VideoDataset
@torch.no_grad()
def evaluation(model, data_loader, tokenizer, device, config):
# test
model.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Evaluation:'
print('Computing features for evaluation...')
start_time = time.time()
texts = data_loader.dataset.text
num_text = len(texts)
text_bs = 256
text_ids = []
text_embeds = []
text_atts = []
for i in range(0, num_text, text_bs):
text = texts[i: min(num_text, i+text_bs)]
text_input = tokenizer(text, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(device)
text_output = model.text_encoder(text_input.input_ids, attention_mask = text_input.attention_mask, mode='text')
text_embed = F.normalize(model.text_proj(text_output.last_hidden_state[:,0,:]))
text_embeds.append(text_embed)
text_ids.append(text_input.input_ids)
text_atts.append(text_input.attention_mask)
text_embeds = torch.cat(text_embeds,dim=0)
text_ids = torch.cat(text_ids,dim=0)
text_atts = torch.cat(text_atts,dim=0)
text_ids[:,0] = tokenizer.additional_special_tokens_ids[0]
video_feats = []
video_embeds = []
for video, video_id in data_loader:
B,N,C,W,H = video.size()
video = video.view(-1,C,W,H)
video = video.to(device,non_blocking=True)
video_feat = model.visual_encoder(video)
video_embed = model.vision_proj(video_feat[:,0,:])
video_embed = video_embed.view(B,N,-1).mean(dim=1)
video_embed = F.normalize(video_embed,dim=-1)
video_feat = video_feat.view(B,-1,video_feat.shape[-1])
video_feats.append(video_feat.cpu())
video_embeds.append(video_embed)
video_feats = torch.cat(video_feats,dim=0)
video_embeds = torch.cat(video_embeds,dim=0)
sims_matrix = video_embeds @ text_embeds.t()
score_matrix_v2t = torch.full((len(texts),len(texts)),-100.0).to(device)
num_tasks = utils.get_world_size()
rank = utils.get_rank()
step = sims_matrix.size(0)//num_tasks + 1
start = rank*step
end = min(sims_matrix.size(0),start+step)
for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)):
topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0)
encoder_output = video_feats[start+i].repeat(config['k_test'],1,1).to(device,non_blocking=True)
encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device,non_blocking=True)
output = model.text_encoder(text_ids[topk_idx],
attention_mask = text_atts[topk_idx],
encoder_hidden_states = encoder_output,
encoder_attention_mask = encoder_att,
return_dict = True,
)
score = model.itm_head(output.last_hidden_state[:,0,:])[:,1]
score_matrix_v2t[start+i,topk_idx] = score + topk_sim
sims_matrix = sims_matrix.t()
score_matrix_t2v = torch.full((len(texts),len(texts)),-100.0).to(device)
step = sims_matrix.size(0)//num_tasks + 1
start = rank*step
end = min(sims_matrix.size(0),start+step)
for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)):
topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0)
encoder_output = video_feats[topk_idx].to(device,non_blocking=True)
encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device,non_blocking=True)
output = model.text_encoder(text_ids[start+i].repeat(config['k_test'],1),
attention_mask = text_atts[start+i].repeat(config['k_test'],1),
encoder_hidden_states = encoder_output,
encoder_attention_mask = encoder_att,
return_dict = True,
)
score = model.itm_head(output.last_hidden_state[:,0,:])[:,1]
score_matrix_t2v[start+i,topk_idx] = score + topk_sim
if args.distributed:
dist.barrier()
torch.distributed.all_reduce(score_matrix_v2t, op=torch.distributed.ReduceOp.SUM)
torch.distributed.all_reduce(score_matrix_t2v, op=torch.distributed.ReduceOp.SUM)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Evaluation time {}'.format(total_time_str))
return score_matrix_v2t.cpu().numpy(), score_matrix_t2v.cpu().numpy()
@torch.no_grad()
def itm_eval(scores_v2t, scores_t2v, txt2vmg, vid2txt):
#Video->Text
ranks = np.zeros(scores_v2t.shape[0])
for index,score in enumerate(scores_v2t):
inds = np.argsort(score)[::-1]
ranks[index] = np.where(inds == vid2txt[index])[0][0]
# Compute metrics
tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
#Text->Video
ranks = np.zeros(scores_t2v.shape[0])
for index,score in enumerate(scores_t2v):
inds = np.argsort(score)[::-1]
ranks[index] = np.where(inds == txt2vmg[index])[0][0]
mdR = np.median(ranks+1)
# Compute metrics
vr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
vr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
vr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
tr_mean = (tr1 + tr5 + tr10) / 3
vr_mean = (vr1 + vr5 + vr10) / 3
r_mean = (tr_mean + vr_mean) / 2
eval_result = {'txt_r1': tr1,
'txt_r5': tr5,
'txt_r10': tr10,
'txt_r_mean': tr_mean,
'vid_r1': vr1,
'vid_r5': vr5,
'vid_r10': vr10,
'vid_r_mean': vr_mean,
'vid_mdR': mdR,
'r_mean': r_mean}
return eval_result
def main(args, config):
utils.init_distributed_mode(args)
device = torch.device(args.device)
# fix the seed for reproducibility
seed = args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
cudnn.benchmark = True
#### Dataset ####
print("Creating retrieval dataset")
test_dataset = VideoDataset(config['video_root'],config['ann_root'],num_frm=config['num_frm_test'],
max_img_size=config['image_size'], frm_sampling_strategy='uniform')
test_loader = DataLoader(
test_dataset,
batch_size=config['batch_size'],
num_workers=4,
pin_memory=True,
drop_last=False,
shuffle=False,
)
#### Model ####
print("Creating model")
model = blip_retrieval(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'])
model = model.to(device)
model_without_ddp = model
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
score_v2t, score_t2v, = evaluation(model_without_ddp, test_loader, model_without_ddp.tokenizer, device, config)
if utils.is_main_process():
test_result = itm_eval(score_v2t, score_t2v, test_loader.dataset.txt2video, test_loader.dataset.video2txt)
print(test_result)
log_stats = {**{f'{k}': v for k, v in test_result.items()},}
with open(os.path.join(args.output_dir, "test_result.txt"),"a") as f:
f.write(json.dumps(log_stats) + "\n")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--config', default='./configs/retrieval_msrvtt.yaml')
parser.add_argument('--output_dir', default='output/Retrieval_msrvtt')
parser.add_argument('--device', default='cuda')
parser.add_argument('--seed', default=42, type=int)
parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
parser.add_argument('--distributed', default=True, type=bool)
args = parser.parse_args()
config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
main(args, config)

View File

@ -0,0 +1,238 @@
'''
* Copyright (c) 2022, salesforce.com, inc.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
* By Junnan Li
'''
import warnings
warnings.filterwarnings("ignore")
from models.vit import VisionTransformer, interpolate_pos_embed
from models.med import BertConfig, BertModel, BertLMHeadModel
from transformers import BertTokenizer
import torch
from torch import nn
import torch.nn.functional as F
import os
from urllib.parse import urlparse
from timm.models.hub import download_cached_file
class BLIP_Base(nn.Module):
def __init__(self,
med_config = 'configs/med_config.json',
image_size = 224,
vit = 'base',
vit_grad_ckpt = False,
vit_ckpt_layer = 0,
):
"""
Args:
med_config (str): path for the mixture of encoder-decoder model's configuration file
image_size (int): input image size
vit (str): model size of vision transformer
"""
super().__init__()
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
self.tokenizer = init_tokenizer()
med_config = BertConfig.from_json_file(med_config)
med_config.encoder_width = vision_width
self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
def forward(self, image, caption, mode):
assert mode in ['image', 'text', 'multimodal'], "mode parameter must be image, text, or multimodal"
text = self.tokenizer(caption, return_tensors="pt").to(image.device)
if mode=='image':
# return image features
image_embeds = self.visual_encoder(image)
return image_embeds
elif mode=='text':
# return text features
text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
return_dict = True, mode = 'text')
return text_output.last_hidden_state
elif mode=='multimodal':
# return multimodel features
image_embeds = self.visual_encoder(image)
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
text.input_ids[:,0] = self.tokenizer.enc_token_id
output = self.text_encoder(text.input_ids,
attention_mask = text.attention_mask,
encoder_hidden_states = image_embeds,
encoder_attention_mask = image_atts,
return_dict = True,
)
return output.last_hidden_state
class BLIP_Decoder(nn.Module):
def __init__(self,
med_config = 'configs/med_config.json',
image_size = 384,
vit = 'base',
vit_grad_ckpt = False,
vit_ckpt_layer = 0,
prompt = 'a picture of ',
):
"""
Args:
med_config (str): path for the mixture of encoder-decoder model's configuration file
image_size (int): input image size
vit (str): model size of vision transformer
"""
super().__init__()
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
self.tokenizer = init_tokenizer()
med_config = BertConfig.from_json_file(med_config)
med_config.encoder_width = vision_width
self.text_decoder = BertLMHeadModel(config=med_config)
self.prompt = prompt
self.prompt_length = len(self.tokenizer(self.prompt).input_ids)-1
def forward(self, image, caption):
image_embeds = self.visual_encoder(image)
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(image.device)
text.input_ids[:,0] = self.tokenizer.bos_token_id
decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100)
decoder_targets[:,:self.prompt_length] = -100
decoder_output = self.text_decoder(text.input_ids,
attention_mask = text.attention_mask,
encoder_hidden_states = image_embeds,
encoder_attention_mask = image_atts,
labels = decoder_targets,
return_dict = True,
)
loss_lm = decoder_output.loss
return loss_lm
def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0):
image_embeds = self.visual_encoder(image)
if not sample:
image_embeds = image_embeds.repeat_interleave(num_beams,dim=0)
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts}
prompt = [self.prompt] * image.size(0)
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device)
input_ids[:,0] = self.tokenizer.bos_token_id
input_ids = input_ids[:, :-1]
if sample:
#nucleus sampling
outputs = self.text_decoder.generate(input_ids=input_ids,
max_length=max_length,
min_length=min_length,
do_sample=True,
top_p=top_p,
num_return_sequences=1,
eos_token_id=self.tokenizer.sep_token_id,
pad_token_id=self.tokenizer.pad_token_id,
repetition_penalty=1.1,
**model_kwargs)
else:
#beam search
outputs = self.text_decoder.generate(input_ids=input_ids,
max_length=max_length,
min_length=min_length,
num_beams=num_beams,
eos_token_id=self.tokenizer.sep_token_id,
pad_token_id=self.tokenizer.pad_token_id,
repetition_penalty=repetition_penalty,
**model_kwargs)
captions = []
for output in outputs:
caption = self.tokenizer.decode(output, skip_special_tokens=True)
captions.append(caption[len(self.prompt):])
return captions
def blip_decoder(pretrained='',**kwargs):
model = BLIP_Decoder(**kwargs)
if pretrained:
model,msg = load_checkpoint(model,pretrained)
assert(len(msg.missing_keys)==0)
return model
def blip_feature_extractor(pretrained='',**kwargs):
model = BLIP_Base(**kwargs)
if pretrained:
model,msg = load_checkpoint(model,pretrained)
assert(len(msg.missing_keys)==0)
return model
def init_tokenizer():
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
tokenizer.add_special_tokens({'bos_token':'[DEC]'})
tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
return tokenizer
def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
assert vit in ['base', 'large'], "vit parameter must be base or large"
if vit=='base':
vision_width = 768
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
drop_path_rate=0 or drop_path_rate
)
elif vit=='large':
vision_width = 1024
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
drop_path_rate=0.1 or drop_path_rate
)
return visual_encoder, vision_width
def is_url(url_or_filename):
parsed = urlparse(url_or_filename)
return parsed.scheme in ("http", "https")
def load_checkpoint(model,url_or_filename):
if is_url(url_or_filename):
cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
checkpoint = torch.load(cached_file, map_location='cpu')
elif os.path.isfile(url_or_filename):
checkpoint = torch.load(url_or_filename, map_location='cpu')
else:
raise RuntimeError('checkpoint url or path is invalid')
state_dict = checkpoint['model']
state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
model.visual_encoder_m)
for key in model.state_dict().keys():
if key in state_dict.keys():
if state_dict[key].shape!=model.state_dict()[key].shape:
del state_dict[key]
msg = model.load_state_dict(state_dict,strict=False)
print('load checkpoint from %s'%url_or_filename)
return model,msg

View File

@ -0,0 +1,76 @@
from models.med import BertConfig, BertModel
from transformers import BertTokenizer
import torch
from torch import nn
import torch.nn.functional as F
from models.blip import create_vit, init_tokenizer, load_checkpoint
class BLIP_ITM(nn.Module):
def __init__(self,
med_config = 'configs/med_config.json',
image_size = 384,
vit = 'base',
vit_grad_ckpt = False,
vit_ckpt_layer = 0,
embed_dim = 256,
):
"""
Args:
med_config (str): path for the mixture of encoder-decoder model's configuration file
image_size (int): input image size
vit (str): model size of vision transformer
"""
super().__init__()
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
self.tokenizer = init_tokenizer()
med_config = BertConfig.from_json_file(med_config)
med_config.encoder_width = vision_width
self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
text_width = self.text_encoder.config.hidden_size
self.vision_proj = nn.Linear(vision_width, embed_dim)
self.text_proj = nn.Linear(text_width, embed_dim)
self.itm_head = nn.Linear(text_width, 2)
def forward(self, image, caption, match_head='itm'):
image_embeds = self.visual_encoder(image)
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35,
return_tensors="pt").to(image.device)
if match_head=='itm':
output = self.text_encoder(text.input_ids,
attention_mask = text.attention_mask,
encoder_hidden_states = image_embeds,
encoder_attention_mask = image_atts,
return_dict = True,
)
itm_output = self.itm_head(output.last_hidden_state[:,0,:])
return itm_output
elif match_head=='itc':
text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
return_dict = True, mode = 'text')
image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1)
text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1)
sim = image_feat @ text_feat.t()
return sim
def blip_itm(pretrained='',**kwargs):
model = BLIP_ITM(**kwargs)
if pretrained:
model,msg = load_checkpoint(model,pretrained)
assert(len(msg.missing_keys)==0)
return model

Some files were not shown because too many files have changed in this diff Show More