diff --git a/README.md b/README.md index 5ec6c4f..339b967 100644 --- a/README.md +++ b/README.md @@ -219,6 +219,7 @@ This extension adds these command line arguments to the webui: --no-half-controlnet load controlnet models in full precision --controlnet-preprocessor-cache-size Cache size for controlnet preprocessor results --controlnet-loglevel Log level for the controlnet extension + --controlnet-tracemalloc Enable malloc memory tracing ``` # MacOS Support diff --git a/preload.py b/preload.py index 9a87a82..9bc15b7 100644 --- a/preload.py +++ b/preload.py @@ -31,3 +31,9 @@ def preload(parser): choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], help="Set the log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)", ) + parser.add_argument( + "--controlnet-tracemalloc", + action="store_true", + help="Enable memory tracing.", + default=None, + ) diff --git a/scripts/controlnet.py b/scripts/controlnet.py index 2ec1cbc..6a46363 100644 --- a/scripts/controlnet.py +++ b/scripts/controlnet.py @@ -1,7 +1,7 @@ import gc +import tracemalloc import os import logging -import re from collections import OrderedDict from copy import copy from typing import Dict, Optional, Tuple @@ -10,9 +10,8 @@ from modules import shared, devices, script_callbacks, processing, masking, imag import gradio as gr import time - from einops import rearrange -from scripts import global_state, hook, external_code, processor, batch_hijack, controlnet_version, utils +from scripts import global_state, hook, external_code, batch_hijack, controlnet_version, utils from scripts.controlnet_lora import bind_control_lora, unbind_control_lora from scripts.processor import * from scripts.adapter import Adapter, StyleAdapter, Adapter_light @@ -32,7 +31,6 @@ import cv2 import numpy as np import torch -from pathlib import Path from PIL import Image, ImageFilter, ImageOps from scripts.lvminthin import lvmin_thin, nake_nms from scripts.processor import model_free_preprocessors @@ -1069,10 +1067,18 @@ class Script(scripts.Script, metaclass=( def controlnet_hack(self, p): t = time.time() + if getattr(shared.cmd_opts, 'controlnet_tracemalloc', False): + tracemalloc.start() + setattr(self, "malloc_begin", tracemalloc.take_snapshot()) + self.controlnet_main_entry(p) + if getattr(shared.cmd_opts, 'controlnet_tracemalloc', False): + logger.info("After hook malloc:") + for stat in tracemalloc.take_snapshot().compare_to(self.malloc_begin, "lineno")[:10]: + logger.info(stat) + if len(self.enabled_units) > 0: logger.info(f'ControlNet Hooked - Time = {time.time() - t}') - return @staticmethod def process_has_sdxl_refiner(p): @@ -1156,6 +1162,11 @@ class Script(scripts.Script, metaclass=( gc.collect() devices.torch_gc() + if getattr(shared.cmd_opts, 'controlnet_tracemalloc', False): + logger.info("After generation:") + for stat in tracemalloc.take_snapshot().compare_to(self.malloc_begin, "lineno")[:10]: + logger.info(stat) + tracemalloc.stop() def batch_tab_process(self, p, batches, *args, **kwargs): self.enabled_units = Script.get_enabled_units(p)