mirror of https://github.com/vladmandic/automatic
Merge branch 'dev' into temp
commit
0d7807acd3
|
|
@ -43,7 +43,6 @@ cache
|
|||
!package.json
|
||||
|
||||
# all dynamic stuff
|
||||
/repositories/**/*
|
||||
/extensions/**/*
|
||||
/outputs/**/*
|
||||
/embeddings/**/*
|
||||
|
|
@ -60,6 +59,5 @@ cache
|
|||
/localizations
|
||||
|
||||
# unexcluded so folders get created
|
||||
!/repositories/.placeholder
|
||||
!/models/VAE-approx
|
||||
!/models/VAE-approx/model.pt
|
||||
|
|
|
|||
|
|
@ -32,3 +32,7 @@
|
|||
path = extensions-builtin/sd-extension-chainner
|
||||
url = https://github.com/vladmandic/sd-extension-chainner
|
||||
ignore = dirty
|
||||
[submodule "modules/k-diffusion"]
|
||||
path = modules/k-diffusion
|
||||
url = https://github.com/crowsonkb/k-diffusion
|
||||
ignore = dirty
|
||||
|
|
|
|||
|
|
@ -151,6 +151,7 @@ disable=bad-inline-option,
|
|||
missing-function-docstring,
|
||||
missing-module-docstring,
|
||||
no-else-return,
|
||||
not-callable,
|
||||
pointless-string-statement,
|
||||
raw-checker-failed,
|
||||
simplifiable-if-expression,
|
||||
|
|
|
|||
82
CHANGELOG.md
82
CHANGELOG.md
|
|
@ -1,67 +1,79 @@
|
|||
# Change Log for SD.Next
|
||||
|
||||
## Update for 2023-10-25
|
||||
## Update for 2023-10-30
|
||||
|
||||
*Note*: Pending release of `diffusers==0.22.0`
|
||||
|
||||
Mostly service release with support for several new models and additional optimizations...
|
||||
Another pretty big release, this time with focus on
|
||||
new models, new backends and optimizations and tons of fixes
|
||||
|
||||
Also, [Wiki](https://github.com/vladmandic/automatic/wiki) has been updated with new content, so check it out!
|
||||
Some highlights: [OpenVINO](https://github.com/vladmandic/automatic/wiki/OpenVINO), [IntelArc](https://github.com/vladmandic/automatic/wiki/Intel-ARC), [DirectML](https://github.com/vladmandic/automatic/wiki/DirectML), [ONNX/Olive>](https://github.com/vladmandic/automatic/wiki/ONNX-Runtime)
|
||||
|
||||
- **Diffusers**
|
||||
- new model type: [SegMind SSD-1B](https://huggingface.co/segmind/SSD-1B)
|
||||
its a distilled model, this time 50% smaller and faster version of SD-XL!
|
||||
- new model type: [Segmind SSD-1B](https://huggingface.co/segmind/SSD-1B)
|
||||
its a *distilled* model, this time 50% smaller and faster version of SD-XL!
|
||||
(and quality does not suffer, its just more optimized)
|
||||
test shows batch-size:4 with 1k images used less than 6.5GB of VRAM
|
||||
download using built-in **Huggingface** downloader: `segmind/SSD-1B`
|
||||
- new model type: [LCM: Latent Consistency Models](https://github.com/openai/consistency_models)
|
||||
near-instant generate in a as little as 3 steps!
|
||||
combined with OpenVINO, generate on CPU takes less than 10 seconds: <https://www.youtube.com/watch?v=b90ESUTLsRo>
|
||||
download using built-in **Huggingface** downloader: `SimianLuo/LCM_Dreamshaper_v7`
|
||||
- support for **Custom pipelines**, thanks @disty0
|
||||
download using built-in **Huggingface** downloader
|
||||
think of them as plugins for diffusers not unlike original extensions that modify behavior of `ldm` backend
|
||||
list of community pipelines: <https://github.com/huggingface/diffusers/tree/main/examples/community>
|
||||
and make sure to check our reference one: `Disty0/zero123plus-pipeline`
|
||||
which generates 4 output images with different camera positions: front, side, top, back!
|
||||
list of community pipelines: <https://github.com/huggingface/diffusers/blob/main/examples/community/README.md>
|
||||
- new custom pipeline: `Disty0/zero123plus-pipeline`
|
||||
generate 4 output images with different camera positions: front, side, top, back!
|
||||
for more details, see <https://github.com/vladmandic/automatic/discussions/2421>
|
||||
- new backend: **ONNX/Olive** (experimental)
|
||||
for details, see WiKi
|
||||
- extend support for [Free-U](https://github.com/ChenyangSi/FreeU)
|
||||
improve generations quality at no cost (other than finding params that work for you)
|
||||
- **General**
|
||||
- add **Lora OFT** support, thanks @antis0007 and @ai-casanova
|
||||
- **Upscalers**
|
||||
- **compile compile** option, thanks @disty0
|
||||
- **compile** option, thanks @disty0
|
||||
- **chaiNNer** add high quality models from [Helaman](https://openmodeldb.info/users/helaman)
|
||||
- redesigned **progress bar** with full details on current operation
|
||||
- redesigned **Progress bar** with full details on current operation
|
||||
- **Extra networks** sort by name, size, date, etc.
|
||||
- new option: *settings -> images -> keep incomplete*
|
||||
can be used to skip vae decode on aborted/skipped/interrupted image generations
|
||||
- new option: *settings -> system paths -> models*
|
||||
can be used to set custom base path for *all* models (previously only as cli option)
|
||||
- remove external clone of items in `/repositories`
|
||||
- switch core font in default theme to **noto-sans**
|
||||
previously default font was simply *system-ui*, but it lead to too much variations between browsers and platforms
|
||||
- **Fixes**
|
||||
- fix **freeu** for backend original and add it to xyz grid
|
||||
- fix loading diffuser models in huggingface format from non-standard location
|
||||
- fix default styles looking in wrong location
|
||||
- fix missing upscaler folder on initial startup
|
||||
- fix handling of relative path for models
|
||||
- fix simple live preview device mismatch
|
||||
- fix batch img2img
|
||||
- fix diffusers samplers: dpm++ 2m, dpm++ 1s, deis
|
||||
- fix new style filename template
|
||||
- fix image name template using model name
|
||||
- fix image name sequence
|
||||
- fix model path using relative path
|
||||
- fix `torch-rocm` and `tensorflow-rocm` version detection, thanks @xangelix
|
||||
- fix **chainner** upscalers color clipping
|
||||
- **Fixes**
|
||||
- fix **freeu** for backend original and add it to xyz grid
|
||||
- fix loading diffuser models in huggingface format from non-standard location
|
||||
- fix default styles looking in wrong location
|
||||
- fix missing upscaler folder on initial startup
|
||||
- fix handling of relative path for models
|
||||
- fix simple live preview device mismatch
|
||||
- fix batch img2img
|
||||
- fix diffusers samplers: dpm++ 2m, dpm++ 1s, deis
|
||||
- fix new style filename template
|
||||
- fix image name template using model name
|
||||
- fix image name sequence
|
||||
- fix model path using relative path
|
||||
- fix `torch-rocm` and `tensorflow-rocm` version detection, thanks @xangelix
|
||||
- fix **chainner** upscalers color clipping
|
||||
- fix for base+refiner workflow in diffusers mode: number of steps, diffuser pipe mode
|
||||
- fix for prompt encoder with refiner in diffusers mode
|
||||
- fix prompts-from-file saving incorrect metadata
|
||||
- fix before-hires step
|
||||
- fix diffusers switch from invalid model
|
||||
- **directml** and **ipex** updates
|
||||
- force second requirements check on startup
|
||||
- remove lyco, multiple_tqdm
|
||||
- fix for prompt encoder with refiner in diffusers mode
|
||||
- fix prompts-from-file saving incorrect metadata
|
||||
- fix before-hires step
|
||||
- fix diffusers switch from invalid model
|
||||
- **directml** and **ipex** updates
|
||||
- force second requirements check on startup
|
||||
- remove **lyco**, multiple_tqdm
|
||||
- enhance extension compatibility for exensions directly importing codeformers
|
||||
- enhance extension compatibility for exensions directly accessing processing params
|
||||
- css fixes
|
||||
- clearly mark external themes in ui
|
||||
- update `openvino`, thanks @disty0
|
||||
- update `typing-extensions`
|
||||
- **css** fixes
|
||||
- clearly mark external themes in ui
|
||||
- update `openvino`, thanks @disty0
|
||||
- update `typing-extensions`
|
||||
|
||||
## Update for 2023-10-17
|
||||
|
||||
|
|
|
|||
22
README.md
22
README.md
|
|
@ -56,21 +56,23 @@ Additional models will be added as they become available and there is public int
|
|||
- [Stable Diffusion](https://github.com/Stability-AI/stablediffusion/) 1.x and 2.x *(all variants)*
|
||||
- [Stable Diffusion XL](https://github.com/Stability-AI/generative-models)
|
||||
- [Kandinsky](https://github.com/ai-forever/Kandinsky-2) 2.1 and 2.2
|
||||
- [DeepFloyd IF](https://github.com/deep-floyd/IF)
|
||||
- [UniDiffusion](https://github.com/thu-ml/unidiffuser)
|
||||
- [SD-Distilled](https://huggingface.co/blog/sd_distillation) *(all variants)*
|
||||
- [Wuerstchen](https://huggingface.co/blog/wuertschen)
|
||||
- [Segmind SSD-1B](https://huggingface.co/segmind/SSD-1B)
|
||||
- [UniDiffusion](https://github.com/thu-ml/unidiffuser)
|
||||
- [DeepFloyd IF](https://github.com/deep-floyd/IF)
|
||||
|
||||
## Platform support
|
||||
|
||||
- *nVidia* GPUs using **CUDA** libraries on both *Windows and Linux*
|
||||
- *AMD* GPUs using **ROCm** libraries on *Linux*.
|
||||
Support will be extended to *Windows* once AMD releases ROCm for Windows
|
||||
- *Intel Arc* GPUs using **OneAPI** with *IPEX XPU* libraries on both *Windows and Linux*
|
||||
- Any GPU compatible with *DirectX* on *Windows* using **DirectML** libraries.
|
||||
This includes support for AMD GPUs that are not supported by native ROCm libraries
|
||||
- Any GPU or device compatible with **OpenVINO** libraries on both *Windows and Linux*
|
||||
- *Apple M1/M2* on *OSX* using built-in support in Torch with **MPS** optimizations
|
||||
- *nVidia* GPUs using **CUDA** libraries on both *Windows and Linux*
|
||||
- *AMD* GPUs using **ROCm** libraries on *Linux*
|
||||
Support will be extended to *Windows* once AMD releases ROCm for Windows
|
||||
- *Intel Arc* GPUs using **OneAPI** with *IPEX XPU* libraries on both *Windows and Linux*
|
||||
- Any GPU compatible with *DirectX* on *Windows* using **DirectML** libraries
|
||||
This includes support for AMD GPUs that are not supported by native ROCm libraries
|
||||
- Any GPU or device compatible with **OpenVINO** libraries on both *Windows and Linux*
|
||||
- *Apple M1/M2* on *OSX* using built-in support in Torch with **MPS** optimizations
|
||||
- *ONNX/Olive* (experimental)
|
||||
|
||||
## Install & Run
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,80 @@
|
|||
model:
|
||||
base_learning_rate: 1.0e-04
|
||||
target: ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion
|
||||
params:
|
||||
embedding_dropout: 0.25
|
||||
parameterization: "v"
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: "jpg"
|
||||
cond_stage_key: "txt"
|
||||
image_size: 96
|
||||
channels: 4
|
||||
cond_stage_trainable: false
|
||||
conditioning_key: crossattn-adm
|
||||
scale_factor: 0.18215
|
||||
monitor: val/loss_simple_ema
|
||||
use_ema: False
|
||||
|
||||
embedder_config:
|
||||
target: ldm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
|
||||
|
||||
noise_aug_config:
|
||||
target: ldm.modules.encoders.modules.CLIPEmbeddingNoiseAugmentation
|
||||
params:
|
||||
timestep_dim: 1024
|
||||
noise_schedule_config:
|
||||
timesteps: 1000
|
||||
beta_schedule: squaredcos_cap_v2
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
num_classes: "sequential"
|
||||
adm_in_channels: 2048
|
||||
use_checkpoint: True
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_head_channels: 64 # need to fix for flash-attn
|
||||
use_spatial_transformer: True
|
||||
use_linear_in_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 1024
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
attn_type: "vanilla-xformers"
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: [ ]
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
||||
params:
|
||||
freeze: True
|
||||
layer: "penultimate"
|
||||
|
|
@ -0,0 +1,83 @@
|
|||
model:
|
||||
base_learning_rate: 1.0e-04
|
||||
target: ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion
|
||||
params:
|
||||
embedding_dropout: 0.25
|
||||
parameterization: "v"
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: "jpg"
|
||||
cond_stage_key: "txt"
|
||||
image_size: 96
|
||||
channels: 4
|
||||
cond_stage_trainable: false
|
||||
conditioning_key: crossattn-adm
|
||||
scale_factor: 0.18215
|
||||
monitor: val/loss_simple_ema
|
||||
use_ema: False
|
||||
|
||||
embedder_config:
|
||||
target: ldm.modules.encoders.modules.ClipImageEmbedder
|
||||
params:
|
||||
model: "ViT-L/14"
|
||||
|
||||
noise_aug_config:
|
||||
target: ldm.modules.encoders.modules.CLIPEmbeddingNoiseAugmentation
|
||||
params:
|
||||
clip_stats_path: "checkpoints/karlo_models/ViT-L-14_stats.th"
|
||||
timestep_dim: 768
|
||||
noise_schedule_config:
|
||||
timesteps: 1000
|
||||
beta_schedule: squaredcos_cap_v2
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
num_classes: "sequential"
|
||||
adm_in_channels: 1536
|
||||
use_checkpoint: True
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_head_channels: 64 # need to fix for flash-attn
|
||||
use_spatial_transformer: True
|
||||
use_linear_in_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 1024
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
attn_type: "vanilla-xformers"
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: [ ]
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
||||
params:
|
||||
freeze: True
|
||||
layer: "penultimate"
|
||||
|
|
@ -0,0 +1,74 @@
|
|||
model:
|
||||
base_learning_rate: 5.0e-07
|
||||
target: ldm.models.diffusion.ddpm.LatentDepth2ImageDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: "jpg"
|
||||
cond_stage_key: "txt"
|
||||
image_size: 64
|
||||
channels: 4
|
||||
cond_stage_trainable: false
|
||||
conditioning_key: hybrid
|
||||
scale_factor: 0.18215
|
||||
monitor: val/loss_simple_ema
|
||||
finetune_keys: null
|
||||
use_ema: False
|
||||
|
||||
depth_stage_config:
|
||||
target: ldm.modules.midas.api.MiDaSInference
|
||||
params:
|
||||
model_type: "dpt_hybrid"
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
use_checkpoint: True
|
||||
image_size: 32 # unused
|
||||
in_channels: 5
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_head_channels: 64 # need to fix for flash-attn
|
||||
use_spatial_transformer: True
|
||||
use_linear_in_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 1024
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
#attn_type: "vanilla-xformers"
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: [ ]
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
|
||||
params:
|
||||
freeze: True
|
||||
layer: "penultimate"
|
||||
|
||||
|
||||
|
|
@ -112,11 +112,12 @@ class KeyConvert:
|
|||
self.converter = self.diffusers
|
||||
self.is_sdxl = True if shared.sd_model_type == "sdxl" else False
|
||||
self.UNET_CONVERSION_MAP = make_unet_conversion_map() if self.is_sdxl else None
|
||||
self.LORA_PREFIX_UNET = "lora_unet"
|
||||
self.LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
||||
self.LORA_PREFIX_UNET = "lora_unet_"
|
||||
self.LORA_PREFIX_TEXT_ENCODER = "lora_te_"
|
||||
self.OFT_PREFIX_UNET = "oft_unet_"
|
||||
# SDXL: must starts with LORA_PREFIX_TEXT_ENCODER
|
||||
self.LORA_PREFIX_TEXT_ENCODER1 = "lora_te1"
|
||||
self.LORA_PREFIX_TEXT_ENCODER2 = "lora_te2"
|
||||
self.LORA_PREFIX_TEXT_ENCODER1 = "lora_te1_"
|
||||
self.LORA_PREFIX_TEXT_ENCODER2 = "lora_te2_"
|
||||
|
||||
def original(self, key):
|
||||
key = convert_diffusers_name_to_compvis(key, self.is_sd2)
|
||||
|
|
@ -142,13 +143,12 @@ class KeyConvert:
|
|||
if self.is_sdxl:
|
||||
map_keys = list(self.UNET_CONVERSION_MAP.keys()) # prefix of U-Net modules
|
||||
map_keys.sort()
|
||||
search_key = key.replace(self.LORA_PREFIX_UNET + "_", "").replace(self.LORA_PREFIX_TEXT_ENCODER1 + "_",
|
||||
"").replace(
|
||||
self.LORA_PREFIX_TEXT_ENCODER2 + "_", "")
|
||||
search_key = key.replace(self.LORA_PREFIX_UNET, "").replace(self.OFT_PREFIX_UNET, "").replace(self.LORA_PREFIX_TEXT_ENCODER1, "").replace(self.LORA_PREFIX_TEXT_ENCODER2, "")
|
||||
|
||||
position = bisect.bisect_right(map_keys, search_key)
|
||||
map_key = map_keys[position - 1]
|
||||
if search_key.startswith(map_key):
|
||||
key = key.replace(map_key, self.UNET_CONVERSION_MAP[map_key]) # pylint: disable=unsubscriptable-object
|
||||
key = key.replace(map_key, self.UNET_CONVERSION_MAP[map_key]).replace("oft","lora") # pylint: disable=unsubscriptable-object
|
||||
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
|
||||
return key, sd_module
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,49 @@
|
|||
import torch
|
||||
import diffusers.models.lora as diffusers_lora
|
||||
import network
|
||||
from modules import devices
|
||||
|
||||
class ModuleTypeOFT(network.ModuleType):
|
||||
def create_module(self, net: network.Network, weights: network.NetworkWeights):
|
||||
"""
|
||||
weights.w.items()
|
||||
|
||||
alpha : tensor(0.0010, dtype=torch.bfloat16)
|
||||
oft_blocks : tensor([[[ 0.0000e+00, 1.4400e-04, 1.7319e-03, ..., -8.8882e-04,
|
||||
5.7373e-03, -4.4250e-03],
|
||||
[-1.4400e-04, 0.0000e+00, 8.6594e-04, ..., 1.5945e-03,
|
||||
-8.5449e-04, 1.9684e-03], ...etc...
|
||||
, dtype=torch.bfloat16)"""
|
||||
|
||||
if "oft_blocks" in weights.w.keys():
|
||||
module = NetworkModuleOFT(net, weights)
|
||||
return module
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class NetworkModuleOFT(network.NetworkModule):
|
||||
def __init__(self, net: network.Network, weights: network.NetworkWeights):
|
||||
super().__init__(net, weights)
|
||||
|
||||
self.weights = weights.w.get("oft_blocks").to(device=devices.device)
|
||||
self.dim = self.weights.shape[0] # num blocks
|
||||
self.alpha = self.multiplier()
|
||||
self.block_size = self.weights.shape[-1]
|
||||
|
||||
def get_weight(self):
|
||||
block_Q = self.weights - self.weights.transpose(1, 2)
|
||||
I = torch.eye(self.block_size, device=devices.device).unsqueeze(0).repeat(self.dim, 1, 1)
|
||||
block_R = torch.matmul(I + block_Q, (I - block_Q).inverse())
|
||||
block_R_weighted = self.alpha * block_R + (1 - self.alpha) * I
|
||||
R = torch.block_diag(*block_R_weighted)
|
||||
return R
|
||||
|
||||
def calc_updown(self, orig_weight):
|
||||
R = self.get_weight().to(device=devices.device, dtype=orig_weight.dtype)
|
||||
if orig_weight.dim() == 4:
|
||||
updown = torch.einsum("oihw, op -> pihw", orig_weight, R) * self.calc_scale()
|
||||
else:
|
||||
updown = torch.einsum("oi, op -> pi", orig_weight, R) * self.calc_scale()
|
||||
|
||||
return self.finalize_updown(updown, orig_weight, orig_weight.shape)
|
||||
|
|
@ -7,6 +7,7 @@ import network
|
|||
import network_lora
|
||||
import network_hada
|
||||
import network_ia3
|
||||
import network_oft
|
||||
import network_lokr
|
||||
import network_full
|
||||
import network_norm
|
||||
|
|
@ -32,6 +33,7 @@ module_types = [
|
|||
network_lora.ModuleTypeLora(),
|
||||
network_hada.ModuleTypeHada(),
|
||||
network_ia3.ModuleTypeIa3(),
|
||||
network_oft.ModuleTypeOFT(),
|
||||
network_lokr.ModuleTypeLokr(),
|
||||
network_full.ModuleTypeFull(),
|
||||
network_norm.ModuleTypeNorm(),
|
||||
|
|
|
|||
|
|
@ -53,6 +53,8 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
|
|||
"local_preview": f"{path}.{shared.opts.samples_format}",
|
||||
"metadata": json.dumps(l.metadata, indent=4) if l.metadata else None,
|
||||
"tags": tags,
|
||||
"mtime": os.path.getmtime(l.filename),
|
||||
"size": os.path.getsize(l.filename),
|
||||
}
|
||||
return item
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
Subproject commit e382d1618593ae05a7115006de8680d3ddbd9777
|
||||
Subproject commit f2aafcf2beb99a03cbdf7db73852228ccd6bd1d6
|
||||
|
|
@ -1 +1 @@
|
|||
Subproject commit 7fb6a263d66c4167f386bf6f66b12f187d44505e
|
||||
Subproject commit 7f57729626503837a70ad9eed92313bc36db7bf3
|
||||
|
|
@ -16,6 +16,7 @@
|
|||
{"id":"","label":"⇰","localized":"","hint":"Apply selected styles to current prompt"},
|
||||
{"id":"","label":"⇩","localized":"","hint":"Save parameters from last generated image as style template"},
|
||||
{"id":"","label":"🕮","localized":"","hint":"Save parameters from last generated image as style template"},
|
||||
{"id":"","label":"⇕","localized":"","hint":"Sort by: Name asc/desc, Size largest/smallest, Time newest/oldest"},
|
||||
{"id":"","label":"⟲","localized":"","hint":"Refresh"},
|
||||
{"id":"","label":"✕","localized":"","hint":"Close"},
|
||||
{"id":"","label":"⊜","localized":"","hint":"Fill"},
|
||||
|
|
|
|||
|
|
@ -591,6 +591,7 @@ def install_packages():
|
|||
|
||||
# clone required repositories
|
||||
def install_repositories():
|
||||
"""
|
||||
if args.profile:
|
||||
pr = cProfile.Profile()
|
||||
pr.enable()
|
||||
|
|
@ -615,6 +616,7 @@ def install_repositories():
|
|||
clone(blip_repo, d('BLIP'), blip_commit)
|
||||
if args.profile:
|
||||
print_profile(pr, 'Repositories')
|
||||
"""
|
||||
|
||||
|
||||
# run extension installer
|
||||
|
|
@ -659,7 +661,7 @@ def install_extensions():
|
|||
pkg_resources._initialize_master_working_set() # pylint: disable=protected-access
|
||||
pkgs = [f'{p.project_name}=={p._version}' for p in pkg_resources.working_set] # pylint: disable=protected-access,not-an-iterable
|
||||
log.debug(f'Installed packages: {len(pkgs)}')
|
||||
from modules.paths_internal import extensions_builtin_dir, extensions_dir
|
||||
from modules.paths import extensions_builtin_dir, extensions_dir
|
||||
extensions_duplicates = []
|
||||
extensions_enabled = []
|
||||
extension_folders = [extensions_builtin_dir] if args.safe else [extensions_builtin_dir, extensions_dir]
|
||||
|
|
@ -793,7 +795,7 @@ def set_environment():
|
|||
|
||||
def check_extensions():
|
||||
newest_all = os.path.getmtime('requirements.txt')
|
||||
from modules.paths_internal import extensions_builtin_dir, extensions_dir
|
||||
from modules.paths import extensions_builtin_dir, extensions_dir
|
||||
extension_folders = [extensions_builtin_dir] if args.safe else [extensions_builtin_dir, extensions_dir]
|
||||
disabled_extensions_all = opts.get('disable_all_extensions', 'none')
|
||||
if disabled_extensions_all != 'none':
|
||||
|
|
@ -981,7 +983,7 @@ def extensions_preload(parser):
|
|||
log.info('Running in safe mode without user extensions')
|
||||
try:
|
||||
from modules.script_loading import preload_extensions
|
||||
from modules.paths_internal import extensions_builtin_dir, extensions_dir
|
||||
from modules.paths import extensions_builtin_dir, extensions_dir
|
||||
extension_folders = [extensions_builtin_dir] if args.safe else [extensions_builtin_dir, extensions_dir]
|
||||
preload_time = {}
|
||||
for ext_dir in extension_folders:
|
||||
|
|
|
|||
|
|
@ -119,9 +119,6 @@ svg.feather.feather-image, .feather .feather-image { display: none }
|
|||
.gradio-button.tool { filter: hue-rotate(180deg) saturate(0.5); }
|
||||
#txt2img_styles_row, #img2img_styles_row { margin-top: -6px; z-index: 200; }
|
||||
|
||||
/* custom elements overrides */
|
||||
#steps-animation, #controlnet { border-width: 0; }
|
||||
|
||||
/* based on gradio built-in dark theme */
|
||||
:root, .light, .dark {
|
||||
--body-background-fill: #000000; /* Black */
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@
|
|||
width: 22em; min-height: 1.3em; font-size: 0.8em; transition: opacity 0.2s ease-in; pointer-events: none; opacity: 0; z-index: 999; }
|
||||
.tooltip-show { opacity: 0.9; }
|
||||
.toolbutton-selected { background: var(--background-fill-primary) !important; }
|
||||
.jobStatus { position: fixed; bottom: 1em; right: 1em; background: var(--input-background-fill); padding: 0.4em; font-size: 0.8em; color: var(--body-text-color-subdued); }
|
||||
|
||||
/* live preview */
|
||||
.progressDiv{ position: relative; height: 20px; background: #b4c0cc; margin-bottom: -3px; }
|
||||
|
|
@ -94,7 +93,7 @@ table.settings-value-table td { padding: 0.4em; border: 1px solid #ccc; max-widt
|
|||
.extra-network-cards { display: flex; flex-wrap: wrap; overflow-y: auto; overflow-x: hidden; align-content: flex-start; width: -moz-available; width: -webkit-fill-available; }
|
||||
.extra-network-cards .card { height: fit-content; margin: 0 0 0.5em 0.5em; position: relative; scroll-snap-align: start; scroll-margin-top: 0; }
|
||||
.extra-network-cards .card .overlay { position: absolute; bottom: 0; padding: 0.2em; z-index: 10; width: 100%; background: none; }
|
||||
.extra-network-cards .card .overlay .name { font-size: 1.1em; font-weight: bold; text-shadow: 1px 1px black; color: white; overflow-wrap: break-word; }
|
||||
.extra-network-cards .card .overlay .name { text-shadow: 1px 1px black; color: white; overflow-wrap: break-word; }
|
||||
.extra-network-cards .card .preview { box-shadow: var(--button-shadow); min-height: 30px; }
|
||||
.extra-network-cards .card:hover .overlay { background: rgba(0, 0, 0, 0.40); }
|
||||
.extra-network-cards .card:hover .preview { box-shadow: none; filter: grayscale(100%); }
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
/* generic html tags */
|
||||
@font-face { font-family: 'NotoSans'; font-display: swap; font-style: normal; font-weight: 100; src: local('NotoSans'), url('notosans-nerdfont-regular.ttf') }
|
||||
:root, .light, .dark {
|
||||
--font: "Source Sans Pro", 'ui-sans-serif', 'system-ui', "Roboto", sans-serif;
|
||||
--font-mono: 'IBM Plex Mono', 'ui-monospace', 'Consolas', monospace;
|
||||
--font: 'NotoSans';
|
||||
--font-mono: 'ui-monospace', 'Consolas', monospace;
|
||||
--font-size: 16px;
|
||||
--left-column: 490px;
|
||||
--highlight-color: #ce6400;
|
||||
|
|
@ -18,15 +19,28 @@
|
|||
--primary-800: #9a3412;
|
||||
--primary-900: #7c2d12;
|
||||
--primary-950: #6c2e12;
|
||||
}
|
||||
.light, .dark {
|
||||
--highlight-color: var(--primary-200);
|
||||
--inactive-color: var(--primary--800);
|
||||
--body-text-color: var(--neutral-100);
|
||||
--body-text-color-subdued: var(--neutral-300);
|
||||
--background-color: #000000;
|
||||
--background-fill-primary: var(--neutral-700);
|
||||
--input-padding: 4px;
|
||||
--radius-lg: 2px;
|
||||
--radius-sm: 1px;
|
||||
--input-background-fill: var(--neutral-800);
|
||||
--input-shadow: 2px 2px 2px 2px var(--background-color);
|
||||
--button-secondary-text-color: white;
|
||||
--button-secondary-background-fill: linear-gradient(to bottom right, var(--neutral-400), var(--neutral-700));
|
||||
--button-secondary-background-fill-hover: linear-gradient(to bottom right, var(--neutral-700), var(--neutral-400));
|
||||
--block-title-text-color: var(--neutral-300);
|
||||
--radius-sm: 2px;
|
||||
--radius-lg: 4px;
|
||||
--spacing-md: 4px;
|
||||
--spacing-xxl: 12px;
|
||||
--line-sm: 1.3em;
|
||||
--line-md: 1.3em;
|
||||
--spacing-xxl: 6px;
|
||||
--line-sm: 1.2em;
|
||||
--line-md: 1.4em;
|
||||
--text-sm: 12px;
|
||||
--text-md: 13px;
|
||||
--text-lg: 15px;
|
||||
}
|
||||
|
||||
html { font-size: var(--font-size); }
|
||||
|
|
@ -119,9 +133,6 @@ svg.feather.feather-image, .feather .feather-image { display: none }
|
|||
.gradio-button.tool { filter: hue-rotate(180deg) saturate(0.5); }
|
||||
#txt2img_styles_row, #img2img_styles_row { margin-top: -6px; z-index: 200; }
|
||||
|
||||
/* custom elements overrides */
|
||||
#steps-animation, #controlnet { border-width: 0; }
|
||||
|
||||
/* based on gradio built-in dark theme */
|
||||
:root, .light, .dark {
|
||||
--body-background-fill: var(--background-color);
|
||||
|
|
@ -244,9 +255,6 @@ svg.feather.feather-image, .feather .feather-image { display: none }
|
|||
--radius-xxl: 0;
|
||||
--text-xxs: 9px;
|
||||
--text-xs: 10px;
|
||||
--text-sm: 12px;
|
||||
--text-md: 14px;
|
||||
--text-lg: 16px;
|
||||
--text-xl: 22px;
|
||||
--text-xxl: 26px;
|
||||
--body-text-size: var(--text-md);
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
/* generic html tags */
|
||||
@font-face { font-family: 'NotoSans'; font-display: swap; font-style: normal; font-weight: 100; src: local('NotoSans'), url('notosans-nerdfont-regular.ttf') }
|
||||
:root, .light, .dark {
|
||||
--font: 'system-ui', 'ui-sans-serif', 'system-ui', "Roboto", sans-serif;
|
||||
--font: 'NotoSans';
|
||||
--font-mono: 'ui-monospace', 'Consolas', monospace;
|
||||
--font-size: 16px;
|
||||
--left-column: 490px;
|
||||
|
|
@ -34,10 +35,13 @@
|
|||
--spacing-xxl: 6px;
|
||||
--line-sm: 1.2em;
|
||||
--line-md: 1.4em;
|
||||
--text-sm: 12px;
|
||||
--text-md: 13px;
|
||||
--text-lg: 15px;
|
||||
}
|
||||
|
||||
html { font-size: var(--font-size); }
|
||||
body, button, input, select, textarea { font-family: var(--font);}
|
||||
html { font-size: var(--font-size); font-family: var(--font); }
|
||||
body, button, input, select, textarea { font-family: var(--font); }
|
||||
button { font-size: 1.2rem; max-width: 400px; }
|
||||
img { background-color: var(--background-color); }
|
||||
input[type=range] { height: var(--line-sm) !important; appearance: none !important; margin-top: 0 !important; min-width: 160px !important;
|
||||
|
|
@ -84,7 +88,8 @@ svg.feather.feather-image, .feather .feather-image { display: none }
|
|||
.tab-nav { zoom: 120%; margin-top: 10px; margin-bottom: 10px; border-bottom: 2px solid var(--highlight-color) !important; padding-bottom: 2px; }
|
||||
.label-wrap { margin: 8px 0px 4px 0px; }
|
||||
.gradio-button.tool { border: none; background: none; box-shadow: none; filter: hue-rotate(340deg) saturate(0.5); }
|
||||
#tab_extensions table td, #tab_extensions table th, #tab_config table td, #tab_config table th { border: none; padding: 0.5em; }
|
||||
#tab_extensions table td, #tab_extensions table th, #tab_config table td, #tab_config table th { border: none; }
|
||||
#tab_extensions table tr:hover, #tab_config table tr:hover { background-color: var(--neutral-500) !important; }
|
||||
#tab_extensions table, #tab_config table { width: 96vw }
|
||||
#tab_extensions table thead, #tab_config table thead { background-color: var(--neutral-700); }
|
||||
#tab_extensions table, #tab_config table { background-color: #222222; }
|
||||
|
|
@ -130,9 +135,6 @@ textarea[rows="1"] { height: 33px !important; width: 99% !important; padding: 8p
|
|||
#pnginfo_html2_info { margin-top: -18px; background-color: var(--input-background-fill); padding: var(--input-padding) }
|
||||
#txt2img_styles_row, #img2img_styles_row { margin-top: -6px; }
|
||||
|
||||
/* custom elements overrides */
|
||||
#steps-animation, #controlnet { border-width: 0; }
|
||||
|
||||
/* based on gradio built-in dark theme */
|
||||
:root, .light, .dark {
|
||||
--body-background-fill: var(--background-color);
|
||||
|
|
@ -246,9 +248,6 @@ textarea[rows="1"] { height: 33px !important; width: 99% !important; padding: 8p
|
|||
--radius-xxl: 0;
|
||||
--text-xxs: 9px;
|
||||
--text-xs: 10px;
|
||||
--text-sm: 12px;
|
||||
--text-md: 14px;
|
||||
--text-lg: 16px;
|
||||
--text-xl: 22px;
|
||||
--text-xxl: 26px;
|
||||
--body-text-size: var(--text-md);
|
||||
|
|
|
|||
|
|
@ -58,7 +58,8 @@ function readCardTags(el, tags) {
|
|||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
const textarea = activePromptTextarea[getENActiveTab()];
|
||||
if (textarea.value.indexOf(tag) !== -1) textarea.value = textarea.value.replace(tag, '');
|
||||
if (textarea.value.indexOf(` ${tag}`) !== -1) textarea.value = textarea.value.replace(` ${tag}`, '');
|
||||
else if (textarea.value.indexOf(`${tag} `) !== -1) textarea.value = textarea.value.replace(` ${tag} `, '');
|
||||
else textarea.value += ` ${tag}`;
|
||||
updateInput(textarea);
|
||||
};
|
||||
|
|
@ -144,6 +145,38 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text) {
|
|||
return false;
|
||||
}
|
||||
|
||||
let sortVal = 0;
|
||||
|
||||
function sortExtraNetworks() {
|
||||
const sortDesc = ['Name [A-Z]', 'Name [Z-A]', 'Date [Newest]', 'Date [Oldest]', 'Size [Largest]', 'Size [Smallest]'];
|
||||
const pagename = getENActivePage();
|
||||
if (!pagename) return 'sort error: unknown page';
|
||||
const allPages = Array.from(gradioApp().querySelectorAll('.extra-network-cards'));
|
||||
const pages = allPages.filter((el) => el.id.includes(pagename.toLowerCase()));
|
||||
let num = 0;
|
||||
for (const pg of pages) {
|
||||
const cards = Array.from(pg.querySelectorAll('.card') || []);
|
||||
num = cards.length;
|
||||
if (num === 0) return 'sort: no cards';
|
||||
cards.sort((a, b) => { // eslint-disable-line no-loop-func
|
||||
switch (sortVal) {
|
||||
case 0: return a.dataset.name ? a.dataset.name.localeCompare(b.dataset.name) : 0;
|
||||
case 1: return b.dataset.name ? b.dataset.name.localeCompare(a.dataset.name) : 0;
|
||||
case 2: return a.dataset.mtime && !isNaN(a.dataset.mtime) ? parseFloat(b.dataset.mtime) - parseFloat(a.dataset.mtime) : 0;
|
||||
case 3: return b.dataset.mtime && !isNaN(b.dataset.mtime) ? parseFloat(a.dataset.mtime) - parseFloat(b.dataset.mtime) : 0;
|
||||
case 4: return a.dataset.size && !isNaN(a.dataset.size) ? parseFloat(b.dataset.size) - parseFloat(a.dataset.size) : 0;
|
||||
case 5: return b.dataset.size && !isNaN(b.dataset.size) ? parseFloat(a.dataset.size) - parseFloat(b.dataset.size) : 0;
|
||||
}
|
||||
return 0;
|
||||
});
|
||||
for (const card of cards) pg.appendChild(card);
|
||||
}
|
||||
const desc = sortDesc[sortVal];
|
||||
sortVal = (sortVal + 1) % sortDesc.length;
|
||||
log('sortExtraNetworks', pagename, num, desc);
|
||||
return `sort page ${pagename} cards ${num} by ${desc}`;
|
||||
}
|
||||
|
||||
function refreshExtraNetworks(tabname) {
|
||||
log('refreshExtraNetworks', tabname, gradioApp().querySelector(`#${tabname}_extra_networks textarea`)?.value);
|
||||
gradioApp().querySelector(`#${tabname}_extra_networks textarea`)?.dispatchEvent(new Event('input'));
|
||||
|
|
@ -195,6 +228,7 @@ function setupExtraNetworksForTab(tabname) {
|
|||
const btnScan = gradioApp().getElementById(`${tabname}_extra_scan`);
|
||||
const btnSave = gradioApp().getElementById(`${tabname}_extra_save`);
|
||||
const btnClose = gradioApp().getElementById(`${tabname}_extra_close`);
|
||||
const btnSort = gradioApp().getElementById(`${tabname}_extra_sort`);
|
||||
const btnModel = gradioApp().getElementById(`${tabname}_extra_model`);
|
||||
const btnApply = gradioApp().getElementById(`${tabname}_extra_apply`);
|
||||
const buttons = document.createElement('span');
|
||||
|
|
@ -204,6 +238,7 @@ function setupExtraNetworksForTab(tabname) {
|
|||
if (btnApply) buttons.appendChild(btnApply);
|
||||
if (btnScan) buttons.appendChild(btnScan);
|
||||
if (btnSave) buttons.appendChild(btnSave);
|
||||
if (btnSort) buttons.appendChild(btnSort);
|
||||
if (btnClose) buttons.appendChild(btnClose);
|
||||
btnModel.onclick = () => btnModel.classList.toggle('toolbutton-selected');
|
||||
tabs.appendChild(buttons);
|
||||
|
|
|
|||
|
|
@ -126,9 +126,6 @@ button.selected {background: var(--button-primary-background-fill);}
|
|||
#txt2img_tools, #img2img_tools { margin-top: -4px; margin-bottom: -4px; }
|
||||
#txt2img_styles_row, #img2img_styles_row { margin-top: -6px; z-index: 200; }
|
||||
|
||||
/* custom elements overrides */
|
||||
#steps-animation, #controlnet { border-width: 0; }
|
||||
|
||||
/* based on gradio built-in dark theme */
|
||||
:root, .light, .dark {
|
||||
--body-background-fill: var(--background-color);
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
/* generic html tags */
|
||||
@font-face { font-family: 'NotoSans'; font-display: swap; font-style: normal; font-weight: 100; src: local('NotoSans'), url('notosans-nerdfont-regular.ttf') }
|
||||
:root, .light, .dark {
|
||||
--font: 'system-ui', 'ui-sans-serif', 'system-ui', "Roboto", sans-serif;
|
||||
--font: 'NotoSans';
|
||||
--font-mono: 'ui-monospace', 'Consolas', monospace;
|
||||
--font-size: 16px;
|
||||
--left-column: 490px;
|
||||
|
|
@ -34,6 +35,9 @@
|
|||
--spacing-xxl: 8px;
|
||||
--line-sm: 1.2em;
|
||||
--line-md: 1.4em;
|
||||
--text-sm: 12px;
|
||||
--text-md: 13px;
|
||||
--text-lg: 15px;
|
||||
}
|
||||
|
||||
html { font-size: var(--font-size); }
|
||||
|
|
@ -126,9 +130,6 @@ svg.feather.feather-image, .feather .feather-image { display: none }
|
|||
#pnginfo_html2_info { margin-top: -18px; background-color: var(--input-background-fill); padding: var(--input-padding) }
|
||||
#txt2img_styles_row, #img2img_styles_row { margin-top: -6px; z-index: 200; }
|
||||
|
||||
/* custom elements overrides */
|
||||
#steps-animation, #controlnet { border-width: 0; }
|
||||
|
||||
/* based on gradio built-in dark theme */
|
||||
:root, .light, .dark {
|
||||
--background-fill-secondary: none;
|
||||
|
|
@ -308,9 +309,6 @@ svg.feather.feather-image, .feather .feather-image { display: none }
|
|||
--table-radius: var(--radius-lg);
|
||||
--table-row-focus: var(--color-accent-soft);
|
||||
--text-lg: 16px;
|
||||
--text-md: 14px;
|
||||
--text-sm: 12px;
|
||||
--text-xl: 22px;
|
||||
--text-xs: 10px;
|
||||
--text-xxl: 26px;
|
||||
--text-xxs: 9px;
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
let logMonitorEl = null;
|
||||
let logMonitorStatus = true;
|
||||
let jobStatusEl = null;
|
||||
|
||||
async function logMonitor() {
|
||||
if (logMonitorStatus) setTimeout(logMonitor, opts.logmonitor_refresh_period);
|
||||
|
|
@ -52,10 +51,6 @@ async function initLogMonitor() {
|
|||
</table>
|
||||
`;
|
||||
el.style.display = 'none';
|
||||
jobStatusEl = document.createElement('div');
|
||||
jobStatusEl.className = 'jobStatus';
|
||||
jobStatusEl.style.display = 'none';
|
||||
gradioApp().appendChild(jobStatusEl);
|
||||
fetch(`/sdapi/v1/start?agent=${encodeURI(navigator.userAgent)}`);
|
||||
logMonitor();
|
||||
log('initLogMonitor');
|
||||
|
|
|
|||
|
|
@ -119,9 +119,6 @@ svg.feather.feather-image, .feather .feather-image { display: none }
|
|||
.gradio-button.tool { filter: hue-rotate(120deg) saturate(0.5); }
|
||||
#txt2img_styles_row, #img2img_styles_row { margin-top: -6px; z-index: 200; }
|
||||
|
||||
/* custom elements overrides */
|
||||
#steps-animation, #controlnet { border-width: 0; }
|
||||
|
||||
/* based on gradio built-in dark theme */
|
||||
:root, .light, .dark {
|
||||
--body-background-fill: var(--background-color);
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -42,24 +42,23 @@ function checkPaused(state) {
|
|||
function setProgress(res) {
|
||||
const elements = ['txt2img_generate', 'img2img_generate', 'extras_generate'];
|
||||
const progress = (res?.progress || 0);
|
||||
const job = res?.job || '';
|
||||
const perc = res && (progress > 0) ? `${Math.round(100.0 * progress)}%` : '';
|
||||
let sec = res?.eta || 0;
|
||||
let eta = '';
|
||||
if (res?.paused) eta = 'Paused';
|
||||
else if (res?.completed || (progress > 0.99)) eta = 'Finishing';
|
||||
else if (sec === 0) eta = `Init${res?.job?.length > 0 ? `: ${res.job}` : ''}`;
|
||||
else if (sec === 0) eta = 'Starting';
|
||||
else {
|
||||
const min = Math.floor(sec / 60);
|
||||
sec %= 60;
|
||||
eta = min > 0 ? `ETA: ${Math.round(min)}m ${Math.round(sec)}s` : `ETA: ${Math.round(sec)}s`;
|
||||
eta = min > 0 ? `${Math.round(min)}m ${Math.round(sec)}s` : `${Math.round(sec)}s`;
|
||||
}
|
||||
document.title = `SD.Next ${perc}`;
|
||||
for (const elId of elements) {
|
||||
const el = document.getElementById(elId);
|
||||
el.innerText = res
|
||||
? `${perc} ${eta}`
|
||||
: 'Generate';
|
||||
el.style.background = res
|
||||
el.innerText = (res ? `${job} ${perc} ${eta}` : 'Generate');
|
||||
el.style.background = res && (progress > 0)
|
||||
? `linear-gradient(to right, var(--primary-500) 0%, var(--primary-800) ${perc}, var(--neutral-700) ${perc})`
|
||||
: 'var(--button-primary-background-fill)';
|
||||
}
|
||||
|
|
@ -106,7 +105,6 @@ function requestProgress(id_task, progressEl, galleryEl, atEnd = null, onProgres
|
|||
debug('taskEnd:', id_task);
|
||||
localStorage.removeItem('task');
|
||||
setProgress();
|
||||
if (jobStatusEl) jobStatusEl.style.display = 'none';
|
||||
if (parentGallery && livePreview) parentGallery.removeChild(livePreview);
|
||||
checkPaused(true);
|
||||
if (atEnd) atEnd();
|
||||
|
|
@ -114,8 +112,6 @@ function requestProgress(id_task, progressEl, galleryEl, atEnd = null, onProgres
|
|||
|
||||
const start = (id_task, id_live_preview) => { // eslint-disable-line no-shadow
|
||||
request('./internal/progress', { id_task, id_live_preview }, (res) => {
|
||||
if (jobStatusEl) jobStatusEl.innerText = (res?.job || '').trim().toUpperCase();
|
||||
if (jobStatusEl) jobStatusEl.style.display = jobStatusEl.innerText.length > 0 ? 'block' : 'none';
|
||||
lastState = res;
|
||||
const elapsedFromStart = (new Date() - dateStart) / 1000;
|
||||
hasStarted |= res.active;
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
@font-face { font-family: 'Roboto'; font-display: swap; font-style: normal; font-weight: 100; src: local('Roboto'), url('roboto.ttf') }
|
||||
:root { --left-column: 490px; }
|
||||
a { font-weight: bold; cursor: pointer; }
|
||||
h2 { margin-top: 1em !important; font-size: 1.4em !important; }
|
||||
|
|
@ -25,7 +26,7 @@ tr { border-bottom: none !important; padding: 0.1em 0.5em !important; }
|
|||
.gradio-button.secondary-down:hover { background: var(--button-secondary-background-fill-hover); color: var(--button-secondary-text-color-hover); }
|
||||
.gradio-button.tool { max-width: min-content; min-width: min-content !important; align-self: end; font-size: 1.4em; color: var(--body-text-color) !important; margin-bottom: var(--spacing-md); }
|
||||
.gradio-checkbox { margin: 0.75em 1.5em 0 0; align-self: center; }
|
||||
.gradio-column { min-width: unset !important; }
|
||||
.gradio-column { min-width: unset; }
|
||||
.gradio-container { max-width: unset !important; padding: var(--block-label-padding) !important; }
|
||||
.gradio-container .prose a, .gradio-container .prose a:visited{ color: unset; text-decoration: none; }
|
||||
|
||||
|
|
@ -45,12 +46,12 @@ tr { border-bottom: none !important; padding: 0.1em 0.5em !important; }
|
|||
|
||||
/* custom gradio elements */
|
||||
.accordion-compact { padding: 8px 0px 4px 0px !important; }
|
||||
.settings-accordion .gap { padding-right: 1000px; }
|
||||
.settings-accordion >div { flex-flow: wrap; }
|
||||
.small-accordion .form { min-width: var(--left-column) !important; }
|
||||
.settings-accordion > div { flex-flow: wrap; }
|
||||
.small-accordion .form { min-width: var(--left-column) !important; width: max-content; }
|
||||
.small-accordion .label-wrap .icon { margin-right: 1.6em; margin-left: 0.6em; color: var(--button-primary-border-color); }
|
||||
.small-accordion .label-wrap { padding: 16px 0px 8px 0px; margin: 0; border-top: 2px solid var(--button-secondary-border-color); }
|
||||
.small-accordion { width: fit-content !important; padding-left: 0 !important; }
|
||||
.extension-script { max-width: 50%; }
|
||||
button.custom-button{ border-radius: var(--button-large-radius); padding: var(--button-large-padding); font-weight: var(--button-large-text-weight); border: var(--button-border-width) solid var(--button-secondary-border-color);
|
||||
background: var(--button-secondary-background-fill); color: var(--button-secondary-text-color); font-size: var(--button-large-text-size);
|
||||
display: inline-flex; justify-content: center; align-items: center; transition: var(--button-transition); box-shadow: var(--button-shadow); text-align: center; }
|
||||
|
|
@ -72,7 +73,7 @@ button.custom-button{ border-radius: var(--button-large-radius); padding: var(--
|
|||
#txt2img_footer, #img2img_footer { height: fit-content; display: none; }
|
||||
#txt2img_generate_box, #img2img_generate_box { gap: 0.5em; flex-wrap: wrap-reverse; height: fit-content; }
|
||||
#txt2img_actions_column, #img2img_actions_column { gap: 0.5em; height: fit-content; }
|
||||
#txt2img_generate_box > button, #img2img_generate_box > button { min-height: 42px; max-height: 42px; }
|
||||
#txt2img_generate_box > button, #img2img_generate_box > button, #txt2img_enqueue, #img2img_enqueue { min-height: 42px; max-height: 42px; line-height: 1em; }
|
||||
#txt2img_generate_line2, #img2img_generate_line2, #txt2img_tools, #img2img_tools { display: flex; }
|
||||
#txt2img_generate_line2 > button, #img2img_generate_line2 > button, #extras_generate_box > button, #txt2img_tools > button, #img2img_tools > button { height: 2em; line-height: 0; font-size: var(--input-text-size);
|
||||
min-width: unset; display: block !important; margin-left: 0.4em; margin-right: 0.4em; }
|
||||
|
|
@ -96,7 +97,6 @@ div#extras_scale_to_tab div.form{ flex-direction: row; }
|
|||
width: 22em; min-height: 1.3em; font-size: 0.8em; transition: opacity 0.2s ease-in; pointer-events: none; opacity: 0; z-index: 999; }
|
||||
.tooltip-show { opacity: 0.9; }
|
||||
.toolbutton-selected { background: var(--background-fill-primary) !important; }
|
||||
.jobStatus { position: fixed; bottom: 1em; right: 1em; background: var(--input-background-fill); padding: 0.4em; font-size: 0.8em; color: var(--body-text-color-subdued); }
|
||||
|
||||
/* settings */
|
||||
#si-sparkline-memo, #si-sparkline-load { background-color: #111; }
|
||||
|
|
@ -163,6 +163,7 @@ table.settings-value-table td { padding: 0.4em; border: 1px solid #ccc; max-widt
|
|||
/* extensions */
|
||||
#tab_extensions table, #tab_config table{ border-collapse: collapse; }
|
||||
#tab_extensions table td, #tab_extensions table th, #tab_config table td, #tab_config table th { border: 1px solid #ccc; padding: 0.25em 0.5em; }
|
||||
#tab_extensions table tr:hover, #tab_config table tr:hover { background-color: var(--neutral-500) !important; }
|
||||
#tab_extensions table input[type="checkbox"] { margin-right: 0.5em; appearance: checkbox; }
|
||||
#tab_extensions button{ max-width: 16em; }
|
||||
#tab_extensions input[disabled="disabled"]{ opacity: 0.5; }
|
||||
|
|
@ -195,9 +196,10 @@ table.settings-value-table td { padding: 0.4em; border: 1px solid #ccc; max-widt
|
|||
.extra-network-cards .card:hover .overlay { background: rgba(0, 0, 0, 0.40); }
|
||||
.extra-network-cards .card:hover .preview { box-shadow: none; filter: grayscale(100%); }
|
||||
.extra-network-cards .card:hover .overlay { background: rgba(0, 0, 0, 0.40); }
|
||||
.extra-network-cards .card .overlay .tags { margin: 4px; display: none; overflow-wrap: break-word; }
|
||||
.extra-network-cards .card .overlay .tag { padding: 2px; margin: 2px; background: var(--neutral-700); cursor: pointer; display: inline-block; }
|
||||
.extra-network-cards .card .overlay .tags { display: none; overflow-wrap: break-word; }
|
||||
.extra-network-cards .card .overlay .tag { padding: 3px; background: rgba(70, 70, 70, 0.60); font-size: var(--text-lg); cursor: pointer; display: inline-block; margin-bottom: 4px; }
|
||||
.extra-network-cards .card .actions > span { padding: 4px; }
|
||||
.extra-network-cards .card .actions > span:hover { color: var(--highlight-color); }
|
||||
.extra-network-cards .card:hover .actions { display: block; }
|
||||
.extra-network-cards .card:hover .overlay .tags { display: block; }
|
||||
.extra-network-cards .card .actions { font-size: 3em; display: none; text-align-last: right; cursor: pointer; font-variant: unicase; position: absolute; z-index: 100; right: 0; height: 0.7em; width: 100%; background: rgba(0, 0, 0, 0.40); }
|
||||
|
|
|
|||
|
|
@ -28,9 +28,9 @@ def init_modules():
|
|||
parser = modules.cmd_args.parser
|
||||
installer.add_args(parser)
|
||||
args, _ = parser.parse_known_args()
|
||||
import modules.paths_internal
|
||||
script_path = modules.paths_internal.script_path
|
||||
extensions_dir = modules.paths_internal.extensions_dir
|
||||
import modules.paths
|
||||
script_path = modules.paths.script_path
|
||||
extensions_dir = modules.paths.extensions_dir
|
||||
|
||||
|
||||
def get_custom_args():
|
||||
|
|
|
|||
|
|
@ -356,68 +356,54 @@ class Api:
|
|||
|
||||
def extras_batch_images_api(self, req: models.ExtrasBatchImagesRequest):
|
||||
reqDict = setUpscalers(req)
|
||||
|
||||
image_list = reqDict.pop('imageList', [])
|
||||
image_folder = [decode_base64_to_image(x.data) for x in image_list]
|
||||
|
||||
with self.queue_lock:
|
||||
result = postprocessing.run_extras(extras_mode=1, image_folder=image_folder, image="", input_dir="", output_dir="", save_output=False, **reqDict)
|
||||
|
||||
return models.ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
|
||||
|
||||
def pnginfoapi(self, req: models.PNGInfoRequest):
|
||||
if not req.image.strip():
|
||||
return models.PNGInfoResponse(info="")
|
||||
|
||||
image = decode_base64_to_image(req.image.strip())
|
||||
if image is None:
|
||||
return models.PNGInfoResponse(info="")
|
||||
|
||||
geninfo, items = images.read_info_from_image(image)
|
||||
if geninfo is None:
|
||||
geninfo = ""
|
||||
|
||||
items = {**{'parameters': geninfo}, **items}
|
||||
|
||||
return models.PNGInfoResponse(info=geninfo, items=items)
|
||||
|
||||
def progressapi(self, req: models.ProgressRequest = Depends()):
|
||||
# copy from check_progress_call of ui.py
|
||||
|
||||
if shared.state.job_count == 0:
|
||||
return models.ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo)
|
||||
|
||||
# avoid dividing zero
|
||||
progress = 0.01
|
||||
|
||||
if shared.state.job_count > 0:
|
||||
progress += shared.state.job_no / shared.state.job_count
|
||||
if shared.state.sampling_steps > 0:
|
||||
progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
|
||||
|
||||
time_since_start = time.time() - shared.state.time_start
|
||||
eta = time_since_start / progress
|
||||
eta_relative = eta-time_since_start
|
||||
|
||||
progress = min(progress, 1)
|
||||
|
||||
shared.state.set_current_image()
|
||||
|
||||
current_image = None
|
||||
if shared.state.current_image and not req.skip_current_image:
|
||||
current_image = encode_pil_to_base64(shared.state.current_image)
|
||||
|
||||
return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo)
|
||||
batch_x = max(shared.state.job_no, 0)
|
||||
batch_y = max(shared.state.job_count, 1)
|
||||
step_x = max(shared.state.sampling_step, 0)
|
||||
step_y = max(shared.state.sampling_steps, 1)
|
||||
current = step_y * batch_x + step_x
|
||||
total = step_y * batch_y
|
||||
progress = current / total if total > 0 else 0
|
||||
|
||||
time_since_start = time.time() - shared.state.time_start
|
||||
eta_relative = (time_since_start / progress) - time_since_start
|
||||
|
||||
res = models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo)
|
||||
return res
|
||||
|
||||
|
||||
def interrogateapi(self, interrogatereq: models.InterrogateRequest):
|
||||
image_b64 = interrogatereq.image
|
||||
if image_b64 is None:
|
||||
raise HTTPException(status_code=404, detail="Image not found")
|
||||
|
||||
img = decode_base64_to_image(image_b64)
|
||||
img = img.convert('RGB')
|
||||
|
||||
# Override object param
|
||||
with self.queue_lock:
|
||||
if interrogatereq.model == "clip":
|
||||
processed = shared.interrogator.interrogate(img)
|
||||
|
|
@ -425,7 +411,6 @@ class Api:
|
|||
processed = deepbooru.model.tag(img)
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="Model not found")
|
||||
|
||||
return models.InterrogateResponse(caption=processed)
|
||||
|
||||
def interruptapi(self):
|
||||
|
|
@ -473,18 +458,8 @@ class Api:
|
|||
def get_sd_vaes(self):
|
||||
return [{"model_name": x, "filename": vae_dict[x]} for x in vae_dict.keys()]
|
||||
|
||||
|
||||
def get_upscalers(self):
|
||||
return [
|
||||
{
|
||||
"name": upscaler.name,
|
||||
"model_name": upscaler.scaler.model_name,
|
||||
"model_path": upscaler.data_path,
|
||||
"model_url": None,
|
||||
"scale": upscaler.scale,
|
||||
}
|
||||
for upscaler in shared.sd_upscalers
|
||||
]
|
||||
return [{"name": upscaler.name, "model_name": upscaler.scaler.model_name, "model_path": upscaler.data_path, "model_url": None, "scale": upscaler.scale} for upscaler in shared.sd_upscalers]
|
||||
|
||||
def get_sd_models(self):
|
||||
return [{"title": x.title, "name": x.name, "filename": x.filename, "type": x.type, "hash": x.shorthash, "sha256": x.sha256, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()]
|
||||
|
|
@ -500,23 +475,13 @@ class Api:
|
|||
|
||||
def get_embeddings(self):
|
||||
db = sd_hijack.model_hijack.embedding_db
|
||||
|
||||
def convert_embedding(embedding):
|
||||
return {
|
||||
"step": embedding.step,
|
||||
"sd_checkpoint": embedding.sd_checkpoint,
|
||||
"sd_checkpoint_name": embedding.sd_checkpoint_name,
|
||||
"shape": embedding.shape,
|
||||
"vectors": embedding.vectors,
|
||||
}
|
||||
return {"step": embedding.step, "sd_checkpoint": embedding.sd_checkpoint, "sd_checkpoint_name": embedding.sd_checkpoint_name, "shape": embedding.shape, "vectors": embedding.vectors}
|
||||
|
||||
def convert_embeddings(embeddings):
|
||||
return {embedding.name: convert_embedding(embedding) for embedding in embeddings.values()}
|
||||
|
||||
return {
|
||||
"loaded": convert_embeddings(db.word_embeddings),
|
||||
"skipped": convert_embeddings(db.skipped_embeddings),
|
||||
}
|
||||
return {"loaded": convert_embeddings(db.word_embeddings), "skipped": convert_embeddings(db.skipped_embeddings)}
|
||||
|
||||
def get_extra_networks(self, page: Optional[str] = None, name: Optional[str] = None, filename: Optional[str] = None, title: Optional[str] = None, fullname: Optional[str] = None, hash: Optional[str] = None): # pylint: disable=redefined-builtin
|
||||
res = []
|
||||
|
|
@ -553,7 +518,7 @@ class Api:
|
|||
|
||||
def create_embedding(self, args: dict):
|
||||
try:
|
||||
shared.state.begin('api-create-embedding')
|
||||
shared.state.begin('api-embedding')
|
||||
filename = create_embedding(**args) # create empty embedding
|
||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
|
||||
shared.state.end()
|
||||
|
|
@ -564,7 +529,7 @@ class Api:
|
|||
|
||||
def create_hypernetwork(self, args: dict):
|
||||
try:
|
||||
shared.state.begin('api-create-hypernetwork')
|
||||
shared.state.begin('api-hypernetwork')
|
||||
filename = create_hypernetwork(**args) # create empty embedding # pylint: disable=E1111
|
||||
shared.state.end()
|
||||
return models.CreateResponse(info = f"create hypernetwork filename: {filename}")
|
||||
|
|
@ -590,7 +555,7 @@ class Api:
|
|||
|
||||
def train_embedding(self, args: dict):
|
||||
try:
|
||||
shared.state.begin('api-train-embedding')
|
||||
shared.state.begin('api-embedding')
|
||||
apply_optimizations = False
|
||||
error = None
|
||||
filename = ''
|
||||
|
|
@ -611,7 +576,7 @@ class Api:
|
|||
|
||||
def train_hypernetwork(self, args: dict):
|
||||
try:
|
||||
shared.state.begin('api-train-hypernetwork')
|
||||
shared.state.begin('api-hypernetwork')
|
||||
shared.loaded_hypernetworks = []
|
||||
apply_optimizations = False
|
||||
error = None
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import os
|
||||
import argparse
|
||||
from modules.paths_internal import data_path
|
||||
from modules.paths import data_path
|
||||
|
||||
parser = argparse.ArgumentParser(description="SD.Next", conflict_handler='resolve', epilog='For other options see UI Settings page', prog='', add_help=True, formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=55, indent_increment=2, width=200))
|
||||
parser._optionals = parser.add_argument_group('Other options') # pylint: disable=protected-access
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import os
|
|||
from datetime import datetime
|
||||
import git
|
||||
from modules import shared, errors
|
||||
from modules.paths_internal import extensions_dir, extensions_builtin_dir
|
||||
from modules.paths import extensions_dir, extensions_builtin_dir
|
||||
|
||||
|
||||
extensions = []
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ def to_half(tensor, enable):
|
|||
|
||||
|
||||
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata): # pylint: disable=unused-argument
|
||||
shared.state.begin('model-merge')
|
||||
shared.state.begin('merge')
|
||||
save_as_half = save_as_half == 0
|
||||
|
||||
def fail(message):
|
||||
|
|
@ -319,7 +319,7 @@ def run_modelconvert(model, checkpoint_formats, precision, conv_type, custom_nam
|
|||
"vae": vae_conv,
|
||||
"other": others_conv
|
||||
}
|
||||
shared.state.begin('model-convert')
|
||||
shared.state.begin('convert')
|
||||
model_info = sd_models.checkpoints_list[model]
|
||||
shared.state.textinfo = f"Loading {model_info.filename}..."
|
||||
shared.log.info(f"Model convert loading: {model_info.filename}")
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ def sha256(filename, title, use_addnet_hash=False):
|
|||
if not os.path.isfile(filename):
|
||||
return None
|
||||
orig_state = copy.deepcopy(shared.state)
|
||||
shared.state.begin("hashing")
|
||||
shared.state.begin("hash")
|
||||
if use_addnet_hash:
|
||||
if progress_ok:
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -460,7 +460,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||
hypernetwork.load(path)
|
||||
shared.loaded_hypernetworks = [hypernetwork]
|
||||
|
||||
shared.state.job = "train-hypernetwork"
|
||||
shared.state.job = "train"
|
||||
shared.state.textinfo = "Initializing hypernetwork training..."
|
||||
shared.state.job_count = steps
|
||||
|
||||
|
|
|
|||
|
|
@ -135,9 +135,9 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
|
|||
|
||||
def get_font(fontsize):
|
||||
try:
|
||||
return ImageFont.truetype(shared.opts.font or 'html/roboto.ttf', fontsize)
|
||||
return ImageFont.truetype(shared.opts.font or 'javascript/roboto.ttf', fontsize)
|
||||
except Exception:
|
||||
return ImageFont.truetype('html/roboto.ttf', fontsize)
|
||||
return ImageFont.truetype('javascript/roboto.ttf', fontsize)
|
||||
|
||||
def draw_texts(drawing: ImageDraw, draw_x, draw_y, lines, initial_fnt, initial_fontsize):
|
||||
for line in lines:
|
||||
|
|
@ -553,13 +553,11 @@ def save_image(image, path, basename = '', seed=None, prompt=None, extension=sha
|
|||
return None, None
|
||||
if not check_grid_size([image]):
|
||||
return None, None
|
||||
if path is None or len(path) == 0:
|
||||
if path is None or len(path) == 0: # set default path to avoid errors when functions are triggered manually or via api and param is not set
|
||||
path = shared.opts.outdir_save
|
||||
|
||||
# namegen
|
||||
namegen = FilenameGenerator(p, seed, prompt, image, grid=grid)
|
||||
if shared.opts.save_to_dirs:
|
||||
dirname = namegen.apply(shared.opts.directories_filename_pattern or "[date]")
|
||||
dirname = namegen.apply(shared.opts.directories_filename_pattern or "[prompt_words]")
|
||||
path = os.path.join(path, dirname)
|
||||
if forced_filename is None:
|
||||
if short_filename or seed is None:
|
||||
|
|
@ -567,11 +565,10 @@ def save_image(image, path, basename = '', seed=None, prompt=None, extension=sha
|
|||
if shared.opts.samples_filename_pattern and len(shared.opts.samples_filename_pattern) > 0:
|
||||
file_decoration = shared.opts.samples_filename_pattern
|
||||
else:
|
||||
file_decoration = "[seq]-[model_name]-[prompt_words]"
|
||||
file_decoration = "[seq]-[prompt_words]"
|
||||
file_decoration = namegen.apply(file_decoration)
|
||||
filename = os.path.join(path, f"{file_decoration}{suffix}.{extension}") if basename is None or basename == '' else os.path.join(path, f"{basename}-{file_decoration}{suffix}.{extension}")
|
||||
else:
|
||||
filename = os.path.join(path, f"{forced_filename}.{extension}")
|
||||
file_decoration += suffix
|
||||
filename = os.path.join(path, f"{file_decoration}.{extension}") if basename == '' else os.path.join(path, f"{basename}-{file_decoration}.{extension}")
|
||||
pnginfo = existing_info or {}
|
||||
if info is not None:
|
||||
pnginfo[pnginfo_section_name] = info
|
||||
|
|
@ -579,7 +576,6 @@ def save_image(image, path, basename = '', seed=None, prompt=None, extension=sha
|
|||
params.filename = namegen.sanitize(filename)
|
||||
dirname = os.path.dirname(params.filename)
|
||||
os.makedirs(dirname, exist_ok=True)
|
||||
|
||||
# sequence
|
||||
if shared.opts.save_images_add_number or '[seq]' in params.filename:
|
||||
if '[seq]' not in params.filename:
|
||||
|
|
@ -592,7 +588,6 @@ def save_image(image, path, basename = '', seed=None, prompt=None, extension=sha
|
|||
debug(f'Prompt sequence: input="{params.filename}" seq={seq} output="{filename}"')
|
||||
params.filename = filename
|
||||
break
|
||||
|
||||
# callbacks
|
||||
script_callbacks.before_image_saved_callback(params)
|
||||
exifinfo = params.pnginfo.get('UserComment', '')
|
||||
|
|
|
|||
|
|
@ -40,7 +40,6 @@ def process_batch(p, input_files, input_dir, output_dir, inpaint_mask_dir, args)
|
|||
btcrept = p.batch_size
|
||||
shared.log.info(f"Process batch: inputs={len(image_files)} outputs={p.n_iter * p.batch_size} per input")
|
||||
for i in range(0, len(image_files), window_size):
|
||||
shared.state.job = f"{i+1} to {min(i+window_size, len(image_files))} out of {len(image_files)}"
|
||||
if shared.state.skipped:
|
||||
shared.state.skipped = False
|
||||
if shared.state.interrupted:
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
Subproject commit 045515774882014cc14c1ba2668ab5bad9cbf7c0
|
||||
|
|
@ -85,7 +85,7 @@ def download_civit_preview(model_path: str, preview_url: str):
|
|||
block_size = 16384 # 16KB blocks
|
||||
written = 0
|
||||
img = None
|
||||
shared.state.begin('civitai-download-preview')
|
||||
shared.state.begin('civitai')
|
||||
try:
|
||||
with open(preview_file, 'wb') as f:
|
||||
with p.Progress(p.TextColumn('[cyan]{task.description}'), p.DownloadColumn(), p.BarColumn(), p.TaskProgressColumn(), p.TimeRemainingColumn(), p.TimeElapsedColumn(), p.TransferSpeedColumn(), console=shared.console) as progress:
|
||||
|
|
@ -142,7 +142,7 @@ def download_civit_model_thread(model_name, model_url, model_path, model_type, p
|
|||
total_size = int(r.headers.get('content-length', 0))
|
||||
res += f' size={round((starting_pos + total_size)/1024/1024)}Mb'
|
||||
shared.log.info(res)
|
||||
shared.state.begin('civitai-download-model')
|
||||
shared.state.begin('civitai')
|
||||
block_size = 16384 # 16KB blocks
|
||||
written = starting_pos
|
||||
global download_pbar # pylint: disable=global-statement
|
||||
|
|
@ -188,7 +188,7 @@ def download_diffusers_model(hub_id: str, cache_dir: str = None, download_config
|
|||
return None
|
||||
from diffusers import DiffusionPipeline
|
||||
import huggingface_hub as hf
|
||||
shared.state.begin('huggingface-download-model')
|
||||
shared.state.begin('huggingface')
|
||||
if download_config is None:
|
||||
download_config = {
|
||||
"force_download": False,
|
||||
|
|
|
|||
|
|
@ -1,9 +1,43 @@
|
|||
# this module must not have any dependencies as it first import
|
||||
import os
|
||||
import sys
|
||||
from modules import paths_internal, errors
|
||||
import json
|
||||
import argparse
|
||||
from modules.errors import log
|
||||
|
||||
# parse args, parse again after we have the data-dir and early-read the config file
|
||||
parser = argparse.ArgumentParser(add_help=False)
|
||||
parser.add_argument("--ckpt", type=str, default=os.environ.get("SD_MODEL", None), help="Path to model checkpoint to load immediately, default: %(default)s")
|
||||
parser.add_argument("--data-dir", type=str, default=os.environ.get("SD_DATADIR", ''), help="Base path where all user data is stored, default: %(default)s")
|
||||
parser.add_argument("--models-dir", type=str, default=os.environ.get("SD_MODELSDIR", None), help="Base path where all models are stored, default: %(default)s",)
|
||||
cli = parser.parse_known_args()[0]
|
||||
parser.add_argument("--config", type=str, default=os.environ.get("SD_CONFIG", os.path.join(cli.data_dir, 'config.json')), help="Use specific server configuration file, default: %(default)s")
|
||||
cli = parser.parse_known_args()[0]
|
||||
config_path = cli.config if os.path.isabs(cli.config) else os.path.join(cli.data_dir, cli.config)
|
||||
try:
|
||||
with open(config_path, 'r', encoding='utf8') as f:
|
||||
config = json.load(f)
|
||||
except Exception as err:
|
||||
print(f'Error loading config file: ${config_path} {err}')
|
||||
config = {}
|
||||
|
||||
debug = errors.log.info if os.environ.get('SD_PATH_DEBUG', None) is not None else lambda *args, **kwargs: None
|
||||
modules_path = os.path.dirname(os.path.realpath(__file__))
|
||||
script_path = os.path.dirname(modules_path)
|
||||
data_path = cli.data_dir
|
||||
models_config = cli.models_dir or config.get('models_dir') or 'models'
|
||||
models_path = models_config if os.path.isabs(models_config) else os.path.join(data_path, models_config)
|
||||
extensions_dir = os.path.join(data_path, "extensions")
|
||||
extensions_builtin_dir = "extensions-builtin"
|
||||
sd_configs_path = os.path.join(script_path, "configs")
|
||||
sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml")
|
||||
sd_model_file = cli.ckpt or os.path.join(script_path, 'model.ckpt') # not used
|
||||
default_sd_model_file = sd_model_file # not used
|
||||
debug = log.info if os.environ.get('SD_PATH_DEBUG', None) is not None else lambda *args, **kwargs: None
|
||||
|
||||
if os.environ.get('SD_PATH_DEBUG', None) is not None:
|
||||
print(f'Paths: script-path="{script_path}" data-dir="{data_path}" models-dir="{models_path}" config="{config_path}"')
|
||||
|
||||
"""
|
||||
data_path = paths_internal.data_path
|
||||
script_path = paths_internal.script_path
|
||||
models_path = paths_internal.models_path
|
||||
|
|
@ -13,26 +47,17 @@ sd_model_file = paths_internal.sd_model_file
|
|||
default_sd_model_file = paths_internal.default_sd_model_file
|
||||
extensions_dir = paths_internal.extensions_dir
|
||||
extensions_builtin_dir = paths_internal.extensions_builtin_dir
|
||||
"""
|
||||
|
||||
# data_path = cmd_opts_pre.data
|
||||
sys.path.insert(0, script_path)
|
||||
|
||||
# search for directory of stable diffusion in following places
|
||||
sd_path = None
|
||||
possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion-stability-ai'), '.', os.path.dirname(script_path)]
|
||||
for possible_sd_path in possible_sd_paths:
|
||||
if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')):
|
||||
sd_path = os.path.abspath(possible_sd_path)
|
||||
break
|
||||
|
||||
assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possible_sd_paths}"
|
||||
|
||||
sd_path = os.path.join(script_path, 'repositories')
|
||||
path_dirs = [
|
||||
(sd_path, 'ldm', 'Stable Diffusion', []),
|
||||
(os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers', []),
|
||||
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
|
||||
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
|
||||
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
|
||||
(sd_path, 'ldm', 'ldm', []),
|
||||
(sd_path, 'taming', 'Taming Transformers', []),
|
||||
(os.path.join(sd_path, 'blip'), 'models/blip.py', 'BLIP', []),
|
||||
(os.path.join(sd_path, 'codeformer'), 'inference_codeformer.py', 'CodeFormer', []),
|
||||
(os.path.join('modules', 'k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
|
||||
]
|
||||
|
||||
paths = {}
|
||||
|
|
@ -40,25 +65,26 @@ paths = {}
|
|||
for d, must_exist, what, _options in path_dirs:
|
||||
must_exist_path = os.path.abspath(os.path.join(script_path, d, must_exist))
|
||||
if not os.path.exists(must_exist_path):
|
||||
errors.log.error(f'Required path not found: path={must_exist_path} item={what}')
|
||||
log.error(f'Required path not found: path={must_exist_path} item={what}')
|
||||
else:
|
||||
d = os.path.abspath(d)
|
||||
sys.path.append(d)
|
||||
paths[what] = d
|
||||
|
||||
|
||||
def create_paths(opts):
|
||||
def create_path(folder):
|
||||
if folder is None or folder == '':
|
||||
return
|
||||
if os.path.exists(folder):
|
||||
return
|
||||
try:
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
errors.log.info(f'Create folder={folder}')
|
||||
except Exception as e:
|
||||
errors.log.error(f'Create Failed folder={folder} {e}')
|
||||
def create_path(folder):
|
||||
if folder is None or folder == '':
|
||||
return
|
||||
if os.path.exists(folder):
|
||||
return
|
||||
try:
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
log.info(f'Create folder={folder}')
|
||||
except Exception as e:
|
||||
log.error(f'Create Failed folder={folder} {e}')
|
||||
|
||||
|
||||
def create_paths(opts):
|
||||
def fix_path(folder):
|
||||
tgt = opts.data.get(folder, None) or opts.data_labels[folder].default
|
||||
if tgt is None or tgt == '':
|
||||
|
|
|
|||
|
|
@ -1,5 +1,8 @@
|
|||
"""this module defines internal paths used by program and is safe to import before dependencies are installed in launch.py"""
|
||||
# no longer used, all paths are defined in paths.py
|
||||
|
||||
from modules.paths import modules_path, script_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, data_path, models_path, extensions_dir, extensions_builtin_dir # pylint: disable=unused-import
|
||||
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
|
||||
|
|
@ -7,15 +10,21 @@ modules_path = os.path.dirname(os.path.realpath(__file__))
|
|||
script_path = os.path.dirname(modules_path)
|
||||
sd_configs_path = os.path.join(script_path, "configs")
|
||||
sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml")
|
||||
sd_model_file = os.path.join(script_path, 'model.ckpt')
|
||||
default_sd_model_file = sd_model_file
|
||||
|
||||
# Parse the --data-dir flag first so we can use it as a base for our other argument default values
|
||||
parser_pre = argparse.ArgumentParser(add_help=False)
|
||||
parser_pre.add_argument("--data-dir", type=str, default="", help="base path where all user data is stored", )
|
||||
parser_pre.add_argument("--models-dir", type=str, default="models", help="base path where all models are stored",)
|
||||
parser_pre.add_argument("--ckpt", type=str, default=os.environ.get("SD_MODEL", None), help="Path to model checkpoint to load immediately, default: %(default)s")
|
||||
parser_pre.add_argument("--data-dir", type=str, default=os.environ.get("SD_DATADIR", ''), help="Base path where all user data is stored, default: %(default)s")
|
||||
parser_pre.add_argument("--models-dir", type=str, default=os.environ.get("SD_MODELSDIR", 'models'), help="Base path where all models are stored, default: %(default)s",)
|
||||
cmd_opts_pre = parser_pre.parse_known_args()[0]
|
||||
|
||||
# parser_pre.add_argument("--config", type=str, default=os.environ.get("SD_CONFIG", os.path.join(data_path, 'config.json')), help="Use specific server configuration file, default: %(default)s")
|
||||
|
||||
data_path = cmd_opts_pre.data_dir
|
||||
models_path = cmd_opts_pre.models_dir if os.path.isabs(cmd_opts_pre.models_dir) else os.path.join(data_path, cmd_opts_pre.models_dir)
|
||||
extensions_dir = os.path.join(data_path, "extensions")
|
||||
extensions_builtin_dir = "extensions-builtin"
|
||||
|
||||
sd_model_file = cmd_opts_pre.ckpt or os.path.join(script_path, 'model.ckpt') # not used
|
||||
default_sd_model_file = sd_model_file # not used
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -80,6 +80,22 @@ def create_binary_mask(image):
|
|||
return image
|
||||
|
||||
|
||||
def images_tensor_to_samples(image, approximation=None, model=None):
|
||||
if model is None:
|
||||
model = shared.sd_model
|
||||
model.first_stage_model.to(devices.dtype_vae)
|
||||
image = image.to(shared.device, dtype=devices.dtype_vae)
|
||||
image = image * 2 - 1
|
||||
if len(image) > 1:
|
||||
x_latent = torch.stack([
|
||||
model.get_first_stage_encoding(model.encode_first_stage(torch.unsqueeze(img, 0)))[0]
|
||||
for img in image
|
||||
])
|
||||
else:
|
||||
x_latent = model.get_first_stage_encoding(model.encode_first_stage(image))
|
||||
return x_latent
|
||||
|
||||
|
||||
def txt2img_image_conditioning(sd_model, x, width, height):
|
||||
if sd_model.model.conditioning_key in {'hybrid', 'concat'}: # Inpainting models
|
||||
# The "masked-image" in this case will just be all zeros since the entire image is masked.
|
||||
|
|
@ -450,6 +466,8 @@ def decode_first_stage(model, x, full_quality=True):
|
|||
shared.log.debug(f'Decode VAE: skipped={shared.state.skipped} interrupted={shared.state.interrupted}')
|
||||
x_sample = torch.zeros((len(x), 3, x.shape[2] * 8, x.shape[3] * 8), dtype=devices.dtype_vae, device=devices.device)
|
||||
return x_sample
|
||||
prev_job = shared.state.job
|
||||
shared.state.job = 'vae'
|
||||
with devices.autocast(disable = x.dtype==devices.dtype_vae):
|
||||
try:
|
||||
if full_quality:
|
||||
|
|
@ -467,6 +485,7 @@ def decode_first_stage(model, x, full_quality=True):
|
|||
except Exception as e:
|
||||
x_sample = x
|
||||
shared.log.error(f'Decode VAE: {e}')
|
||||
shared.state.job = prev_job
|
||||
return x_sample
|
||||
|
||||
|
||||
|
|
@ -777,12 +796,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||
return ''
|
||||
|
||||
ema_scope_context = p.sd_model.ema_scope if shared.backend == shared.Backend.ORIGINAL else nullcontext
|
||||
shared.state.job_count = p.n_iter
|
||||
with devices.inference_context(), ema_scope_context():
|
||||
t0 = time.time()
|
||||
with devices.autocast():
|
||||
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
|
||||
if shared.state.job_count == -1:
|
||||
shared.state.job_count = p.n_iter
|
||||
extra_network_data = None
|
||||
for n in range(p.n_iter):
|
||||
p.iteration = n
|
||||
|
|
@ -814,8 +832,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||
step_multiplier = 1
|
||||
sampler_config = modules.sd_samplers.find_sampler_config(p.sampler_name)
|
||||
step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
|
||||
if p.n_iter > 1:
|
||||
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
||||
|
||||
if shared.backend == shared.Backend.ORIGINAL:
|
||||
uc = get_conds_with_caching(modules.prompt_parser.get_learned_conditioning, p.negative_prompts, p.steps * step_multiplier, cached_uc)
|
||||
|
|
@ -921,7 +937,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||
output_images.append(image_mask_composite)
|
||||
del x_samples_ddim
|
||||
devices.torch_gc()
|
||||
shared.state.nextjob()
|
||||
|
||||
t1 = time.time()
|
||||
shared.log.info(f'Processed: images={len(output_images)} time={t1 - t0:.2f}s its={(p.steps * len(output_images)) / (t1 - t0):.2f} memory={modules.memstats.memory_stats()}')
|
||||
|
|
@ -1044,12 +1059,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||
self.is_hr_pass = False
|
||||
return
|
||||
self.is_hr_pass = True
|
||||
if not shared.state.processing_has_refined_job_count:
|
||||
if shared.state.job_count == -1:
|
||||
shared.state.job_count = self.n_iter
|
||||
shared.state.job_count = shared.state.job_count * 2
|
||||
shared.state.processing_has_refined_job_count = True
|
||||
hypertile_set(self, hr=True)
|
||||
shared.state.job_count = 2 * self.n_iter
|
||||
shared.log.debug(f'Init hires: upscaler="{self.hr_upscaler}" sampler="{self.latent_sampler}" resize={self.hr_resize_x}x{self.hr_resize_y} upscale={self.hr_upscale_to_x}x{self.hr_upscale_to_y}')
|
||||
|
||||
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
|
||||
|
|
@ -1069,11 +1080,13 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||
self.sampler.initialize(self)
|
||||
x = create_random_tensors([4, self.height // 8, self.width // 8], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
|
||||
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
|
||||
shared.state.nextjob()
|
||||
if not self.enable_hr or shared.state.interrupted or shared.state.skipped:
|
||||
return samples
|
||||
|
||||
self.init_hr()
|
||||
if self.is_hr_pass:
|
||||
prev_job = shared.state.job
|
||||
target_width = self.hr_upscale_to_x
|
||||
target_height = self.hr_upscale_to_y
|
||||
decoded_samples = None
|
||||
|
|
@ -1091,6 +1104,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||
self.extra_generation_params, self.restore_faces = bak_extra_generation_params, bak_restore_faces
|
||||
images.save_image(image, self.outpath_samples, "", seeds[i], prompts[i], shared.opts.samples_format, info=info, suffix="-before-hires")
|
||||
if latent_scale_mode is None or self.hr_force: # non-latent upscaling
|
||||
shared.state.job = 'upscale'
|
||||
if decoded_samples is None:
|
||||
decoded_samples = decode_first_stage(self.sd_model, samples.to(dtype=devices.dtype_vae), self.full_quality)
|
||||
decoded_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
|
@ -1120,6 +1134,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||
if self.latent_sampler == "PLMS":
|
||||
self.latent_sampler = 'UniPC'
|
||||
if self.hr_force or latent_scale_mode is not None:
|
||||
shared.state.job = 'hires'
|
||||
if self.denoising_strength > 0:
|
||||
self.ops.append('hires')
|
||||
devices.torch_gc() # GC now before running the next img2img to prevent running out of memory
|
||||
|
|
@ -1135,8 +1150,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||
else:
|
||||
self.ops.append('upscale')
|
||||
x = None
|
||||
shared.state.nextjob()
|
||||
self.is_hr_pass = False
|
||||
shared.state.job = prev_job
|
||||
shared.state.nextjob()
|
||||
|
||||
return samples
|
||||
|
||||
|
|
@ -1301,6 +1317,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
|||
samples = samples * self.nmask + self.init_latent * self.mask
|
||||
del x
|
||||
devices.torch_gc()
|
||||
shared.state.nextjob()
|
||||
return samples
|
||||
|
||||
def get_token_merging_ratio(self, for_hr=False):
|
||||
|
|
|
|||
|
|
@ -63,14 +63,6 @@ def process_diffusers(p: StableDiffusionProcessing, seeds, prompts, negative_pro
|
|||
|
||||
def diffusers_callback(step: int, _timestep: int, latents: torch.FloatTensor):
|
||||
shared.state.sampling_step = step
|
||||
if p.is_hr_pass:
|
||||
shared.state.job = 'hires'
|
||||
shared.state.sampling_steps = p.hr_second_pass_steps # add optional hires
|
||||
elif p.is_refiner_pass:
|
||||
shared.state.job = 'refine'
|
||||
shared.state.sampling_steps = calculate_refiner_steps() # add optional refiner
|
||||
else:
|
||||
shared.state.sampling_steps = p.steps # base steps
|
||||
shared.state.current_latent = latents
|
||||
if shared.state.interrupted or shared.state.skipped:
|
||||
raise AssertionError('Interrupted...')
|
||||
|
|
@ -133,6 +125,8 @@ def process_diffusers(p: StableDiffusionProcessing, seeds, prompts, negative_pro
|
|||
return encoded
|
||||
|
||||
def vae_decode(latents, model, output_type='np', full_quality=True):
|
||||
prev_job = shared.state.job
|
||||
shared.state.job = 'vae'
|
||||
if not torch.is_tensor(latents): # already decoded
|
||||
return latents
|
||||
if latents.shape[0] == 0:
|
||||
|
|
@ -150,6 +144,7 @@ def process_diffusers(p: StableDiffusionProcessing, seeds, prompts, negative_pro
|
|||
else:
|
||||
decoded = taesd_vae_decode(latents=latents)
|
||||
imgs = model.image_processor.postprocess(decoded, output_type=output_type)
|
||||
shared.state.job = prev_job
|
||||
return imgs
|
||||
|
||||
def vae_encode(image, model, full_quality=True): # pylint: disable=unused-variable
|
||||
|
|
@ -186,16 +181,17 @@ def process_diffusers(p: StableDiffusionProcessing, seeds, prompts, negative_pro
|
|||
|
||||
def task_specific_kwargs(model):
|
||||
task_args = {}
|
||||
if sd_models.get_diffusers_task(model) == sd_models.DiffusersTaskType.TEXT_2_IMAGE:
|
||||
is_img2img_model = bool("Zero123" in shared.sd_model.__class__.__name__)
|
||||
if sd_models.get_diffusers_task(model) == sd_models.DiffusersTaskType.TEXT_2_IMAGE and not is_img2img_model:
|
||||
p.ops.append('txt2img')
|
||||
task_args = {"height": 8 * math.ceil(p.height / 8), "width": 8 * math.ceil(p.width / 8)}
|
||||
elif sd_models.get_diffusers_task(model) == sd_models.DiffusersTaskType.IMAGE_2_IMAGE and len(getattr(p, 'init_images' ,[])) > 0:
|
||||
elif (sd_models.get_diffusers_task(model) == sd_models.DiffusersTaskType.IMAGE_2_IMAGE or is_img2img_model) and len(getattr(p, 'init_images' ,[])) > 0:
|
||||
p.ops.append('img2img')
|
||||
task_args = {"image": p.init_images, "strength": p.denoising_strength}
|
||||
elif sd_models.get_diffusers_task(model) == sd_models.DiffusersTaskType.INSTRUCT and len(getattr(p, 'init_images' ,[])) > 0:
|
||||
p.ops.append('instruct')
|
||||
task_args = {"height": 8 * math.ceil(p.height / 8), "width": 8 * math.ceil(p.width / 8), "image": p.init_images, "strength": p.denoising_strength}
|
||||
elif sd_models.get_diffusers_task(model) == sd_models.DiffusersTaskType.INPAINTING and len(getattr(p, 'init_images' ,[])) > 0:
|
||||
elif (sd_models.get_diffusers_task(model) == sd_models.DiffusersTaskType.INPAINTING or is_img2img_model) and len(getattr(p, 'init_images' ,[])) > 0:
|
||||
p.ops.append('inpaint')
|
||||
if getattr(p, 'mask', None) is None:
|
||||
p.mask = TF.to_pil_image(torch.ones_like(TF.to_tensor(p.init_images[0]))).convert("L")
|
||||
|
|
@ -388,6 +384,7 @@ def process_diffusers(p: StableDiffusionProcessing, seeds, prompts, negative_pro
|
|||
clip_skip=p.clip_skip,
|
||||
desc='Base',
|
||||
)
|
||||
shared.state.sampling_steps = base_args['num_inference_steps']
|
||||
p.extra_generation_params['CFG rescale'] = p.diffusers_guidance_rescale
|
||||
p.extra_generation_params["Sampler Eta"] = shared.opts.scheduler_eta if shared.opts.scheduler_eta is not None and shared.opts.scheduler_eta > 0 and shared.opts.scheduler_eta < 1 else None
|
||||
try:
|
||||
|
|
@ -403,6 +400,7 @@ def process_diffusers(p: StableDiffusionProcessing, seeds, prompts, negative_pro
|
|||
if hasattr(shared.sd_model, 'embedding_db') and len(shared.sd_model.embedding_db.embeddings_used) > 0:
|
||||
p.extra_generation_params['Embeddings'] = ', '.join(shared.sd_model.embedding_db.embeddings_used)
|
||||
|
||||
shared.state.nextjob()
|
||||
if shared.state.interrupted or shared.state.skipped:
|
||||
return results
|
||||
|
||||
|
|
@ -412,10 +410,12 @@ def process_diffusers(p: StableDiffusionProcessing, seeds, prompts, negative_pro
|
|||
latent_scale_mode = shared.latent_upscale_modes.get(p.hr_upscaler, None) if (hasattr(p, "hr_upscaler") and p.hr_upscaler is not None) else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "None")
|
||||
if p.is_hr_pass:
|
||||
p.init_hr()
|
||||
prev_job = shared.state.job
|
||||
if p.width != p.hr_upscale_to_x or p.height != p.hr_upscale_to_y:
|
||||
p.ops.append('upscale')
|
||||
if shared.opts.save and not p.do_not_save_samples and shared.opts.save_images_before_highres_fix and hasattr(shared.sd_model, 'vae'):
|
||||
save_intermediate(latents=output.images, suffix="-before-hires")
|
||||
shared.state.job = 'upscale'
|
||||
output.images = hires_resize(latents=output.images)
|
||||
if latent_scale_mode is not None or p.hr_force:
|
||||
p.ops.append('hires')
|
||||
|
|
@ -438,15 +438,22 @@ def process_diffusers(p: StableDiffusionProcessing, seeds, prompts, negative_pro
|
|||
strength=p.denoising_strength,
|
||||
desc='Hires',
|
||||
)
|
||||
shared.state.job = 'hires'
|
||||
shared.state.sampling_steps = hires_args['num_inference_steps']
|
||||
try:
|
||||
output = shared.sd_model(**hires_args) # pylint: disable=not-callable
|
||||
except AssertionError as e:
|
||||
shared.log.info(e)
|
||||
p.init_images = []
|
||||
shared.state.job = prev_job
|
||||
shared.state.nextjob()
|
||||
p.is_hr_pass = False
|
||||
|
||||
# optional refiner pass or decode
|
||||
if is_refiner_enabled:
|
||||
prev_job = shared.state.job
|
||||
shared.state.job = 'refine'
|
||||
shared.state.job_count +=1
|
||||
if shared.opts.save and not p.do_not_save_samples and shared.opts.save_images_before_refiner and hasattr(shared.sd_model, 'vae'):
|
||||
save_intermediate(latents=output.images, suffix="-before-refiner")
|
||||
if shared.opts.diffusers_move_base and not getattr(shared.sd_model, 'has_accelerate', False):
|
||||
|
|
@ -491,6 +498,7 @@ def process_diffusers(p: StableDiffusionProcessing, seeds, prompts, negative_pro
|
|||
clip_skip=p.clip_skip,
|
||||
desc='Refiner',
|
||||
)
|
||||
shared.state.sampling_steps = refiner_args['num_inference_steps']
|
||||
try:
|
||||
refiner_output = shared.sd_refiner(**refiner_args) # pylint: disable=not-callable
|
||||
except AssertionError as e:
|
||||
|
|
@ -505,7 +513,9 @@ def process_diffusers(p: StableDiffusionProcessing, seeds, prompts, negative_pro
|
|||
shared.log.debug('Moving to CPU: model=refiner')
|
||||
shared.sd_refiner.to(devices.cpu)
|
||||
devices.torch_gc()
|
||||
p.is_refiner_pass = True
|
||||
shared.state.job = prev_job
|
||||
shared.state.nextjob()
|
||||
p.is_refiner_pass = False
|
||||
|
||||
# final decode since there is no refiner
|
||||
if not is_refiner_enabled:
|
||||
|
|
|
|||
|
|
@ -66,15 +66,20 @@ def progressapi(req: ProgressRequest):
|
|||
paused = shared.state.paused
|
||||
if not active:
|
||||
return InternalProgressResponse(job=shared.state.job, active=active, queued=queued, paused=paused, completed=completed, id_live_preview=-1, textinfo="Queued..." if queued else "Waiting...")
|
||||
progress = 0
|
||||
if shared.state.job_count > 0:
|
||||
progress += shared.state.job_no / shared.state.job_count
|
||||
if shared.state.sampling_steps > 0 and shared.state.job_count > 0:
|
||||
progress += 1 / (shared.state.job_count / 2 if shared.state.processing_has_refined_job_count else 1) * shared.state.sampling_step / shared.state.sampling_steps
|
||||
progress = min(progress, 1)
|
||||
if shared.state.job_no > shared.state.job_count:
|
||||
shared.state.job_count = shared.state.job_no
|
||||
batch_x = max(shared.state.job_no, 0)
|
||||
batch_y = max(shared.state.job_count, 1)
|
||||
step_x = max(shared.state.sampling_step, 0)
|
||||
step_y = max(shared.state.sampling_steps, 1)
|
||||
current = step_y * batch_x + step_x
|
||||
total = step_y * batch_y
|
||||
progress = min(1, current / total if total > 0 else 0)
|
||||
|
||||
elapsed_since_start = time.time() - shared.state.time_start
|
||||
predicted_duration = elapsed_since_start / progress if progress > 0 else None
|
||||
eta = predicted_duration - elapsed_since_start if predicted_duration is not None else None
|
||||
|
||||
id_live_preview = req.id_live_preview
|
||||
live_preview = None
|
||||
shared.state.set_current_image()
|
||||
|
|
@ -83,4 +88,6 @@ def progressapi(req: ProgressRequest):
|
|||
shared.state.current_image.save(buffered, format='jpeg')
|
||||
live_preview = f'data:image/jpeg;base64,{base64.b64encode(buffered.getvalue()).decode("ascii")}'
|
||||
id_live_preview = shared.state.id_live_preview
|
||||
return InternalProgressResponse(job=shared.state.job, active=active, queued=queued, paused=paused, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo)
|
||||
|
||||
res = InternalProgressResponse(job=shared.state.job, active=active, queued=queued, paused=paused, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo)
|
||||
return res
|
||||
|
|
|
|||
|
|
@ -321,6 +321,7 @@ class ScriptRunner:
|
|||
self.paste_field_names = []
|
||||
self.script_load_ctr = 0
|
||||
self.is_img2img = False
|
||||
self.inputs = [None]
|
||||
|
||||
def initialize_scripts(self, is_img2img):
|
||||
from modules import scripts_auto_postprocessing
|
||||
|
|
@ -355,6 +356,31 @@ class ScriptRunner:
|
|||
except Exception as e:
|
||||
log.error(f'Script initialize: {path} {e}')
|
||||
|
||||
def create_script_ui(self, script):
|
||||
import modules.api.models as api_models
|
||||
script.args_from = len(self.inputs)
|
||||
script.args_to = len(self.inputs)
|
||||
controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
|
||||
if controls is None:
|
||||
return
|
||||
script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower()
|
||||
api_args = []
|
||||
for control in controls:
|
||||
control.custom_script_source = os.path.basename(script.filename)
|
||||
arg_info = api_models.ScriptArg(label=control.label or "")
|
||||
for field in ("value", "minimum", "maximum", "step", "choices"):
|
||||
v = getattr(control, field, None)
|
||||
if v is not None:
|
||||
setattr(arg_info, field, v)
|
||||
api_args.append(arg_info)
|
||||
script.api_info = api_models.ScriptInfo(name=script.name, is_img2img=script.is_img2img, is_alwayson=script.alwayson, args=api_args)
|
||||
if script.infotext_fields is not None:
|
||||
self.infotext_fields += script.infotext_fields
|
||||
if script.paste_field_names is not None:
|
||||
self.paste_field_names += script.paste_field_names
|
||||
self.inputs += controls
|
||||
script.args_to = len(self.inputs)
|
||||
|
||||
def setup_ui_for_section(self, section, scriptlist=None):
|
||||
if scriptlist is None:
|
||||
scriptlist = self.alwayson_scripts
|
||||
|
|
@ -377,7 +403,7 @@ class ScriptRunner:
|
|||
inputs = []
|
||||
inputs_alwayson = [True]
|
||||
|
||||
def create_script_ui(script, inputs, inputs_alwayson):
|
||||
def create_script_ui(script, inputs, inputs_alwayson): # TODO this is legacy implementation, see self.create_script_ui
|
||||
script.args_from = len(inputs)
|
||||
script.args_to = len(inputs)
|
||||
controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
|
||||
|
|
@ -445,7 +471,7 @@ class ScriptRunner:
|
|||
for script in self.alwayson_scripts:
|
||||
t0 = time.time()
|
||||
elem_id = f'script_{"txt2img" if script.is_txt2img else "img2img"}_{script.title().lower().replace(" ", "_")}'
|
||||
with gr.Group(elem_id=elem_id) as group:
|
||||
with gr.Group(elem_id=elem_id, elem_classes=['extension-script']) as group:
|
||||
create_script_ui(script, inputs, inputs_alwayson)
|
||||
script.group = group
|
||||
time_setup[script.title()] = time_setup.get(script.title(), 0) + (time.time()-t0)
|
||||
|
|
|
|||
|
|
@ -88,7 +88,7 @@ class ScriptPostprocessingRunner:
|
|||
def setup_ui(self):
|
||||
inputs = []
|
||||
for script in self.scripts_in_preferred_order():
|
||||
with gr.Accordion(label=script.name, open=False) as group:
|
||||
with gr.Accordion(label=script.name, open=False, elem_classes=['postprocess']) as group:
|
||||
self.create_script_ui(script, inputs)
|
||||
script.group = group
|
||||
self.ui_created = True
|
||||
|
|
|
|||
|
|
@ -317,7 +317,7 @@ def get_xformers_flash_attention_op(q, k, v):
|
|||
return None
|
||||
|
||||
try:
|
||||
flash_attention_op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp
|
||||
flash_attention_op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp # pylint: disable=used-before-assignment
|
||||
fw, _bw = flash_attention_op
|
||||
if fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)):
|
||||
return flash_attention_op
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ from modules import paths, shared, shared_items, shared_state, modelloader, devi
|
|||
from modules.sd_hijack_inpainting import do_inpainting_hijack
|
||||
from modules.timer import Timer
|
||||
from modules.memstats import memory_stats
|
||||
from modules.paths_internal import models_path, script_path
|
||||
from modules.paths import models_path, script_path
|
||||
|
||||
try:
|
||||
import diffusers
|
||||
|
|
@ -848,6 +848,8 @@ def load_diffuser(checkpoint_info=None, already_loaded_state_dict=None, timer=No
|
|||
vae = sd_vae.load_vae_diffusers(checkpoint_info.path, vae_file, vae_source)
|
||||
if vae is not None:
|
||||
diffusers_load_config["vae"] = vae
|
||||
if 'LCM' in checkpoint_info.path:
|
||||
diffusers_load_config['custom_pipeline'] = 'latent_consistency_txt2img'
|
||||
|
||||
if os.path.isdir(checkpoint_info.path):
|
||||
err1 = None
|
||||
|
|
@ -858,18 +860,21 @@ def load_diffuser(checkpoint_info=None, already_loaded_state_dict=None, timer=No
|
|||
sd_model.model_type = sd_model.__class__.__name__
|
||||
except Exception as e:
|
||||
err1 = e
|
||||
# shared.log.error(f'AutoPipeline: {e}')
|
||||
try: # try diffusion pipeline next second-best choice, works for most non-linked pipelines
|
||||
if err1 is not None:
|
||||
sd_model = diffusers.DiffusionPipeline.from_pretrained(checkpoint_info.path, cache_dir=shared.opts.diffusers_dir, **diffusers_load_config)
|
||||
sd_model.model_type = sd_model.__class__.__name__
|
||||
except Exception as e:
|
||||
err2 = e
|
||||
# shared.log.error(f'DiffusionPipeline: {e}')
|
||||
try: # try basic pipeline next just in case
|
||||
if err2 is not None:
|
||||
sd_model = diffusers.StableDiffusionPipeline.from_pretrained(checkpoint_info.path, cache_dir=shared.opts.diffusers_dir, **diffusers_load_config)
|
||||
sd_model.model_type = sd_model.__class__.__name__
|
||||
except Exception as e:
|
||||
err3 = e # ignore last error
|
||||
shared.log.error(f'StableDiffusionPipeline: {e}')
|
||||
if err3 is not None:
|
||||
shared.log.error(f'Failed loading {op}: {checkpoint_info.path} auto={err1} diffusion={err2}')
|
||||
return
|
||||
|
|
@ -1155,7 +1160,7 @@ def reload_model_weights(sd_model=None, info=None, reuse_dict=False, op='model')
|
|||
return None
|
||||
orig_state = copy.deepcopy(shared.state)
|
||||
shared.state = shared_state.State()
|
||||
shared.state.begin(f'load-{op}')
|
||||
shared.state.begin('load')
|
||||
if load_dict:
|
||||
shared.log.debug(f'Model dict: existing={sd_model is not None} target={checkpoint_info.filename} info={info}')
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -4,10 +4,10 @@ import torch
|
|||
|
||||
from modules import paths, sd_disable_initialization, devices
|
||||
|
||||
sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion")
|
||||
sd_repo_configs_path = 'configs'
|
||||
config_default = paths.sd_default_config
|
||||
config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")
|
||||
config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
|
||||
config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference-512-base.yaml")
|
||||
config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-768-v.yaml")
|
||||
config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
|
||||
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
|
||||
config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml")
|
||||
|
|
|
|||
|
|
@ -173,7 +173,12 @@ class CFGDenoiser(torch.nn.Module):
|
|||
else:
|
||||
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
|
||||
if self.mask is not None:
|
||||
denoised = self.init_latent * self.mask + self.nmask * denoised
|
||||
if devices.backend == "directml":
|
||||
self.init_latent = self.init_latent.float()
|
||||
denoised = self.init_latent * self.mask + self.nmask * denoised
|
||||
self.init_latent = self.init_latent.half()
|
||||
else:
|
||||
denoised = self.init_latent * self.mask + self.nmask * denoised
|
||||
after_cfg_callback_params = AfterCFGCallbackParams(denoised, shared.state.sampling_step, shared.state.sampling_steps)
|
||||
cfg_after_cfg_callback(after_cfg_callback_params)
|
||||
denoised = after_cfg_callback_params.x
|
||||
|
|
@ -333,6 +338,7 @@ class KDiffusionSampler:
|
|||
's_min_uncond': self.s_min_uncond
|
||||
}
|
||||
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||
samples = samples.type(devices.dtype)
|
||||
return samples
|
||||
|
||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
|
|
|
|||
|
|
@ -49,6 +49,7 @@ class CFGDenoiserTimesteps(CFGDenoiser):
|
|||
|
||||
self.alphas = shared.sd_model.alphas_cumprod
|
||||
self.mask_before_denoising = True
|
||||
self.model_wrap = None
|
||||
|
||||
def get_pred_x0(self, x_in, x_out, sigma):
|
||||
ts = sigma.to(dtype=int)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import collections
|
|||
import glob
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
from modules import shared, paths, paths_internal, devices, script_callbacks, sd_models
|
||||
from modules import shared, paths, devices, script_callbacks, sd_models
|
||||
|
||||
|
||||
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
|
||||
|
|
@ -200,7 +200,7 @@ def load_vae_diffusers(model_file, vae_file=None, vae_source="unknown-source"):
|
|||
import diffusers
|
||||
if os.path.isfile(vae_file):
|
||||
_pipeline, model_type = sd_models.detect_pipeline(model_file, 'vae')
|
||||
diffusers_load_config = { "config_file": paths_internal.sd_default_config if model_type != 'Stable Diffusion XL' else os.path.join(paths_internal.sd_configs_path, 'sd_xl_base.yaml')}
|
||||
diffusers_load_config = { "config_file": paths.sd_default_config if model_type != 'Stable Diffusion XL' else os.path.join(paths.sd_configs_path, 'sd_xl_base.yaml')}
|
||||
vae = diffusers.AutoencoderKL.from_single_file(vae_file, **diffusers_load_config)
|
||||
vae = vae.to(devices.dtype_vae)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -12,13 +12,13 @@ import gradio as gr
|
|||
import fasteners
|
||||
from rich.console import Console
|
||||
from modules import errors, shared_items, shared_state, cmd_args, ui_components, theme
|
||||
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # pylint: disable=W0611
|
||||
from modules.paths import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # pylint: disable=W0611
|
||||
from modules.dml import memory_providers, default_memory_provider, directml_do_hijack
|
||||
import modules.interrogate
|
||||
import modules.memmon
|
||||
import modules.styles
|
||||
import modules.devices as devices # pylint: disable=R0402
|
||||
import modules.paths_internal as paths
|
||||
import modules.paths as paths
|
||||
from installer import print_dict
|
||||
from installer import log as central_logger # pylint: disable=E0611
|
||||
|
||||
|
|
@ -337,7 +337,7 @@ options_templates.update(options_section(('diffusers', "Diffusers Settings"), {
|
|||
"diffusers_attention_slicing": OptionInfo(False, "Enable attention slicing"),
|
||||
"diffusers_model_load_variant": OptionInfo("default", "Diffusers model loading variant", gr.Radio, {"choices": ['default', 'fp32', 'fp16']}),
|
||||
"diffusers_vae_load_variant": OptionInfo("default", "Diffusers VAE loading variant", gr.Radio, {"choices": ['default', 'fp32', 'fp16']}),
|
||||
"custom_diffusers_pipeline": OptionInfo('hf-internal-testing/diffusers-dummy-pipeline', 'Custom Diffusers pipeline to use'),
|
||||
"custom_diffusers_pipeline": OptionInfo('', 'Load custom Diffusers pipeline'),
|
||||
"diffusers_lora_loader": OptionInfo("diffusers" if cmd_opts.use_openvino else "sequential apply", "Diffusers LoRA loading variant", gr.Radio, {"choices": ['diffusers', 'sequential apply', 'merge and apply']}),
|
||||
"diffusers_force_zeros": OptionInfo(True, "Force zeros for prompts when empty"),
|
||||
"diffusers_aesthetics_score": OptionInfo(False, "Require aesthetics score"),
|
||||
|
|
@ -346,8 +346,8 @@ options_templates.update(options_section(('diffusers', "Diffusers Settings"), {
|
|||
}))
|
||||
|
||||
options_templates.update(options_section(('system-paths', "System Paths"), {
|
||||
"temp_dir": OptionInfo("", "Directory for temporary images; leave empty for default", folder=True),
|
||||
"clean_temp_dir_at_start": OptionInfo(True, "Cleanup non-default temporary directory when starting webui"),
|
||||
"models_paths_sep_options": OptionInfo("<h2>Models paths</h2>", "", gr.HTML),
|
||||
"models_dir": OptionInfo('models', "Base path where all models are stored", folder=True),
|
||||
"ckpt_dir": OptionInfo(os.path.join(paths.models_path, 'Stable-diffusion'), "Folder with stable diffusion models", folder=True),
|
||||
"diffusers_dir": OptionInfo(os.path.join(paths.models_path, 'Diffusers'), "Folder with Hugggingface models", folder=True),
|
||||
"vae_dir": OptionInfo(os.path.join(paths.models_path, 'VAE'), "Folder with VAE files", folder=True),
|
||||
|
|
@ -366,6 +366,10 @@ options_templates.update(options_section(('system-paths', "System Paths"), {
|
|||
"swinir_models_path": OptionInfo(os.path.join(paths.models_path, 'SwinIR'), "Folder with SwinIR models", folder=True),
|
||||
"ldsr_models_path": OptionInfo(os.path.join(paths.models_path, 'LDSR'), "Folder with LDSR models", folder=True),
|
||||
"clip_models_path": OptionInfo(os.path.join(paths.models_path, 'CLIP'), "Folder with CLIP models", folder=True),
|
||||
|
||||
"other_paths_sep_options": OptionInfo("<h2>Other paths</h2>", "", gr.HTML),
|
||||
"temp_dir": OptionInfo("", "Directory for temporary images; leave empty for default", folder=True),
|
||||
"clean_temp_dir_at_start": OptionInfo(True, "Cleanup non-default temporary directory when starting webui"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('saving-images', "Image Options"), {
|
||||
|
|
@ -474,7 +478,7 @@ options_templates.update(options_section(('sampler-params', "Sampler Settings"),
|
|||
"schedulers_use_karras": OptionInfo(True, "Use Karras sigmas", gr.Checkbox, {"visible": False}),
|
||||
"schedulers_use_thresholding": OptionInfo(False, "Use dynamic thresholding", gr.Checkbox, {"visible": False}),
|
||||
"schedulers_use_loworder": OptionInfo(True, "Use simplified solvers in final steps", gr.Checkbox, {"visible": False}),
|
||||
"schedulers_prediction_type": OptionInfo("default", "Override model prediction type", gr.Radio, {"choices": ['default', 'epsilon', 'sample', 'v_prediction'], "visible": False}),
|
||||
"schedulers_prediction_type": OptionInfo("default", "Override model prediction type", gr.Radio, {"choices": ['default', 'epsilon', 'sample', 'v_prediction']}),
|
||||
|
||||
# managed from ui.py for backend diffusers
|
||||
"schedulers_sep_diffusers": OptionInfo("<h2>Diffusers specific config</h2>", "", gr.HTML),
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ class State:
|
|||
job_no = 0
|
||||
job_count = 0
|
||||
total_jobs = 0
|
||||
processing_has_refined_job_count = False
|
||||
job_timestamp = '0'
|
||||
sampling_step = 0
|
||||
sampling_steps = 0
|
||||
|
|
@ -72,7 +71,6 @@ class State:
|
|||
self.job_no = 0
|
||||
self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
self.paused = False
|
||||
self.processing_has_refined_job_count = False
|
||||
self.sampling_step = 0
|
||||
self.skipped = False
|
||||
self.textinfo = None
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from modules import paths
|
|||
|
||||
|
||||
class Style():
|
||||
def __init__(self, name: str, desc: str = "", prompt: str = "", negative_prompt: str = "", extra: str = "", filename: str = "", preview: str = ""):
|
||||
def __init__(self, name: str, desc: str = "", prompt: str = "", negative_prompt: str = "", extra: str = "", filename: str = "", preview: str = "", mtime: float = 0):
|
||||
self.name = name
|
||||
self.description = desc
|
||||
self.prompt = prompt
|
||||
|
|
@ -17,6 +17,7 @@ class Style():
|
|||
self.extra = extra
|
||||
self.filename = filename
|
||||
self.preview = preview
|
||||
self.mtime = mtime
|
||||
|
||||
def merge_prompts(style_prompt: str, prompt: str) -> str:
|
||||
if "{prompt}" in style_prompt:
|
||||
|
|
@ -105,7 +106,8 @@ class StyleDatabase:
|
|||
negative_prompt=style.get("negative", ""),
|
||||
extra=style.get("extra", ""),
|
||||
preview=style.get("preview", None),
|
||||
filename=fn
|
||||
filename=fn,
|
||||
mtime=os.path.getmtime(fn),
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f'Failed to load style: file={fn} error={e}')
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ https://github.com/madebyollin/taesd
|
|||
"""
|
||||
import os
|
||||
from PIL import Image
|
||||
from modules import devices, paths_internal
|
||||
from modules import devices, paths
|
||||
from modules.taesd.taesd import TAESD
|
||||
|
||||
taesd_models = { 'sd-decoder': None, 'sd-encoder': None, 'sdxl-decoder': None, 'sdxl-encoder': None }
|
||||
|
|
@ -25,7 +25,7 @@ def download_model(model_path):
|
|||
def model(model_class = 'sd', model_type = 'decoder'):
|
||||
vae = taesd_models[f'{model_class}-{model_type}']
|
||||
if vae is None:
|
||||
model_path = os.path.join(paths_internal.models_path, "TAESD", f"tae{model_class}_{model_type}.pth")
|
||||
model_path = os.path.join(paths.models_path, "TAESD", f"tae{model_class}_{model_type}.pth")
|
||||
download_model(model_path)
|
||||
if os.path.exists(model_path):
|
||||
from modules.shared import log
|
||||
|
|
@ -52,7 +52,7 @@ def decode(latents):
|
|||
return Image.new('RGB', (8, 8), color = (0, 0, 0))
|
||||
vae = taesd_models[f'{model_class}-decoder']
|
||||
if vae is None:
|
||||
model_path = os.path.join(paths_internal.models_path, "TAESD", f"tae{model_class}_decoder.pth")
|
||||
model_path = os.path.join(paths.models_path, "TAESD", f"tae{model_class}_decoder.pth")
|
||||
download_model(model_path)
|
||||
if os.path.exists(model_path):
|
||||
taesd_models[f'{model_class}-decoder'] = TAESD(decoder_path=model_path, encoder_path=None)
|
||||
|
|
@ -73,7 +73,7 @@ def encode(image):
|
|||
return Image.new('RGB', (8, 8), color = (0, 0, 0))
|
||||
vae = taesd_models[f'{model_class}-encoder']
|
||||
if vae is None:
|
||||
model_path = os.path.join(paths_internal.models_path, "TAESD", f"tae{model_class}_encoder.pth")
|
||||
model_path = os.path.join(paths.models_path, "TAESD", f"tae{model_class}_encoder.pth")
|
||||
download_model(model_path)
|
||||
if os.path.exists(model_path):
|
||||
taesd_models[f'{model_class}-encoder'] = TAESD(encoder_path=model_path, decoder_path=None)
|
||||
|
|
|
|||
|
|
@ -133,7 +133,7 @@ def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, t
|
|||
image = srcimage.copy()
|
||||
fontsize = 32
|
||||
if textfont is None:
|
||||
textfont = opts.font or 'html/roboto.ttf'
|
||||
textfont = opts.font or 'javascript/roboto.ttf'
|
||||
|
||||
factor = 1.5
|
||||
gradient = Image.new('RGBA', (1, image.size[1]), color=(0, 0, 0, 0))
|
||||
|
|
|
|||
|
|
@ -425,7 +425,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
|||
log_directory = f"{os.path.join(shared.cmd_opts.data_dir, 'train/log/embeddings')}"
|
||||
template_file = template_file.path
|
||||
|
||||
shared.state.job = "train-embedding"
|
||||
shared.state.job = "train"
|
||||
shared.state.textinfo = "Initializing textual inversion training..."
|
||||
shared.state.job_count = steps
|
||||
|
||||
|
|
|
|||
|
|
@ -644,7 +644,7 @@ def create_ui(startup_timer = None):
|
|||
steps, sampler_index = create_sampler_and_steps_selection(modules.sd_samplers.samplers_for_img2img, "img2img")
|
||||
|
||||
with gr.Accordion(open=False, label="Resize", elem_classes=["small-accordion"], elem_id="img2img_resize_group"):
|
||||
with FormRow():
|
||||
with gr.Row():
|
||||
resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["None", "Resize fixed", "Crop and resize", "Resize and fill", "Latent upscale"], type="index", value="None")
|
||||
|
||||
with FormRow():
|
||||
|
|
|
|||
|
|
@ -118,14 +118,14 @@ class ExtraNetworksPage:
|
|||
self.list_time = 0
|
||||
# class additional is to keep old extensions happy
|
||||
self.card = '''
|
||||
<div class='card' onclick={card_click} title='{name}' data-tab='{tabname}' data-page='{page}' data-name='{name}' data-filename='{filename}' data-tags='{tags}'>
|
||||
<div class='card' onclick={card_click} title='{name}' data-tab='{tabname}' data-page='{page}' data-name='{name}' data-filename='{filename}' data-tags='{tags}' data-mtime='{mtime}' data-size='{size}'>
|
||||
<div class='overlay'>
|
||||
<span style="display:none" class='search_term'>{search_term}</span>
|
||||
<div class='tags'></div>
|
||||
<div class='name'>{title}</div>
|
||||
</div>
|
||||
<div class='actions'>
|
||||
<span title="Get details" onclick="showCardDetails(event)">🛈</span>
|
||||
<span class='details' title="Get details" onclick="showCardDetails(event)">🛈</span>
|
||||
<div class='additional'><ul></ul></div>
|
||||
</div>
|
||||
<img class='preview' src='{preview}' style='width: {width}px; height: {height}px; object-fit: {fit}' loading='lazy'></img>
|
||||
|
|
@ -282,6 +282,8 @@ class ExtraNetworksPage:
|
|||
"search_term": item.get("search_term", ""),
|
||||
"description": item.get("description") or "",
|
||||
"card_click": item.get("onclick", '"' + html.escape(f'return cardClicked({item.get("prompt", None)}, {"true" if self.allow_negative_prompt else "false"})') + '"'),
|
||||
"mtime": item.get("mtime", 0),
|
||||
"size": item.get("size", 0),
|
||||
}
|
||||
alias = item.get("alias", None)
|
||||
if alias is not None:
|
||||
|
|
@ -392,6 +394,7 @@ class ExtraNetworksUi:
|
|||
self.button_refresh: gr.Button = None
|
||||
self.button_scan: gr.Button = None
|
||||
self.button_save: gr.Button = None
|
||||
self.button_sort: gr.Button = None
|
||||
self.button_apply: gr.Button = None
|
||||
self.button_close: gr.Button = None
|
||||
self.button_model: gr.Checkbox = None
|
||||
|
|
@ -485,7 +488,8 @@ def create_ui(container, button_parent, tabname, skip_indexing = False):
|
|||
ui.button_refresh = ToolButton(symbols.refresh, elem_id=tabname+"_extra_refresh")
|
||||
ui.button_scan = ToolButton(symbols.scan, elem_id=tabname+"_extra_scan", visible=True)
|
||||
ui.button_save = ToolButton(symbols.book, elem_id=tabname+"_extra_save", visible=False)
|
||||
ui.button_close = ToolButton(symbols.close, elem_id=tabname+"_extra_close")
|
||||
ui.button_sort = ToolButton(symbols.sort, elem_id=tabname+"_extra_sort", visible=True)
|
||||
ui.button_close = ToolButton(symbols.close, elem_id=tabname+"_extra_close", visible=True)
|
||||
ui.button_model = ToolButton(symbols.refine, elem_id=tabname+"_extra_model", visible=True)
|
||||
ui.search = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", elem_classes="textbox", lines=2, container=False)
|
||||
ui.description = gr.Textbox('', show_label=False, elem_id=tabname+"_description", elem_classes="textbox", lines=2, interactive=False, container=False)
|
||||
|
|
@ -700,9 +704,14 @@ def create_ui(container, button_parent, tabname, skip_indexing = False):
|
|||
res = show_details(text=None, img=None, desc=None, info=None, meta=None, params=params)
|
||||
return res
|
||||
|
||||
def ui_sort_cards(msg):
|
||||
shared.log.debug(f'Extra networks: {msg}')
|
||||
return msg
|
||||
|
||||
dummy_state = gr.State(value=False) # pylint: disable=abstract-class-instantiated
|
||||
button_parent.click(fn=toggle_visibility, inputs=[ui.visible], outputs=[ui.visible, container, button_parent])
|
||||
ui.button_close.click(fn=toggle_visibility, inputs=[ui.visible], outputs=[ui.visible, container])
|
||||
ui.button_sort.click(fn=ui_sort_cards, _js='sortExtraNetworks', inputs=[ui.search], outputs=[ui.description])
|
||||
ui.button_refresh.click(fn=ui_refresh_click, _js='getENActivePage', inputs=[ui.search], outputs=ui.pages)
|
||||
ui.button_scan.click(fn=ui_scan_click, _js='getENActivePage', inputs=[ui.search], outputs=ui.pages)
|
||||
ui.button_save.click(fn=ui_save_click, inputs=[], outputs=ui.details_components + [ui.details])
|
||||
|
|
|
|||
|
|
@ -30,6 +30,8 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
|
|||
"info": self.find_info(fn),
|
||||
"metadata": checkpoint.metadata,
|
||||
"onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"',
|
||||
"mtime": os.path.getmtime(checkpoint.filename),
|
||||
"size": os.path.getsize(checkpoint.filename),
|
||||
}
|
||||
yield record
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -25,6 +25,8 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
|
|||
"search_term": self.search_terms_from_path(name),
|
||||
"prompt": json.dumps(f"<hypernet:{name}:{shared.opts.extra_networks_default_multiplier}>"),
|
||||
"local_preview": f"{fn}.{shared.opts.samples_format}",
|
||||
"mtime": os.path.getmtime(path),
|
||||
"size": os.path.getsize(path),
|
||||
}
|
||||
except Exception as e:
|
||||
shared.log.debug(f"Extra networks error: type=hypernetwork file={path} {e}")
|
||||
|
|
|
|||
|
|
@ -85,6 +85,8 @@ class ExtraNetworksPageStyles(ui_extra_networks.ExtraNetworksPage):
|
|||
"extra": getattr(style, 'extra', ''),
|
||||
"local_preview": f"{fn}.{shared.opts.samples_format}",
|
||||
"onclick": '"' + html.escape(f"""return selectStyle({json.dumps(name)})""") + '"',
|
||||
"mtime": getattr(style, 'mtime', 0),
|
||||
"size": os.path.getsize(style.filename),
|
||||
}
|
||||
except Exception as e:
|
||||
shared.log.debug(f"Extra networks error: type=style file={k} {e}")
|
||||
|
|
|
|||
|
|
@ -56,6 +56,8 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
|
|||
"prompt": json.dumps(os.path.splitext(embedding.name)[0]),
|
||||
"local_preview": f"{path}.{shared.opts.samples_format}",
|
||||
"tags": tags,
|
||||
"mtime": os.path.getmtime(embedding.filename),
|
||||
"size": os.path.getsize(embedding.filename),
|
||||
}
|
||||
except Exception as e:
|
||||
shared.log.debug(f"Extra networks error: type=embedding file={embedding.filename} {e}")
|
||||
|
|
|
|||
|
|
@ -28,6 +28,8 @@ class ExtraNetworksPageVAEs(ui_extra_networks.ExtraNetworksPage):
|
|||
"info": self.find_info(fn),
|
||||
"metadata": {},
|
||||
"onclick": '"' + html.escape(f"""return selectVAE({json.dumps(name)})""") + '"',
|
||||
"mtime": os.path.getmtime(filename),
|
||||
"size": os.path.getsize(filename),
|
||||
}
|
||||
yield record
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
# TODO: a1111 compatibility item, not used
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from modules import shared, ui_common, ui_components, styles
|
||||
from modules import shared, styles
|
||||
|
||||
styles_edit_symbol = '\U0001f58c\uFE0F' # 🖌️
|
||||
styles_materialize_symbol = '\U0001f4cb' # 📋
|
||||
|
|
@ -34,7 +33,7 @@ def delete_style(name):
|
|||
return '', '', ''
|
||||
|
||||
|
||||
def materialize_styles(prompt, negative_prompt, styles):
|
||||
def materialize_styles(prompt, negative_prompt, styles): # pylint: disable=redefined-outer-name
|
||||
prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles)
|
||||
negative_prompt = shared.prompt_styles.apply_negative_styles_to_prompt(negative_prompt, styles)
|
||||
return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=negative_prompt), gr.Dropdown.update(value=[])]
|
||||
|
|
@ -45,7 +44,7 @@ def refresh_styles():
|
|||
|
||||
|
||||
class UiPromptStyles:
|
||||
def __init__(self, tabname, main_ui_prompt, main_ui_negative_prompt):
|
||||
def __init__(self, tabname, main_ui_prompt, main_ui_negative_prompt): # pylint: disable=unused-argument
|
||||
self.dropdown = gr.Dropdown(label="Styles", elem_id=f"{tabname}_styles", choices=[style.name for style in shared.prompt_styles.styles.values()], value=[], multiselect=True)
|
||||
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ networks = '🌐'
|
|||
paste = '⇦'
|
||||
refine = '⌾'
|
||||
switch = '⇅'
|
||||
sort = '⇕'
|
||||
detect = '📐'
|
||||
folder = '📂'
|
||||
random = '🎲️'
|
||||
|
|
|
|||
|
|
@ -34,7 +34,11 @@ exclude = [
|
|||
"extensions-builtin",
|
||||
"modules/lora",
|
||||
"modules/dml",
|
||||
"modules/models/diffusion",
|
||||
"modules/k-diffusion",
|
||||
"repositories/ldm",
|
||||
"repositories/taming",
|
||||
"repositories/blip",
|
||||
"repositories/codeformer",
|
||||
]
|
||||
ignore = [
|
||||
"A003", # Class attirbute shadowing builtin
|
||||
|
|
|
|||
|
|
@ -0,0 +1,2 @@
|
|||
# Comment line immediately above ownership line is reserved for related gus information. Please be careful while editing.
|
||||
#ECCN:Open Source
|
||||
|
|
@ -0,0 +1,105 @@
|
|||
# Salesforce Open Source Community Code of Conduct
|
||||
|
||||
## About the Code of Conduct
|
||||
|
||||
Equality is a core value at Salesforce. We believe a diverse and inclusive
|
||||
community fosters innovation and creativity, and are committed to building a
|
||||
culture where everyone feels included.
|
||||
|
||||
Salesforce open-source projects are committed to providing a friendly, safe, and
|
||||
welcoming environment for all, regardless of gender identity and expression,
|
||||
sexual orientation, disability, physical appearance, body size, ethnicity, nationality,
|
||||
race, age, religion, level of experience, education, socioeconomic status, or
|
||||
other similar personal characteristics.
|
||||
|
||||
The goal of this code of conduct is to specify a baseline standard of behavior so
|
||||
that people with different social values and communication styles can work
|
||||
together effectively, productively, and respectfully in our open source community.
|
||||
It also establishes a mechanism for reporting issues and resolving conflicts.
|
||||
|
||||
All questions and reports of abusive, harassing, or otherwise unacceptable behavior
|
||||
in a Salesforce open-source project may be reported by contacting the Salesforce
|
||||
Open Source Conduct Committee at ossconduct@salesforce.com.
|
||||
|
||||
## Our Pledge
|
||||
|
||||
In the interest of fostering an open and welcoming environment, we as
|
||||
contributors and maintainers pledge to making participation in our project and
|
||||
our community a harassment-free experience for everyone, regardless of gender
|
||||
identity and expression, sexual orientation, disability, physical appearance,
|
||||
body size, ethnicity, nationality, race, age, religion, level of experience, education,
|
||||
socioeconomic status, or other similar personal characteristics.
|
||||
|
||||
## Our Standards
|
||||
|
||||
Examples of behavior that contributes to creating a positive environment
|
||||
include:
|
||||
|
||||
* Using welcoming and inclusive language
|
||||
* Being respectful of differing viewpoints and experiences
|
||||
* Gracefully accepting constructive criticism
|
||||
* Focusing on what is best for the community
|
||||
* Showing empathy toward other community members
|
||||
|
||||
Examples of unacceptable behavior by participants include:
|
||||
|
||||
* The use of sexualized language or imagery and unwelcome sexual attention or
|
||||
advances
|
||||
* Personal attacks, insulting/derogatory comments, or trolling
|
||||
* Public or private harassment
|
||||
* Publishing, or threatening to publish, others' private information—such as
|
||||
a physical or electronic address—without explicit permission
|
||||
* Other conduct which could reasonably be considered inappropriate in a
|
||||
professional setting
|
||||
* Advocating for or encouraging any of the above behaviors
|
||||
|
||||
## Our Responsibilities
|
||||
|
||||
Project maintainers are responsible for clarifying the standards of acceptable
|
||||
behavior and are expected to take appropriate and fair corrective action in
|
||||
response to any instances of unacceptable behavior.
|
||||
|
||||
Project maintainers have the right and responsibility to remove, edit, or
|
||||
reject comments, commits, code, wiki edits, issues, and other contributions
|
||||
that are not aligned with this Code of Conduct, or to ban temporarily or
|
||||
permanently any contributor for other behaviors that they deem inappropriate,
|
||||
threatening, offensive, or harmful.
|
||||
|
||||
## Scope
|
||||
|
||||
This Code of Conduct applies both within project spaces and in public spaces
|
||||
when an individual is representing the project or its community. Examples of
|
||||
representing a project or community include using an official project email
|
||||
address, posting via an official social media account, or acting as an appointed
|
||||
representative at an online or offline event. Representation of a project may be
|
||||
further defined and clarified by project maintainers.
|
||||
|
||||
## Enforcement
|
||||
|
||||
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
||||
reported by contacting the Salesforce Open Source Conduct Committee
|
||||
at ossconduct@salesforce.com. All complaints will be reviewed and investigated
|
||||
and will result in a response that is deemed necessary and appropriate to the
|
||||
circumstances. The committee is obligated to maintain confidentiality with
|
||||
regard to the reporter of an incident. Further details of specific enforcement
|
||||
policies may be posted separately.
|
||||
|
||||
Project maintainers who do not follow or enforce the Code of Conduct in good
|
||||
faith may face temporary or permanent repercussions as determined by other
|
||||
members of the project's leadership and the Salesforce Open Source Conduct
|
||||
Committee.
|
||||
|
||||
## Attribution
|
||||
|
||||
This Code of Conduct is adapted from the [Contributor Covenant][contributor-covenant-home],
|
||||
version 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html.
|
||||
It includes adaptions and additions from [Go Community Code of Conduct][golang-coc],
|
||||
[CNCF Code of Conduct][cncf-coc], and [Microsoft Open Source Code of Conduct][microsoft-coc].
|
||||
|
||||
This Code of Conduct is licensed under the [Creative Commons Attribution 3.0 License][cc-by-3-us].
|
||||
|
||||
[contributor-covenant-home]: https://www.contributor-covenant.org (https://www.contributor-covenant.org/)
|
||||
[golang-coc]: https://golang.org/conduct
|
||||
[cncf-coc]: https://github.com/cncf/foundation/blob/master/code-of-conduct.md
|
||||
[microsoft-coc]: https://opensource.microsoft.com/codeofconduct/
|
||||
[cc-by-3-us]: https://creativecommons.org/licenses/by/3.0/us/
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
Copyright (c) 2022, Salesforce.com, Inc.
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
||||
|
||||
* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
||||
|
||||
* Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
|
@ -0,0 +1,116 @@
|
|||
## BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation
|
||||
|
||||
## Announcement: BLIP is now officially integrated into [LAVIS](https://github.com/salesforce/LAVIS) - a one-stop library for language-and-vision research and applications!
|
||||
|
||||
<img src="BLIP.gif" width="700">
|
||||
|
||||
This is the PyTorch code of the <a href="https://arxiv.org/abs/2201.12086">BLIP paper</a> [[blog](https://blog.salesforceairesearch.com/blip-bootstrapping-language-image-pretraining/)]. The code has been tested on PyTorch 1.10.
|
||||
To install the dependencies, run <pre/>pip install -r requirements.txt</pre>
|
||||
|
||||
Catalog:
|
||||
- [x] Inference demo
|
||||
- [x] Pre-trained and finetuned checkpoints
|
||||
- [x] Finetuning code for Image-Text Retrieval, Image Captioning, VQA, and NLVR2
|
||||
- [x] Pre-training code
|
||||
- [x] Zero-shot video-text retrieval
|
||||
- [x] Download of bootstrapped pre-training datasets
|
||||
|
||||
|
||||
### Inference demo:
|
||||
Run our interactive demo using [Colab notebook](https://colab.research.google.com/github/salesforce/BLIP/blob/main/demo.ipynb) (no GPU needed).
|
||||
The demo includes code for:
|
||||
1. Image captioning
|
||||
2. Open-ended visual question answering
|
||||
3. Multimodal / unimodal feature extraction
|
||||
4. Image-text matching
|
||||
|
||||
Try out the [Web demo](https://huggingface.co/spaces/Salesforce/BLIP), integrated into [Huggingface Spaces 🤗](https://huggingface.co/spaces) using [Gradio](https://github.com/gradio-app/gradio).
|
||||
|
||||
Replicate web demo and Docker image is also available at [](https://replicate.com/salesforce/blip)
|
||||
|
||||
### Pre-trained checkpoints:
|
||||
Num. pre-train images | BLIP w/ ViT-B | BLIP w/ ViT-B and CapFilt-L | BLIP w/ ViT-L
|
||||
--- | :---: | :---: | :---:
|
||||
14M | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_14M.pth">Download</a>| - | -
|
||||
129M | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth">Download</a> | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large.pth">Download</a>
|
||||
|
||||
### Finetuned checkpoints:
|
||||
Task | BLIP w/ ViT-B | BLIP w/ ViT-B and CapFilt-L | BLIP w/ ViT-L
|
||||
--- | :---: | :---: | :---:
|
||||
Image-Text Retrieval (COCO) | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth">Download</a>| - | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_coco.pth">Download</a>
|
||||
Image-Text Retrieval (Flickr30k) | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_flickr.pth">Download</a>| - | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_flickr.pth">Download</a>
|
||||
Image Captioning (COCO) | - | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth">Download</a> |
|
||||
VQA | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_vqa.pth">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth">Download</a> | -
|
||||
NLVR2 | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_nlvr.pth">Download</a>| - | -
|
||||
|
||||
|
||||
### Image-Text Retrieval:
|
||||
1. Download COCO and Flickr30k datasets from the original websites, and set 'image_root' in configs/retrieval_{dataset}.yaml accordingly.
|
||||
2. To evaluate the finetuned BLIP model on COCO, run:
|
||||
<pre>python -m torch.distributed.run --nproc_per_node=8 train_retrieval.py \
|
||||
--config ./configs/retrieval_coco.yaml \
|
||||
--output_dir output/retrieval_coco \
|
||||
--evaluate</pre>
|
||||
3. To finetune the pre-trained checkpoint using 8 A100 GPUs, first set 'pretrained' in configs/retrieval_coco.yaml as "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth". Then run:
|
||||
<pre>python -m torch.distributed.run --nproc_per_node=8 train_retrieval.py \
|
||||
--config ./configs/retrieval_coco.yaml \
|
||||
--output_dir output/retrieval_coco </pre>
|
||||
|
||||
### Image-Text Captioning:
|
||||
1. Download COCO and NoCaps datasets from the original websites, and set 'image_root' in configs/caption_coco.yaml and configs/nocaps.yaml accordingly.
|
||||
2. To evaluate the finetuned BLIP model on COCO, run:
|
||||
<pre>python -m torch.distributed.run --nproc_per_node=8 train_caption.py --evaluate</pre>
|
||||
3. To evaluate the finetuned BLIP model on NoCaps, generate results with: (evaluation needs to be performed on official server)
|
||||
<pre>python -m torch.distributed.run --nproc_per_node=8 eval_nocaps.py </pre>
|
||||
4. To finetune the pre-trained checkpoint using 8 A100 GPUs, first set 'pretrained' in configs/caption_coco.yaml as "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth". Then run:
|
||||
<pre>python -m torch.distributed.run --nproc_per_node=8 train_caption.py </pre>
|
||||
|
||||
### VQA:
|
||||
1. Download VQA v2 dataset and Visual Genome dataset from the original websites, and set 'vqa_root' and 'vg_root' in configs/vqa.yaml.
|
||||
2. To evaluate the finetuned BLIP model, generate results with: (evaluation needs to be performed on official server)
|
||||
<pre>python -m torch.distributed.run --nproc_per_node=8 train_vqa.py --evaluate</pre>
|
||||
3. To finetune the pre-trained checkpoint using 16 A100 GPUs, first set 'pretrained' in configs/vqa.yaml as "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth". Then run:
|
||||
<pre>python -m torch.distributed.run --nproc_per_node=16 train_vqa.py </pre>
|
||||
|
||||
### NLVR2:
|
||||
1. Download NLVR2 dataset from the original websites, and set 'image_root' in configs/nlvr.yaml.
|
||||
2. To evaluate the finetuned BLIP model, run
|
||||
<pre>python -m torch.distributed.run --nproc_per_node=8 train_nlvr.py --evaluate</pre>
|
||||
3. To finetune the pre-trained checkpoint using 16 A100 GPUs, first set 'pretrained' in configs/nlvr.yaml as "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth". Then run:
|
||||
<pre>python -m torch.distributed.run --nproc_per_node=16 train_nlvr.py </pre>
|
||||
|
||||
### Finetune with ViT-L:
|
||||
In order to finetune a model with ViT-L, simply change the config file to set 'vit' as large. Batch size and learning rate may also need to be adjusted accordingly (please see the paper's appendix for hyper-parameter details). <a href="https://github.com/facebookresearch/fairscale">Gradient checkpoint</a> can also be activated in the config file to reduce GPU memory usage.
|
||||
|
||||
### Pre-train:
|
||||
1. Prepare training json files where each json file contains a list. Each item in the list is a dictonary with two key-value pairs: {'image': path_of_image, 'caption': text_of_image}.
|
||||
2. In configs/pretrain.yaml, set 'train_file' as the paths for the json files .
|
||||
3. Pre-train the model using 8 A100 GPUs:
|
||||
<pre>python -m torch.distributed.run --nproc_per_node=8 pretrain.py --config ./configs/Pretrain.yaml --output_dir output/Pretrain </pre>
|
||||
|
||||
### Zero-shot video-text retrieval:
|
||||
1. Download MSRVTT dataset following the instructions from https://github.com/salesforce/ALPRO, and set 'video_root' accordingly in configs/retrieval_msrvtt.yaml.
|
||||
2. Install [decord](https://github.com/dmlc/decord) with <pre>pip install decord</pre>
|
||||
3. To perform zero-shot evaluation, run
|
||||
<pre>python -m torch.distributed.run --nproc_per_node=8 eval_retrieval_video.py</pre>
|
||||
|
||||
### Pre-training datasets download:
|
||||
We provide bootstrapped pre-training datasets as json files. Each json file contains a list. Each item in the list is a dictonary with two key-value pairs: {'url': url_of_image, 'caption': text_of_image}.
|
||||
|
||||
Image source | Filtered web caption | Filtered synthetic caption by ViT-B | Filtered synthetic caption by ViT-L
|
||||
--- | :---: | :---: | :---:
|
||||
CC3M+CC12M+SBU | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/ccs_filtered.json">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/ccs_synthetic_filtered.json">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/ccs_synthetic_filtered_large.json">Download</a>
|
||||
LAION115M | <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/laion_filtered.json">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/laion_synthetic_filtered.json">Download</a>| <a href="https://storage.googleapis.com/sfr-vision-language-research/BLIP/datasets/laion_synthetic_filtered_large.json">Download</a>
|
||||
|
||||
### Citation
|
||||
If you find this code to be useful for your research, please consider citing.
|
||||
<pre>
|
||||
@inproceedings{li2022blip,
|
||||
title={BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation},
|
||||
author={Junnan Li and Dongxu Li and Caiming Xiong and Steven Hoi},
|
||||
year={2022},
|
||||
booktitle={ICML},
|
||||
}</pre>
|
||||
|
||||
### Acknowledgement
|
||||
The implementation of BLIP relies on resources from <a href="https://github.com/salesforce/ALBEF">ALBEF</a>, <a href="https://github.com/huggingface/transformers">Huggingface Transformers</a>, and <a href="https://github.com/rwightman/pytorch-image-models/tree/master/timm">timm</a>. We thank the original authors for their open-sourcing.
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
## Security
|
||||
|
||||
Please report any security issue to [security@salesforce.com](mailto:security@salesforce.com)
|
||||
as soon as it is discovered. This library limits its runtime dependencies in
|
||||
order to reduce the total cost of ownership as much as can be, but all consumers
|
||||
should remain vigilant and have their security stakeholders review all third-party
|
||||
products (3PP) like this one and their dependencies.
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
build:
|
||||
gpu: true
|
||||
cuda: "11.1"
|
||||
python_version: "3.8"
|
||||
system_packages:
|
||||
- "libgl1-mesa-glx"
|
||||
- "libglib2.0-0"
|
||||
python_packages:
|
||||
- "ipython==7.30.1"
|
||||
- "torchvision==0.11.1"
|
||||
- "torch==1.10.0"
|
||||
- "timm==0.4.12"
|
||||
- "transformers==4.15.0"
|
||||
- "fairscale==0.4.4"
|
||||
- "pycocoevalcap==1.2"
|
||||
|
||||
predict: "predict.py:Predictor"
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
{
|
||||
"architectures": [
|
||||
"BertModel"
|
||||
],
|
||||
"attention_probs_dropout_prob": 0.1,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_dropout_prob": 0.1,
|
||||
"hidden_size": 768,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 3072,
|
||||
"layer_norm_eps": 1e-12,
|
||||
"max_position_embeddings": 512,
|
||||
"model_type": "bert",
|
||||
"num_attention_heads": 12,
|
||||
"num_hidden_layers": 12,
|
||||
"pad_token_id": 0,
|
||||
"type_vocab_size": 2,
|
||||
"vocab_size": 30522,
|
||||
"encoder_width": 768,
|
||||
"add_cross_attention": true
|
||||
}
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
image_root: '/export/share/datasets/vision/coco/images/'
|
||||
ann_root: 'annotation'
|
||||
coco_gt_root: 'annotation/coco_gt'
|
||||
|
||||
# set pretrained as a file path or an url
|
||||
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
|
||||
|
||||
# size of vit model; base or large
|
||||
vit: 'base'
|
||||
vit_grad_ckpt: False
|
||||
vit_ckpt_layer: 0
|
||||
batch_size: 32
|
||||
init_lr: 1e-5
|
||||
|
||||
# vit: 'large'
|
||||
# vit_grad_ckpt: True
|
||||
# vit_ckpt_layer: 5
|
||||
# batch_size: 16
|
||||
# init_lr: 2e-6
|
||||
|
||||
image_size: 384
|
||||
|
||||
# generation configs
|
||||
max_length: 20
|
||||
min_length: 5
|
||||
num_beams: 3
|
||||
prompt: 'a picture of '
|
||||
|
||||
# optimizer
|
||||
weight_decay: 0.05
|
||||
min_lr: 0
|
||||
max_epoch: 5
|
||||
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
{
|
||||
"architectures": [
|
||||
"BertModel"
|
||||
],
|
||||
"attention_probs_dropout_prob": 0.1,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_dropout_prob": 0.1,
|
||||
"hidden_size": 768,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 3072,
|
||||
"layer_norm_eps": 1e-12,
|
||||
"max_position_embeddings": 512,
|
||||
"model_type": "bert",
|
||||
"num_attention_heads": 12,
|
||||
"num_hidden_layers": 12,
|
||||
"pad_token_id": 0,
|
||||
"type_vocab_size": 2,
|
||||
"vocab_size": 30524,
|
||||
"encoder_width": 768,
|
||||
"add_cross_attention": true
|
||||
}
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
image_root: '/export/share/datasets/vision/NLVR2/'
|
||||
ann_root: 'annotation'
|
||||
|
||||
# set pretrained as a file path or an url
|
||||
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_nlvr.pth'
|
||||
|
||||
#size of vit model; base or large
|
||||
vit: 'base'
|
||||
batch_size_train: 16
|
||||
batch_size_test: 64
|
||||
vit_grad_ckpt: False
|
||||
vit_ckpt_layer: 0
|
||||
max_epoch: 15
|
||||
|
||||
image_size: 384
|
||||
|
||||
# optimizer
|
||||
weight_decay: 0.05
|
||||
init_lr: 3e-5
|
||||
min_lr: 0
|
||||
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
image_root: '/export/share/datasets/vision/nocaps/'
|
||||
ann_root: 'annotation'
|
||||
|
||||
# set pretrained as a file path or an url
|
||||
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
|
||||
|
||||
vit: 'base'
|
||||
batch_size: 32
|
||||
|
||||
image_size: 384
|
||||
|
||||
max_length: 20
|
||||
min_length: 5
|
||||
num_beams: 3
|
||||
prompt: 'a picture of '
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
train_file: ['/export/share/junnan-li/VL_pretrain/annotation/coco_karpathy_train.json',
|
||||
'/export/share/junnan-li/VL_pretrain/annotation/vg_caption.json',
|
||||
]
|
||||
laion_path: ''
|
||||
|
||||
# size of vit model; base or large
|
||||
vit: 'base'
|
||||
vit_grad_ckpt: False
|
||||
vit_ckpt_layer: 0
|
||||
|
||||
image_size: 224
|
||||
batch_size: 75
|
||||
|
||||
queue_size: 57600
|
||||
alpha: 0.4
|
||||
|
||||
# optimizer
|
||||
weight_decay: 0.05
|
||||
init_lr: 3e-4
|
||||
min_lr: 1e-6
|
||||
warmup_lr: 1e-6
|
||||
lr_decay_rate: 0.9
|
||||
max_epoch: 20
|
||||
warmup_steps: 3000
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
image_root: '/export/share/datasets/vision/coco/images/'
|
||||
ann_root: 'annotation'
|
||||
dataset: 'coco'
|
||||
|
||||
# set pretrained as a file path or an url
|
||||
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth'
|
||||
|
||||
# size of vit model; base or large
|
||||
|
||||
vit: 'base'
|
||||
batch_size_train: 32
|
||||
batch_size_test: 64
|
||||
vit_grad_ckpt: True
|
||||
vit_ckpt_layer: 4
|
||||
init_lr: 1e-5
|
||||
|
||||
# vit: 'large'
|
||||
# batch_size_train: 16
|
||||
# batch_size_test: 32
|
||||
# vit_grad_ckpt: True
|
||||
# vit_ckpt_layer: 12
|
||||
# init_lr: 5e-6
|
||||
|
||||
image_size: 384
|
||||
queue_size: 57600
|
||||
alpha: 0.4
|
||||
k_test: 256
|
||||
negative_all_rank: True
|
||||
|
||||
# optimizer
|
||||
weight_decay: 0.05
|
||||
min_lr: 0
|
||||
max_epoch: 6
|
||||
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
image_root: '/export/share/datasets/vision/flickr30k/'
|
||||
ann_root: 'annotation'
|
||||
dataset: 'flickr'
|
||||
|
||||
# set pretrained as a file path or an url
|
||||
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_flickr.pth'
|
||||
|
||||
# size of vit model; base or large
|
||||
|
||||
vit: 'base'
|
||||
batch_size_train: 32
|
||||
batch_size_test: 64
|
||||
vit_grad_ckpt: True
|
||||
vit_ckpt_layer: 4
|
||||
init_lr: 1e-5
|
||||
|
||||
# vit: 'large'
|
||||
# batch_size_train: 16
|
||||
# batch_size_test: 32
|
||||
# vit_grad_ckpt: True
|
||||
# vit_ckpt_layer: 10
|
||||
# init_lr: 5e-6
|
||||
|
||||
image_size: 384
|
||||
queue_size: 57600
|
||||
alpha: 0.4
|
||||
k_test: 128
|
||||
negative_all_rank: False
|
||||
|
||||
# optimizer
|
||||
weight_decay: 0.05
|
||||
min_lr: 0
|
||||
max_epoch: 6
|
||||
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
video_root: '/export/share/dongxuli/data/msrvtt_retrieval/videos'
|
||||
ann_root: 'annotation'
|
||||
|
||||
# set pretrained as a file path or an url
|
||||
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth'
|
||||
|
||||
# size of vit model; base or large
|
||||
vit: 'base'
|
||||
batch_size: 64
|
||||
k_test: 128
|
||||
image_size: 384
|
||||
num_frm_test: 8
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
vqa_root: '/export/share/datasets/vision/VQA/Images/mscoco/' #followed by train2014/
|
||||
vg_root: '/export/share/datasets/vision/visual-genome/' #followed by image/
|
||||
train_files: ['vqa_train','vqa_val','vg_qa']
|
||||
ann_root: 'annotation'
|
||||
|
||||
# set pretrained as a file path or an url
|
||||
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth'
|
||||
|
||||
# size of vit model; base or large
|
||||
vit: 'base'
|
||||
batch_size_train: 16
|
||||
batch_size_test: 32
|
||||
vit_grad_ckpt: False
|
||||
vit_ckpt_layer: 0
|
||||
init_lr: 2e-5
|
||||
|
||||
image_size: 480
|
||||
|
||||
k_test: 128
|
||||
inference: 'rank'
|
||||
|
||||
# optimizer
|
||||
weight_decay: 0.05
|
||||
min_lr: 0
|
||||
max_epoch: 10
|
||||
|
|
@ -0,0 +1,101 @@
|
|||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
|
||||
from data.coco_karpathy_dataset import coco_karpathy_train, coco_karpathy_caption_eval, coco_karpathy_retrieval_eval
|
||||
from data.nocaps_dataset import nocaps_eval
|
||||
from data.flickr30k_dataset import flickr30k_train, flickr30k_retrieval_eval
|
||||
from data.vqa_dataset import vqa_dataset
|
||||
from data.nlvr_dataset import nlvr_dataset
|
||||
from data.pretrain_dataset import pretrain_dataset
|
||||
from transform.randaugment import RandomAugment
|
||||
|
||||
def create_dataset(dataset, config, min_scale=0.5):
|
||||
|
||||
normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
||||
|
||||
transform_train = transforms.Compose([
|
||||
transforms.RandomResizedCrop(config['image_size'],scale=(min_scale, 1.0),interpolation=InterpolationMode.BICUBIC),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
RandomAugment(2,5,isPIL=True,augs=['Identity','AutoContrast','Brightness','Sharpness','Equalize',
|
||||
'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),
|
||||
transforms.ToTensor(),
|
||||
normalize,
|
||||
])
|
||||
transform_test = transforms.Compose([
|
||||
transforms.Resize((config['image_size'],config['image_size']),interpolation=InterpolationMode.BICUBIC),
|
||||
transforms.ToTensor(),
|
||||
normalize,
|
||||
])
|
||||
|
||||
if dataset=='pretrain':
|
||||
dataset = pretrain_dataset(config['train_file'], config['laion_path'], transform_train)
|
||||
return dataset
|
||||
|
||||
elif dataset=='caption_coco':
|
||||
train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root'], prompt=config['prompt'])
|
||||
val_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'val')
|
||||
test_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'test')
|
||||
return train_dataset, val_dataset, test_dataset
|
||||
|
||||
elif dataset=='nocaps':
|
||||
val_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'val')
|
||||
test_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'test')
|
||||
return val_dataset, test_dataset
|
||||
|
||||
elif dataset=='retrieval_coco':
|
||||
train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root'])
|
||||
val_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val')
|
||||
test_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test')
|
||||
return train_dataset, val_dataset, test_dataset
|
||||
|
||||
elif dataset=='retrieval_flickr':
|
||||
train_dataset = flickr30k_train(transform_train, config['image_root'], config['ann_root'])
|
||||
val_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val')
|
||||
test_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test')
|
||||
return train_dataset, val_dataset, test_dataset
|
||||
|
||||
elif dataset=='vqa':
|
||||
train_dataset = vqa_dataset(transform_train, config['ann_root'], config['vqa_root'], config['vg_root'],
|
||||
train_files = config['train_files'], split='train')
|
||||
test_dataset = vqa_dataset(transform_test, config['ann_root'], config['vqa_root'], config['vg_root'], split='test')
|
||||
return train_dataset, test_dataset
|
||||
|
||||
elif dataset=='nlvr':
|
||||
train_dataset = nlvr_dataset(transform_train, config['image_root'], config['ann_root'],'train')
|
||||
val_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'val')
|
||||
test_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'test')
|
||||
return train_dataset, val_dataset, test_dataset
|
||||
|
||||
|
||||
def create_sampler(datasets, shuffles, num_tasks, global_rank):
|
||||
samplers = []
|
||||
for dataset,shuffle in zip(datasets,shuffles):
|
||||
sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle)
|
||||
samplers.append(sampler)
|
||||
return samplers
|
||||
|
||||
|
||||
def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns):
|
||||
loaders = []
|
||||
for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns):
|
||||
if is_train:
|
||||
shuffle = (sampler is None)
|
||||
drop_last = True
|
||||
else:
|
||||
shuffle = False
|
||||
drop_last = False
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=bs,
|
||||
num_workers=n_worker,
|
||||
pin_memory=True,
|
||||
sampler=sampler,
|
||||
shuffle=shuffle,
|
||||
collate_fn=collate_fn,
|
||||
drop_last=drop_last,
|
||||
)
|
||||
loaders.append(loader)
|
||||
return loaders
|
||||
|
||||
|
|
@ -0,0 +1,126 @@
|
|||
import os
|
||||
import json
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision.datasets.utils import download_url
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from data.utils import pre_caption
|
||||
|
||||
class coco_karpathy_train(Dataset):
|
||||
def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''):
|
||||
'''
|
||||
image_root (string): Root directory of images (e.g. coco/images/)
|
||||
ann_root (string): directory to store the annotation file
|
||||
'''
|
||||
url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_train.json'
|
||||
filename = 'coco_karpathy_train.json'
|
||||
|
||||
download_url(url,ann_root)
|
||||
|
||||
self.annotation = json.load(open(os.path.join(ann_root,filename),'r'))
|
||||
self.transform = transform
|
||||
self.image_root = image_root
|
||||
self.max_words = max_words
|
||||
self.prompt = prompt
|
||||
|
||||
self.img_ids = {}
|
||||
n = 0
|
||||
for ann in self.annotation:
|
||||
img_id = ann['image_id']
|
||||
if img_id not in self.img_ids.keys():
|
||||
self.img_ids[img_id] = n
|
||||
n += 1
|
||||
|
||||
def __len__(self):
|
||||
return len(self.annotation)
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
ann = self.annotation[index]
|
||||
|
||||
image_path = os.path.join(self.image_root,ann['image'])
|
||||
image = Image.open(image_path).convert('RGB')
|
||||
image = self.transform(image)
|
||||
|
||||
caption = self.prompt+pre_caption(ann['caption'], self.max_words)
|
||||
|
||||
return image, caption, self.img_ids[ann['image_id']]
|
||||
|
||||
|
||||
class coco_karpathy_caption_eval(Dataset):
|
||||
def __init__(self, transform, image_root, ann_root, split):
|
||||
'''
|
||||
image_root (string): Root directory of images (e.g. coco/images/)
|
||||
ann_root (string): directory to store the annotation file
|
||||
split (string): val or test
|
||||
'''
|
||||
urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json',
|
||||
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'}
|
||||
filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'}
|
||||
|
||||
download_url(urls[split],ann_root)
|
||||
|
||||
self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
|
||||
self.transform = transform
|
||||
self.image_root = image_root
|
||||
|
||||
def __len__(self):
|
||||
return len(self.annotation)
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
ann = self.annotation[index]
|
||||
|
||||
image_path = os.path.join(self.image_root,ann['image'])
|
||||
image = Image.open(image_path).convert('RGB')
|
||||
image = self.transform(image)
|
||||
|
||||
img_id = ann['image'].split('/')[-1].strip('.jpg').split('_')[-1]
|
||||
|
||||
return image, int(img_id)
|
||||
|
||||
|
||||
class coco_karpathy_retrieval_eval(Dataset):
|
||||
def __init__(self, transform, image_root, ann_root, split, max_words=30):
|
||||
'''
|
||||
image_root (string): Root directory of images (e.g. coco/images/)
|
||||
ann_root (string): directory to store the annotation file
|
||||
split (string): val or test
|
||||
'''
|
||||
urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json',
|
||||
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'}
|
||||
filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'}
|
||||
|
||||
download_url(urls[split],ann_root)
|
||||
|
||||
self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
|
||||
self.transform = transform
|
||||
self.image_root = image_root
|
||||
|
||||
self.text = []
|
||||
self.image = []
|
||||
self.txt2img = {}
|
||||
self.img2txt = {}
|
||||
|
||||
txt_id = 0
|
||||
for img_id, ann in enumerate(self.annotation):
|
||||
self.image.append(ann['image'])
|
||||
self.img2txt[img_id] = []
|
||||
for i, caption in enumerate(ann['caption']):
|
||||
self.text.append(pre_caption(caption,max_words))
|
||||
self.img2txt[img_id].append(txt_id)
|
||||
self.txt2img[txt_id] = img_id
|
||||
txt_id += 1
|
||||
|
||||
def __len__(self):
|
||||
return len(self.annotation)
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
image_path = os.path.join(self.image_root, self.annotation[index]['image'])
|
||||
image = Image.open(image_path).convert('RGB')
|
||||
image = self.transform(image)
|
||||
|
||||
return image, index
|
||||
|
|
@ -0,0 +1,93 @@
|
|||
import os
|
||||
import json
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision.datasets.utils import download_url
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from data.utils import pre_caption
|
||||
|
||||
class flickr30k_train(Dataset):
|
||||
def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''):
|
||||
'''
|
||||
image_root (string): Root directory of images (e.g. flickr30k/)
|
||||
ann_root (string): directory to store the annotation file
|
||||
'''
|
||||
url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_train.json'
|
||||
filename = 'flickr30k_train.json'
|
||||
|
||||
download_url(url,ann_root)
|
||||
|
||||
self.annotation = json.load(open(os.path.join(ann_root,filename),'r'))
|
||||
self.transform = transform
|
||||
self.image_root = image_root
|
||||
self.max_words = max_words
|
||||
self.prompt = prompt
|
||||
|
||||
self.img_ids = {}
|
||||
n = 0
|
||||
for ann in self.annotation:
|
||||
img_id = ann['image_id']
|
||||
if img_id not in self.img_ids.keys():
|
||||
self.img_ids[img_id] = n
|
||||
n += 1
|
||||
|
||||
def __len__(self):
|
||||
return len(self.annotation)
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
ann = self.annotation[index]
|
||||
|
||||
image_path = os.path.join(self.image_root,ann['image'])
|
||||
image = Image.open(image_path).convert('RGB')
|
||||
image = self.transform(image)
|
||||
|
||||
caption = self.prompt+pre_caption(ann['caption'], self.max_words)
|
||||
|
||||
return image, caption, self.img_ids[ann['image_id']]
|
||||
|
||||
|
||||
class flickr30k_retrieval_eval(Dataset):
|
||||
def __init__(self, transform, image_root, ann_root, split, max_words=30):
|
||||
'''
|
||||
image_root (string): Root directory of images (e.g. flickr30k/)
|
||||
ann_root (string): directory to store the annotation file
|
||||
split (string): val or test
|
||||
'''
|
||||
urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_val.json',
|
||||
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_test.json'}
|
||||
filenames = {'val':'flickr30k_val.json','test':'flickr30k_test.json'}
|
||||
|
||||
download_url(urls[split],ann_root)
|
||||
|
||||
self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
|
||||
self.transform = transform
|
||||
self.image_root = image_root
|
||||
|
||||
self.text = []
|
||||
self.image = []
|
||||
self.txt2img = {}
|
||||
self.img2txt = {}
|
||||
|
||||
txt_id = 0
|
||||
for img_id, ann in enumerate(self.annotation):
|
||||
self.image.append(ann['image'])
|
||||
self.img2txt[img_id] = []
|
||||
for i, caption in enumerate(ann['caption']):
|
||||
self.text.append(pre_caption(caption,max_words))
|
||||
self.img2txt[img_id].append(txt_id)
|
||||
self.txt2img[txt_id] = img_id
|
||||
txt_id += 1
|
||||
|
||||
def __len__(self):
|
||||
return len(self.annotation)
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
image_path = os.path.join(self.image_root, self.annotation[index]['image'])
|
||||
image = Image.open(image_path).convert('RGB')
|
||||
image = self.transform(image)
|
||||
|
||||
return image, index
|
||||
|
|
@ -0,0 +1,78 @@
|
|||
import os
|
||||
import json
|
||||
import random
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision.datasets.utils import download_url
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from data.utils import pre_caption
|
||||
|
||||
class nlvr_dataset(Dataset):
|
||||
def __init__(self, transform, image_root, ann_root, split):
|
||||
'''
|
||||
image_root (string): Root directory of images
|
||||
ann_root (string): directory to store the annotation file
|
||||
split (string): train, val or test
|
||||
'''
|
||||
urls = {'train':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_train.json',
|
||||
'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_dev.json',
|
||||
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_test.json'}
|
||||
filenames = {'train':'nlvr_train.json','val':'nlvr_dev.json','test':'nlvr_test.json'}
|
||||
|
||||
download_url(urls[split],ann_root)
|
||||
self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
|
||||
|
||||
self.transform = transform
|
||||
self.image_root = image_root
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.annotation)
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
ann = self.annotation[index]
|
||||
|
||||
image0_path = os.path.join(self.image_root,ann['images'][0])
|
||||
image0 = Image.open(image0_path).convert('RGB')
|
||||
image0 = self.transform(image0)
|
||||
|
||||
image1_path = os.path.join(self.image_root,ann['images'][1])
|
||||
image1 = Image.open(image1_path).convert('RGB')
|
||||
image1 = self.transform(image1)
|
||||
|
||||
sentence = pre_caption(ann['sentence'], 40)
|
||||
|
||||
if ann['label']=='True':
|
||||
label = 1
|
||||
else:
|
||||
label = 0
|
||||
|
||||
words = sentence.split(' ')
|
||||
|
||||
if 'left' not in words and 'right' not in words:
|
||||
if random.random()<0.5:
|
||||
return image0, image1, sentence, label
|
||||
else:
|
||||
return image1, image0, sentence, label
|
||||
else:
|
||||
if random.random()<0.5:
|
||||
return image0, image1, sentence, label
|
||||
else:
|
||||
new_words = []
|
||||
for word in words:
|
||||
if word=='left':
|
||||
new_words.append('right')
|
||||
elif word=='right':
|
||||
new_words.append('left')
|
||||
else:
|
||||
new_words.append(word)
|
||||
|
||||
sentence = ' '.join(new_words)
|
||||
return image1, image0, sentence, label
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,32 @@
|
|||
import os
|
||||
import json
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision.datasets.utils import download_url
|
||||
|
||||
from PIL import Image
|
||||
|
||||
class nocaps_eval(Dataset):
|
||||
def __init__(self, transform, image_root, ann_root, split):
|
||||
urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nocaps_val.json',
|
||||
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nocaps_test.json'}
|
||||
filenames = {'val':'nocaps_val.json','test':'nocaps_test.json'}
|
||||
|
||||
download_url(urls[split],ann_root)
|
||||
|
||||
self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
|
||||
self.transform = transform
|
||||
self.image_root = image_root
|
||||
|
||||
def __len__(self):
|
||||
return len(self.annotation)
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
ann = self.annotation[index]
|
||||
|
||||
image_path = os.path.join(self.image_root,ann['image'])
|
||||
image = Image.open(image_path).convert('RGB')
|
||||
image = self.transform(image)
|
||||
|
||||
return image, int(ann['img_id'])
|
||||
|
|
@ -0,0 +1,59 @@
|
|||
import json
|
||||
import os
|
||||
import random
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from PIL import Image
|
||||
from PIL import ImageFile
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
Image.MAX_IMAGE_PIXELS = None
|
||||
|
||||
from data.utils import pre_caption
|
||||
import os,glob
|
||||
|
||||
class pretrain_dataset(Dataset):
|
||||
def __init__(self, ann_file, laion_path, transform):
|
||||
|
||||
self.ann_pretrain = []
|
||||
for f in ann_file:
|
||||
print('loading '+f)
|
||||
ann = json.load(open(f,'r'))
|
||||
self.ann_pretrain += ann
|
||||
|
||||
self.laion_path = laion_path
|
||||
if self.laion_path:
|
||||
self.laion_files = glob.glob(os.path.join(laion_path,'*.json'))
|
||||
|
||||
print('loading '+self.laion_files[0])
|
||||
with open(self.laion_files[0],'r') as f:
|
||||
self.ann_laion = json.load(f)
|
||||
|
||||
self.annotation = self.ann_pretrain + self.ann_laion
|
||||
else:
|
||||
self.annotation = self.ann_pretrain
|
||||
|
||||
self.transform = transform
|
||||
|
||||
|
||||
def reload_laion(self, epoch):
|
||||
n = epoch%len(self.laion_files)
|
||||
print('loading '+self.laion_files[n])
|
||||
with open(self.laion_files[n],'r') as f:
|
||||
self.ann_laion = json.load(f)
|
||||
|
||||
self.annotation = self.ann_pretrain + self.ann_laion
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.annotation)
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
ann = self.annotation[index]
|
||||
|
||||
image = Image.open(ann['image']).convert('RGB')
|
||||
image = self.transform(image)
|
||||
caption = pre_caption(ann['caption'],30)
|
||||
|
||||
return image, caption
|
||||
|
|
@ -0,0 +1,112 @@
|
|||
import re
|
||||
import json
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import utils
|
||||
|
||||
def pre_caption(caption,max_words=50):
|
||||
caption = re.sub(
|
||||
r"([.!\"()*#:;~])",
|
||||
' ',
|
||||
caption.lower(),
|
||||
)
|
||||
caption = re.sub(
|
||||
r"\s{2,}",
|
||||
' ',
|
||||
caption,
|
||||
)
|
||||
caption = caption.rstrip('\n')
|
||||
caption = caption.strip(' ')
|
||||
|
||||
#truncate caption
|
||||
caption_words = caption.split(' ')
|
||||
if len(caption_words)>max_words:
|
||||
caption = ' '.join(caption_words[:max_words])
|
||||
|
||||
return caption
|
||||
|
||||
def pre_question(question,max_ques_words=50):
|
||||
question = re.sub(
|
||||
r"([.!\"()*#:;~])",
|
||||
'',
|
||||
question.lower(),
|
||||
)
|
||||
question = question.rstrip(' ')
|
||||
|
||||
#truncate question
|
||||
question_words = question.split(' ')
|
||||
if len(question_words)>max_ques_words:
|
||||
question = ' '.join(question_words[:max_ques_words])
|
||||
|
||||
return question
|
||||
|
||||
|
||||
def save_result(result, result_dir, filename, remove_duplicate=''):
|
||||
result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,utils.get_rank()))
|
||||
final_result_file = os.path.join(result_dir, '%s.json'%filename)
|
||||
|
||||
json.dump(result,open(result_file,'w'))
|
||||
|
||||
dist.barrier()
|
||||
|
||||
if utils.is_main_process():
|
||||
# combine results from all processes
|
||||
result = []
|
||||
|
||||
for rank in range(utils.get_world_size()):
|
||||
result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,rank))
|
||||
res = json.load(open(result_file,'r'))
|
||||
result += res
|
||||
|
||||
if remove_duplicate:
|
||||
result_new = []
|
||||
id_list = []
|
||||
for res in result:
|
||||
if res[remove_duplicate] not in id_list:
|
||||
id_list.append(res[remove_duplicate])
|
||||
result_new.append(res)
|
||||
result = result_new
|
||||
|
||||
json.dump(result,open(final_result_file,'w'))
|
||||
print('result file saved to %s'%final_result_file)
|
||||
|
||||
return final_result_file
|
||||
|
||||
|
||||
|
||||
from pycocotools.coco import COCO
|
||||
from pycocoevalcap.eval import COCOEvalCap
|
||||
from torchvision.datasets.utils import download_url
|
||||
|
||||
def coco_caption_eval(coco_gt_root, results_file, split):
|
||||
urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val_gt.json',
|
||||
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test_gt.json'}
|
||||
filenames = {'val':'coco_karpathy_val_gt.json','test':'coco_karpathy_test_gt.json'}
|
||||
|
||||
download_url(urls[split],coco_gt_root)
|
||||
annotation_file = os.path.join(coco_gt_root,filenames[split])
|
||||
|
||||
# create coco object and coco_result object
|
||||
coco = COCO(annotation_file)
|
||||
coco_result = coco.loadRes(results_file)
|
||||
|
||||
# create coco_eval object by taking coco and coco_result
|
||||
coco_eval = COCOEvalCap(coco, coco_result)
|
||||
|
||||
# evaluate on a subset of images by setting
|
||||
# coco_eval.params['image_id'] = coco_result.getImgIds()
|
||||
# please remove this line when evaluating the full validation set
|
||||
# coco_eval.params['image_id'] = coco_result.getImgIds()
|
||||
|
||||
# evaluate results
|
||||
# SPICE will take a few minutes the first time, but speeds up due to caching
|
||||
coco_eval.evaluate()
|
||||
|
||||
# print output evaluation scores
|
||||
for metric, score in coco_eval.eval.items():
|
||||
print(f'{metric}: {score:.3f}')
|
||||
|
||||
return coco_eval
|
||||
|
|
@ -0,0 +1,110 @@
|
|||
from torch.utils.data import Dataset
|
||||
from torchvision.datasets.utils import download_url
|
||||
|
||||
from PIL import Image
|
||||
import torch
|
||||
import numpy as np
|
||||
import random
|
||||
import decord
|
||||
from decord import VideoReader
|
||||
import json
|
||||
import os
|
||||
from data.utils import pre_caption
|
||||
|
||||
decord.bridge.set_bridge("torch")
|
||||
|
||||
class ImageNorm(object):
|
||||
"""Apply Normalization to Image Pixels on GPU
|
||||
"""
|
||||
def __init__(self, mean, std):
|
||||
self.mean = torch.tensor(mean).view(1, 3, 1, 1)
|
||||
self.std = torch.tensor(std).view(1, 3, 1, 1)
|
||||
|
||||
def __call__(self, img):
|
||||
|
||||
if torch.max(img) > 1 and self.mean.max() <= 1:
|
||||
img.div_(255.)
|
||||
return img.sub_(self.mean).div_(self.std)
|
||||
|
||||
def load_jsonl(filename):
|
||||
with open(filename, "r") as f:
|
||||
return [json.loads(l.strip("\n")) for l in f.readlines()]
|
||||
|
||||
|
||||
class VideoDataset(Dataset):
|
||||
|
||||
def __init__(self, video_root, ann_root, num_frm=4, frm_sampling_strategy="rand", max_img_size=384, video_fmt='.mp4'):
|
||||
'''
|
||||
image_root (string): Root directory of video
|
||||
ann_root (string): directory to store the annotation file
|
||||
'''
|
||||
url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/msrvtt_test.jsonl'
|
||||
filename = 'msrvtt_test.jsonl'
|
||||
|
||||
download_url(url,ann_root)
|
||||
self.annotation = load_jsonl(os.path.join(ann_root,filename))
|
||||
|
||||
self.num_frm = num_frm
|
||||
self.frm_sampling_strategy = frm_sampling_strategy
|
||||
self.max_img_size = max_img_size
|
||||
self.video_root = video_root
|
||||
self.video_fmt = video_fmt
|
||||
self.img_norm = ImageNorm(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
|
||||
|
||||
self.text = [pre_caption(ann['caption'],40) for ann in self.annotation]
|
||||
self.txt2video = [i for i in range(len(self.annotation))]
|
||||
self.video2txt = self.txt2video
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.annotation)
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
ann = self.annotation[index]
|
||||
|
||||
video_path = os.path.join(self.video_root, ann['clip_name'] + self.video_fmt)
|
||||
|
||||
vid_frm_array = self._load_video_from_path_decord(video_path, height=self.max_img_size, width=self.max_img_size)
|
||||
|
||||
video = self.img_norm(vid_frm_array.float())
|
||||
|
||||
return video, ann['clip_name']
|
||||
|
||||
|
||||
|
||||
def _load_video_from_path_decord(self, video_path, height=None, width=None, start_time=None, end_time=None, fps=-1):
|
||||
try:
|
||||
if not height or not width:
|
||||
vr = VideoReader(video_path)
|
||||
else:
|
||||
vr = VideoReader(video_path, width=width, height=height)
|
||||
|
||||
vlen = len(vr)
|
||||
|
||||
if start_time or end_time:
|
||||
assert fps > 0, 'must provide video fps if specifying start and end time.'
|
||||
|
||||
start_idx = min(int(start_time * fps), vlen)
|
||||
end_idx = min(int(end_time * fps), vlen)
|
||||
else:
|
||||
start_idx, end_idx = 0, vlen
|
||||
|
||||
if self.frm_sampling_strategy == 'uniform':
|
||||
frame_indices = np.arange(start_idx, end_idx, vlen / self.num_frm, dtype=int)
|
||||
elif self.frm_sampling_strategy == 'rand':
|
||||
frame_indices = sorted(random.sample(range(vlen), self.num_frm))
|
||||
elif self.frm_sampling_strategy == 'headtail':
|
||||
frame_indices_head = sorted(random.sample(range(vlen // 2), self.num_frm // 2))
|
||||
frame_indices_tail = sorted(random.sample(range(vlen // 2, vlen), self.num_frm // 2))
|
||||
frame_indices = frame_indices_head + frame_indices_tail
|
||||
else:
|
||||
raise NotImplementedError('Invalid sampling strategy {} '.format(self.frm_sampling_strategy))
|
||||
|
||||
raw_sample_frms = vr.get_batch(frame_indices)
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
raw_sample_frms = raw_sample_frms.permute(0, 3, 1, 2)
|
||||
|
||||
return raw_sample_frms
|
||||
|
|
@ -0,0 +1,88 @@
|
|||
import os
|
||||
import json
|
||||
import random
|
||||
from PIL import Image
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from data.utils import pre_question
|
||||
|
||||
from torchvision.datasets.utils import download_url
|
||||
|
||||
class vqa_dataset(Dataset):
|
||||
def __init__(self, transform, ann_root, vqa_root, vg_root, train_files=[], split="train"):
|
||||
self.split = split
|
||||
|
||||
self.transform = transform
|
||||
self.vqa_root = vqa_root
|
||||
self.vg_root = vg_root
|
||||
|
||||
if split=='train':
|
||||
urls = {'vqa_train':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_train.json',
|
||||
'vqa_val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_val.json',
|
||||
'vg_qa':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vg_qa.json'}
|
||||
|
||||
self.annotation = []
|
||||
for f in train_files:
|
||||
download_url(urls[f],ann_root)
|
||||
self.annotation += json.load(open(os.path.join(ann_root,'%s.json'%f),'r'))
|
||||
else:
|
||||
download_url('https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_test.json',ann_root)
|
||||
self.annotation = json.load(open(os.path.join(ann_root,'vqa_test.json'),'r'))
|
||||
|
||||
download_url('https://storage.googleapis.com/sfr-vision-language-research/datasets/answer_list.json',ann_root)
|
||||
self.answer_list = json.load(open(os.path.join(ann_root,'answer_list.json'),'r'))
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.annotation)
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
ann = self.annotation[index]
|
||||
|
||||
if ann['dataset']=='vqa':
|
||||
image_path = os.path.join(self.vqa_root,ann['image'])
|
||||
elif ann['dataset']=='vg':
|
||||
image_path = os.path.join(self.vg_root,ann['image'])
|
||||
|
||||
image = Image.open(image_path).convert('RGB')
|
||||
image = self.transform(image)
|
||||
|
||||
if self.split == 'test':
|
||||
question = pre_question(ann['question'])
|
||||
question_id = ann['question_id']
|
||||
return image, question, question_id
|
||||
|
||||
|
||||
elif self.split=='train':
|
||||
|
||||
question = pre_question(ann['question'])
|
||||
|
||||
if ann['dataset']=='vqa':
|
||||
answer_weight = {}
|
||||
for answer in ann['answer']:
|
||||
if answer in answer_weight.keys():
|
||||
answer_weight[answer] += 1/len(ann['answer'])
|
||||
else:
|
||||
answer_weight[answer] = 1/len(ann['answer'])
|
||||
|
||||
answers = list(answer_weight.keys())
|
||||
weights = list(answer_weight.values())
|
||||
|
||||
elif ann['dataset']=='vg':
|
||||
answers = [ann['answer']]
|
||||
weights = [0.2]
|
||||
|
||||
return image, question, answers, weights
|
||||
|
||||
|
||||
def vqa_collate_fn(batch):
|
||||
image_list, question_list, answer_list, weight_list, n = [], [], [], [], []
|
||||
for image, question, answer, weights in batch:
|
||||
image_list.append(image)
|
||||
question_list.append(question)
|
||||
weight_list += weights
|
||||
answer_list += answer
|
||||
n.append(len(answer))
|
||||
return torch.stack(image_list,dim=0), question_list, answer_list, torch.Tensor(weight_list), n
|
||||
File diff suppressed because one or more lines are too long
|
|
@ -0,0 +1,118 @@
|
|||
'''
|
||||
* Copyright (c) 2022, salesforce.com, inc.
|
||||
* All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
* By Junnan Li
|
||||
'''
|
||||
import argparse
|
||||
import os
|
||||
import ruamel_yaml as yaml
|
||||
import numpy as np
|
||||
import random
|
||||
import time
|
||||
import datetime
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from models.blip import blip_decoder
|
||||
import utils
|
||||
from data import create_dataset, create_sampler, create_loader
|
||||
from data.utils import save_result
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate(model, data_loader, device, config):
|
||||
# evaluate
|
||||
model.eval()
|
||||
|
||||
metric_logger = utils.MetricLogger(delimiter=" ")
|
||||
header = 'Evaluation:'
|
||||
print_freq = 10
|
||||
|
||||
result = []
|
||||
for image, image_id in metric_logger.log_every(data_loader, print_freq, header):
|
||||
|
||||
image = image.to(device)
|
||||
|
||||
captions = model.generate(image, sample=False, num_beams=config['num_beams'], max_length=config['max_length'],
|
||||
min_length=config['min_length'], repetition_penalty=1.1)
|
||||
|
||||
for caption, img_id in zip(captions, image_id):
|
||||
result.append({"image_id": img_id.item(), "caption": caption})
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def main(args, config):
|
||||
utils.init_distributed_mode(args)
|
||||
|
||||
device = torch.device(args.device)
|
||||
|
||||
# fix the seed for reproducibility
|
||||
seed = args.seed + utils.get_rank()
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
cudnn.benchmark = True
|
||||
|
||||
#### Dataset ####
|
||||
print("Creating captioning dataset")
|
||||
val_dataset, test_dataset = create_dataset('nocaps', config)
|
||||
|
||||
if args.distributed:
|
||||
num_tasks = utils.get_world_size()
|
||||
global_rank = utils.get_rank()
|
||||
samplers = create_sampler([val_dataset,test_dataset], [False,False], num_tasks, global_rank)
|
||||
else:
|
||||
samplers = [None,None]
|
||||
|
||||
val_loader, test_loader = create_loader([val_dataset, test_dataset],samplers,
|
||||
batch_size=[config['batch_size']]*2,num_workers=[4,4],
|
||||
is_trains=[False, False], collate_fns=[None,None])
|
||||
|
||||
#### Model ####
|
||||
print("Creating model")
|
||||
model = blip_decoder(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'],
|
||||
prompt=config['prompt'])
|
||||
|
||||
model = model.to(device)
|
||||
|
||||
model_without_ddp = model
|
||||
if args.distributed:
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
||||
model_without_ddp = model.module
|
||||
|
||||
val_result = evaluate(model_without_ddp, val_loader, device, config)
|
||||
val_result_file = save_result(val_result, args.result_dir, 'val', remove_duplicate='image_id')
|
||||
test_result = evaluate(model_without_ddp, test_loader, device, config)
|
||||
test_result_file = save_result(test_result, args.result_dir, 'test', remove_duplicate='image_id')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--config', default='./configs/nocaps.yaml')
|
||||
parser.add_argument('--output_dir', default='output/NoCaps')
|
||||
parser.add_argument('--device', default='cuda')
|
||||
parser.add_argument('--seed', default=42, type=int)
|
||||
parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
|
||||
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
|
||||
parser.add_argument('--distributed', default=True, type=bool)
|
||||
args = parser.parse_args()
|
||||
|
||||
config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
|
||||
|
||||
args.result_dir = os.path.join(args.output_dir, 'result')
|
||||
|
||||
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
||||
Path(args.result_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
|
||||
|
||||
main(args, config)
|
||||
|
|
@ -0,0 +1,250 @@
|
|||
'''
|
||||
* Copyright (c) 2022, salesforce.com, inc.
|
||||
* All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
* By Junnan Li
|
||||
'''
|
||||
import argparse
|
||||
import os
|
||||
import ruamel_yaml as yaml
|
||||
import numpy as np
|
||||
import random
|
||||
import time
|
||||
import datetime
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from models.blip_retrieval import blip_retrieval
|
||||
import utils
|
||||
from data.video_dataset import VideoDataset
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluation(model, data_loader, tokenizer, device, config):
|
||||
# test
|
||||
model.eval()
|
||||
|
||||
metric_logger = utils.MetricLogger(delimiter=" ")
|
||||
header = 'Evaluation:'
|
||||
|
||||
print('Computing features for evaluation...')
|
||||
start_time = time.time()
|
||||
|
||||
texts = data_loader.dataset.text
|
||||
num_text = len(texts)
|
||||
text_bs = 256
|
||||
text_ids = []
|
||||
text_embeds = []
|
||||
text_atts = []
|
||||
for i in range(0, num_text, text_bs):
|
||||
text = texts[i: min(num_text, i+text_bs)]
|
||||
text_input = tokenizer(text, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(device)
|
||||
text_output = model.text_encoder(text_input.input_ids, attention_mask = text_input.attention_mask, mode='text')
|
||||
text_embed = F.normalize(model.text_proj(text_output.last_hidden_state[:,0,:]))
|
||||
text_embeds.append(text_embed)
|
||||
text_ids.append(text_input.input_ids)
|
||||
text_atts.append(text_input.attention_mask)
|
||||
|
||||
text_embeds = torch.cat(text_embeds,dim=0)
|
||||
text_ids = torch.cat(text_ids,dim=0)
|
||||
text_atts = torch.cat(text_atts,dim=0)
|
||||
text_ids[:,0] = tokenizer.additional_special_tokens_ids[0]
|
||||
|
||||
video_feats = []
|
||||
video_embeds = []
|
||||
for video, video_id in data_loader:
|
||||
|
||||
B,N,C,W,H = video.size()
|
||||
video = video.view(-1,C,W,H)
|
||||
video = video.to(device,non_blocking=True)
|
||||
video_feat = model.visual_encoder(video)
|
||||
video_embed = model.vision_proj(video_feat[:,0,:])
|
||||
video_embed = video_embed.view(B,N,-1).mean(dim=1)
|
||||
video_embed = F.normalize(video_embed,dim=-1)
|
||||
|
||||
video_feat = video_feat.view(B,-1,video_feat.shape[-1])
|
||||
video_feats.append(video_feat.cpu())
|
||||
video_embeds.append(video_embed)
|
||||
|
||||
video_feats = torch.cat(video_feats,dim=0)
|
||||
video_embeds = torch.cat(video_embeds,dim=0)
|
||||
|
||||
sims_matrix = video_embeds @ text_embeds.t()
|
||||
score_matrix_v2t = torch.full((len(texts),len(texts)),-100.0).to(device)
|
||||
|
||||
num_tasks = utils.get_world_size()
|
||||
rank = utils.get_rank()
|
||||
step = sims_matrix.size(0)//num_tasks + 1
|
||||
start = rank*step
|
||||
end = min(sims_matrix.size(0),start+step)
|
||||
|
||||
for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)):
|
||||
topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0)
|
||||
|
||||
encoder_output = video_feats[start+i].repeat(config['k_test'],1,1).to(device,non_blocking=True)
|
||||
encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device,non_blocking=True)
|
||||
output = model.text_encoder(text_ids[topk_idx],
|
||||
attention_mask = text_atts[topk_idx],
|
||||
encoder_hidden_states = encoder_output,
|
||||
encoder_attention_mask = encoder_att,
|
||||
return_dict = True,
|
||||
)
|
||||
score = model.itm_head(output.last_hidden_state[:,0,:])[:,1]
|
||||
score_matrix_v2t[start+i,topk_idx] = score + topk_sim
|
||||
|
||||
sims_matrix = sims_matrix.t()
|
||||
score_matrix_t2v = torch.full((len(texts),len(texts)),-100.0).to(device)
|
||||
|
||||
step = sims_matrix.size(0)//num_tasks + 1
|
||||
start = rank*step
|
||||
end = min(sims_matrix.size(0),start+step)
|
||||
|
||||
for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)):
|
||||
|
||||
topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0)
|
||||
encoder_output = video_feats[topk_idx].to(device,non_blocking=True)
|
||||
encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device,non_blocking=True)
|
||||
output = model.text_encoder(text_ids[start+i].repeat(config['k_test'],1),
|
||||
attention_mask = text_atts[start+i].repeat(config['k_test'],1),
|
||||
encoder_hidden_states = encoder_output,
|
||||
encoder_attention_mask = encoder_att,
|
||||
return_dict = True,
|
||||
)
|
||||
score = model.itm_head(output.last_hidden_state[:,0,:])[:,1]
|
||||
score_matrix_t2v[start+i,topk_idx] = score + topk_sim
|
||||
|
||||
if args.distributed:
|
||||
dist.barrier()
|
||||
torch.distributed.all_reduce(score_matrix_v2t, op=torch.distributed.ReduceOp.SUM)
|
||||
torch.distributed.all_reduce(score_matrix_t2v, op=torch.distributed.ReduceOp.SUM)
|
||||
|
||||
total_time = time.time() - start_time
|
||||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||
print('Evaluation time {}'.format(total_time_str))
|
||||
|
||||
return score_matrix_v2t.cpu().numpy(), score_matrix_t2v.cpu().numpy()
|
||||
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def itm_eval(scores_v2t, scores_t2v, txt2vmg, vid2txt):
|
||||
|
||||
#Video->Text
|
||||
ranks = np.zeros(scores_v2t.shape[0])
|
||||
for index,score in enumerate(scores_v2t):
|
||||
inds = np.argsort(score)[::-1]
|
||||
ranks[index] = np.where(inds == vid2txt[index])[0][0]
|
||||
|
||||
# Compute metrics
|
||||
tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
|
||||
tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
|
||||
tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
|
||||
|
||||
#Text->Video
|
||||
ranks = np.zeros(scores_t2v.shape[0])
|
||||
|
||||
for index,score in enumerate(scores_t2v):
|
||||
inds = np.argsort(score)[::-1]
|
||||
ranks[index] = np.where(inds == txt2vmg[index])[0][0]
|
||||
|
||||
mdR = np.median(ranks+1)
|
||||
|
||||
# Compute metrics
|
||||
vr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
|
||||
vr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
|
||||
vr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
|
||||
|
||||
tr_mean = (tr1 + tr5 + tr10) / 3
|
||||
vr_mean = (vr1 + vr5 + vr10) / 3
|
||||
r_mean = (tr_mean + vr_mean) / 2
|
||||
|
||||
eval_result = {'txt_r1': tr1,
|
||||
'txt_r5': tr5,
|
||||
'txt_r10': tr10,
|
||||
'txt_r_mean': tr_mean,
|
||||
'vid_r1': vr1,
|
||||
'vid_r5': vr5,
|
||||
'vid_r10': vr10,
|
||||
'vid_r_mean': vr_mean,
|
||||
'vid_mdR': mdR,
|
||||
'r_mean': r_mean}
|
||||
return eval_result
|
||||
|
||||
|
||||
|
||||
|
||||
def main(args, config):
|
||||
utils.init_distributed_mode(args)
|
||||
|
||||
device = torch.device(args.device)
|
||||
|
||||
# fix the seed for reproducibility
|
||||
seed = args.seed + utils.get_rank()
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
cudnn.benchmark = True
|
||||
|
||||
#### Dataset ####
|
||||
print("Creating retrieval dataset")
|
||||
test_dataset = VideoDataset(config['video_root'],config['ann_root'],num_frm=config['num_frm_test'],
|
||||
max_img_size=config['image_size'], frm_sampling_strategy='uniform')
|
||||
|
||||
test_loader = DataLoader(
|
||||
test_dataset,
|
||||
batch_size=config['batch_size'],
|
||||
num_workers=4,
|
||||
pin_memory=True,
|
||||
drop_last=False,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
#### Model ####
|
||||
print("Creating model")
|
||||
model = blip_retrieval(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'])
|
||||
|
||||
model = model.to(device)
|
||||
|
||||
model_without_ddp = model
|
||||
if args.distributed:
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
||||
model_without_ddp = model.module
|
||||
|
||||
score_v2t, score_t2v, = evaluation(model_without_ddp, test_loader, model_without_ddp.tokenizer, device, config)
|
||||
|
||||
if utils.is_main_process():
|
||||
|
||||
test_result = itm_eval(score_v2t, score_t2v, test_loader.dataset.txt2video, test_loader.dataset.video2txt)
|
||||
print(test_result)
|
||||
|
||||
log_stats = {**{f'{k}': v for k, v in test_result.items()},}
|
||||
with open(os.path.join(args.output_dir, "test_result.txt"),"a") as f:
|
||||
f.write(json.dumps(log_stats) + "\n")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--config', default='./configs/retrieval_msrvtt.yaml')
|
||||
parser.add_argument('--output_dir', default='output/Retrieval_msrvtt')
|
||||
parser.add_argument('--device', default='cuda')
|
||||
parser.add_argument('--seed', default=42, type=int)
|
||||
parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
|
||||
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
|
||||
parser.add_argument('--distributed', default=True, type=bool)
|
||||
args = parser.parse_args()
|
||||
|
||||
config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
|
||||
|
||||
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
|
||||
|
||||
main(args, config)
|
||||
|
|
@ -0,0 +1,238 @@
|
|||
'''
|
||||
* Copyright (c) 2022, salesforce.com, inc.
|
||||
* All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
* By Junnan Li
|
||||
'''
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
from models.vit import VisionTransformer, interpolate_pos_embed
|
||||
from models.med import BertConfig, BertModel, BertLMHeadModel
|
||||
from transformers import BertTokenizer
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import os
|
||||
from urllib.parse import urlparse
|
||||
from timm.models.hub import download_cached_file
|
||||
|
||||
class BLIP_Base(nn.Module):
|
||||
def __init__(self,
|
||||
med_config = 'configs/med_config.json',
|
||||
image_size = 224,
|
||||
vit = 'base',
|
||||
vit_grad_ckpt = False,
|
||||
vit_ckpt_layer = 0,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
||||
image_size (int): input image size
|
||||
vit (str): model size of vision transformer
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
|
||||
self.tokenizer = init_tokenizer()
|
||||
med_config = BertConfig.from_json_file(med_config)
|
||||
med_config.encoder_width = vision_width
|
||||
self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
|
||||
|
||||
|
||||
def forward(self, image, caption, mode):
|
||||
|
||||
assert mode in ['image', 'text', 'multimodal'], "mode parameter must be image, text, or multimodal"
|
||||
text = self.tokenizer(caption, return_tensors="pt").to(image.device)
|
||||
|
||||
if mode=='image':
|
||||
# return image features
|
||||
image_embeds = self.visual_encoder(image)
|
||||
return image_embeds
|
||||
|
||||
elif mode=='text':
|
||||
# return text features
|
||||
text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
|
||||
return_dict = True, mode = 'text')
|
||||
return text_output.last_hidden_state
|
||||
|
||||
elif mode=='multimodal':
|
||||
# return multimodel features
|
||||
image_embeds = self.visual_encoder(image)
|
||||
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
||||
|
||||
text.input_ids[:,0] = self.tokenizer.enc_token_id
|
||||
output = self.text_encoder(text.input_ids,
|
||||
attention_mask = text.attention_mask,
|
||||
encoder_hidden_states = image_embeds,
|
||||
encoder_attention_mask = image_atts,
|
||||
return_dict = True,
|
||||
)
|
||||
return output.last_hidden_state
|
||||
|
||||
|
||||
|
||||
class BLIP_Decoder(nn.Module):
|
||||
def __init__(self,
|
||||
med_config = 'configs/med_config.json',
|
||||
image_size = 384,
|
||||
vit = 'base',
|
||||
vit_grad_ckpt = False,
|
||||
vit_ckpt_layer = 0,
|
||||
prompt = 'a picture of ',
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
||||
image_size (int): input image size
|
||||
vit (str): model size of vision transformer
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
|
||||
self.tokenizer = init_tokenizer()
|
||||
med_config = BertConfig.from_json_file(med_config)
|
||||
med_config.encoder_width = vision_width
|
||||
self.text_decoder = BertLMHeadModel(config=med_config)
|
||||
|
||||
self.prompt = prompt
|
||||
self.prompt_length = len(self.tokenizer(self.prompt).input_ids)-1
|
||||
|
||||
|
||||
def forward(self, image, caption):
|
||||
|
||||
image_embeds = self.visual_encoder(image)
|
||||
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
||||
|
||||
text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(image.device)
|
||||
|
||||
text.input_ids[:,0] = self.tokenizer.bos_token_id
|
||||
|
||||
decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100)
|
||||
decoder_targets[:,:self.prompt_length] = -100
|
||||
|
||||
decoder_output = self.text_decoder(text.input_ids,
|
||||
attention_mask = text.attention_mask,
|
||||
encoder_hidden_states = image_embeds,
|
||||
encoder_attention_mask = image_atts,
|
||||
labels = decoder_targets,
|
||||
return_dict = True,
|
||||
)
|
||||
loss_lm = decoder_output.loss
|
||||
|
||||
return loss_lm
|
||||
|
||||
def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0):
|
||||
image_embeds = self.visual_encoder(image)
|
||||
|
||||
if not sample:
|
||||
image_embeds = image_embeds.repeat_interleave(num_beams,dim=0)
|
||||
|
||||
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
||||
model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts}
|
||||
|
||||
prompt = [self.prompt] * image.size(0)
|
||||
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device)
|
||||
input_ids[:,0] = self.tokenizer.bos_token_id
|
||||
input_ids = input_ids[:, :-1]
|
||||
|
||||
if sample:
|
||||
#nucleus sampling
|
||||
outputs = self.text_decoder.generate(input_ids=input_ids,
|
||||
max_length=max_length,
|
||||
min_length=min_length,
|
||||
do_sample=True,
|
||||
top_p=top_p,
|
||||
num_return_sequences=1,
|
||||
eos_token_id=self.tokenizer.sep_token_id,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
repetition_penalty=1.1,
|
||||
**model_kwargs)
|
||||
else:
|
||||
#beam search
|
||||
outputs = self.text_decoder.generate(input_ids=input_ids,
|
||||
max_length=max_length,
|
||||
min_length=min_length,
|
||||
num_beams=num_beams,
|
||||
eos_token_id=self.tokenizer.sep_token_id,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
repetition_penalty=repetition_penalty,
|
||||
**model_kwargs)
|
||||
|
||||
captions = []
|
||||
for output in outputs:
|
||||
caption = self.tokenizer.decode(output, skip_special_tokens=True)
|
||||
captions.append(caption[len(self.prompt):])
|
||||
return captions
|
||||
|
||||
|
||||
def blip_decoder(pretrained='',**kwargs):
|
||||
model = BLIP_Decoder(**kwargs)
|
||||
if pretrained:
|
||||
model,msg = load_checkpoint(model,pretrained)
|
||||
assert(len(msg.missing_keys)==0)
|
||||
return model
|
||||
|
||||
def blip_feature_extractor(pretrained='',**kwargs):
|
||||
model = BLIP_Base(**kwargs)
|
||||
if pretrained:
|
||||
model,msg = load_checkpoint(model,pretrained)
|
||||
assert(len(msg.missing_keys)==0)
|
||||
return model
|
||||
|
||||
def init_tokenizer():
|
||||
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
||||
tokenizer.add_special_tokens({'bos_token':'[DEC]'})
|
||||
tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
|
||||
tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
|
||||
return tokenizer
|
||||
|
||||
|
||||
def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
|
||||
|
||||
assert vit in ['base', 'large'], "vit parameter must be base or large"
|
||||
if vit=='base':
|
||||
vision_width = 768
|
||||
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
|
||||
num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
|
||||
drop_path_rate=0 or drop_path_rate
|
||||
)
|
||||
elif vit=='large':
|
||||
vision_width = 1024
|
||||
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
|
||||
num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
|
||||
drop_path_rate=0.1 or drop_path_rate
|
||||
)
|
||||
return visual_encoder, vision_width
|
||||
|
||||
def is_url(url_or_filename):
|
||||
parsed = urlparse(url_or_filename)
|
||||
return parsed.scheme in ("http", "https")
|
||||
|
||||
def load_checkpoint(model,url_or_filename):
|
||||
if is_url(url_or_filename):
|
||||
cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
|
||||
checkpoint = torch.load(cached_file, map_location='cpu')
|
||||
elif os.path.isfile(url_or_filename):
|
||||
checkpoint = torch.load(url_or_filename, map_location='cpu')
|
||||
else:
|
||||
raise RuntimeError('checkpoint url or path is invalid')
|
||||
|
||||
state_dict = checkpoint['model']
|
||||
|
||||
state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
|
||||
if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
|
||||
state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
|
||||
model.visual_encoder_m)
|
||||
for key in model.state_dict().keys():
|
||||
if key in state_dict.keys():
|
||||
if state_dict[key].shape!=model.state_dict()[key].shape:
|
||||
del state_dict[key]
|
||||
|
||||
msg = model.load_state_dict(state_dict,strict=False)
|
||||
print('load checkpoint from %s'%url_or_filename)
|
||||
return model,msg
|
||||
|
||||
|
|
@ -0,0 +1,76 @@
|
|||
from models.med import BertConfig, BertModel
|
||||
from transformers import BertTokenizer
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from models.blip import create_vit, init_tokenizer, load_checkpoint
|
||||
|
||||
class BLIP_ITM(nn.Module):
|
||||
def __init__(self,
|
||||
med_config = 'configs/med_config.json',
|
||||
image_size = 384,
|
||||
vit = 'base',
|
||||
vit_grad_ckpt = False,
|
||||
vit_ckpt_layer = 0,
|
||||
embed_dim = 256,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
||||
image_size (int): input image size
|
||||
vit (str): model size of vision transformer
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
|
||||
self.tokenizer = init_tokenizer()
|
||||
med_config = BertConfig.from_json_file(med_config)
|
||||
med_config.encoder_width = vision_width
|
||||
self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
|
||||
|
||||
text_width = self.text_encoder.config.hidden_size
|
||||
|
||||
self.vision_proj = nn.Linear(vision_width, embed_dim)
|
||||
self.text_proj = nn.Linear(text_width, embed_dim)
|
||||
|
||||
self.itm_head = nn.Linear(text_width, 2)
|
||||
|
||||
|
||||
def forward(self, image, caption, match_head='itm'):
|
||||
|
||||
image_embeds = self.visual_encoder(image)
|
||||
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
||||
|
||||
text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35,
|
||||
return_tensors="pt").to(image.device)
|
||||
|
||||
|
||||
if match_head=='itm':
|
||||
output = self.text_encoder(text.input_ids,
|
||||
attention_mask = text.attention_mask,
|
||||
encoder_hidden_states = image_embeds,
|
||||
encoder_attention_mask = image_atts,
|
||||
return_dict = True,
|
||||
)
|
||||
itm_output = self.itm_head(output.last_hidden_state[:,0,:])
|
||||
return itm_output
|
||||
|
||||
elif match_head=='itc':
|
||||
text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
|
||||
return_dict = True, mode = 'text')
|
||||
image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1)
|
||||
text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1)
|
||||
|
||||
sim = image_feat @ text_feat.t()
|
||||
return sim
|
||||
|
||||
|
||||
def blip_itm(pretrained='',**kwargs):
|
||||
model = BLIP_ITM(**kwargs)
|
||||
if pretrained:
|
||||
model,msg = load_checkpoint(model,pretrained)
|
||||
assert(len(msg.missing_keys)==0)
|
||||
return model
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue