diff --git a/patches/external_pr/ui.py b/patches/external_pr/ui.py index 9749203..1b097aa 100644 --- a/patches/external_pr/ui.py +++ b/patches/external_pr/ui.py @@ -5,10 +5,9 @@ import os import random from modules import shared, sd_hijack, devices -from modules.call_queue import wrap_gradio_call +from modules.call_queue import wrap_gradio_call, wrap_gradio_gpu_call from modules.paths import script_path from modules.ui import create_refresh_button, gr_show -from webui import wrap_gradio_gpu_call from .textual_inversion import train_embedding as train_embedding_external from .hypernetwork import train_hypernetwork as train_hypernetwork_external, train_hypernetwork_tuning import gradio as gr diff --git a/scripts/hypernetwork-extensions.py b/scripts/hypernetwork-extensions.py index 17a954a..a3231aa 100644 --- a/scripts/hypernetwork-extensions.py +++ b/scripts/hypernetwork-extensions.py @@ -3,17 +3,15 @@ import os from modules.call_queue import wrap_gradio_call from modules.hypernetworks.ui import keys import modules.scripts as scripts -from modules import script_callbacks, shared, sd_hijack +from modules import script_callbacks, shared import gradio as gr -from modules.paths import script_path -from modules.ui import create_refresh_button, gr_show +from modules.ui import gr_show import patches.clip_hijack as clip_hijack import patches.textual_inversion as textual_inversion import patches.ui as ui import patches.shared as shared_patch import patches.external_pr.ui as external_patch_ui -from webui import wrap_gradio_gpu_call setattr(shared.opts,'pin_memory', False) @@ -125,6 +123,24 @@ def create_extension_tab2(params=None): ], outputs=[] ) + with gr.Row(): + def track_vram_usage(*args): + import torch + import gc + torch.cuda.empty_cache() + gc.collect() + for obj in gc.get_objects(): + try: + if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)): + if obj.is_cuda: + print(type(obj), obj.size()) + except: pass + track_vram_usage_button = gr.Button(value="Track VRAM usage") + track_vram_usage_button.click( + fn = track_vram_usage, + inputs=[], + outputs=[] + ) return [(CLIP_test_interface, "CLIP_test", "clip_test")] def on_ui_settings():