🔊 Add tracemalloc to track memory usage (#2373)

pull/2376/head
Chenlei Hu 2023-12-31 04:31:58 +00:00 committed by GitHub
parent 218f075023
commit cd0d5870f4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 23 additions and 5 deletions

View File

@ -219,6 +219,7 @@ This extension adds these command line arguments to the webui:
--no-half-controlnet load controlnet models in full precision --no-half-controlnet load controlnet models in full precision
--controlnet-preprocessor-cache-size Cache size for controlnet preprocessor results --controlnet-preprocessor-cache-size Cache size for controlnet preprocessor results
--controlnet-loglevel Log level for the controlnet extension --controlnet-loglevel Log level for the controlnet extension
--controlnet-tracemalloc Enable malloc memory tracing
``` ```
# MacOS Support # MacOS Support

View File

@ -31,3 +31,9 @@ def preload(parser):
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
help="Set the log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)", help="Set the log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)",
) )
parser.add_argument(
"--controlnet-tracemalloc",
action="store_true",
help="Enable memory tracing.",
default=None,
)

View File

@ -1,7 +1,7 @@
import gc import gc
import tracemalloc
import os import os
import logging import logging
import re
from collections import OrderedDict from collections import OrderedDict
from copy import copy from copy import copy
from typing import Dict, Optional, Tuple from typing import Dict, Optional, Tuple
@ -10,9 +10,8 @@ from modules import shared, devices, script_callbacks, processing, masking, imag
import gradio as gr import gradio as gr
import time import time
from einops import rearrange from einops import rearrange
from scripts import global_state, hook, external_code, processor, batch_hijack, controlnet_version, utils from scripts import global_state, hook, external_code, batch_hijack, controlnet_version, utils
from scripts.controlnet_lora import bind_control_lora, unbind_control_lora from scripts.controlnet_lora import bind_control_lora, unbind_control_lora
from scripts.processor import * from scripts.processor import *
from scripts.adapter import Adapter, StyleAdapter, Adapter_light from scripts.adapter import Adapter, StyleAdapter, Adapter_light
@ -32,7 +31,6 @@ import cv2
import numpy as np import numpy as np
import torch import torch
from pathlib import Path
from PIL import Image, ImageFilter, ImageOps from PIL import Image, ImageFilter, ImageOps
from scripts.lvminthin import lvmin_thin, nake_nms from scripts.lvminthin import lvmin_thin, nake_nms
from scripts.processor import model_free_preprocessors from scripts.processor import model_free_preprocessors
@ -1069,10 +1067,18 @@ class Script(scripts.Script, metaclass=(
def controlnet_hack(self, p): def controlnet_hack(self, p):
t = time.time() t = time.time()
if getattr(shared.cmd_opts, 'controlnet_tracemalloc', False):
tracemalloc.start()
setattr(self, "malloc_begin", tracemalloc.take_snapshot())
self.controlnet_main_entry(p) self.controlnet_main_entry(p)
if getattr(shared.cmd_opts, 'controlnet_tracemalloc', False):
logger.info("After hook malloc:")
for stat in tracemalloc.take_snapshot().compare_to(self.malloc_begin, "lineno")[:10]:
logger.info(stat)
if len(self.enabled_units) > 0: if len(self.enabled_units) > 0:
logger.info(f'ControlNet Hooked - Time = {time.time() - t}') logger.info(f'ControlNet Hooked - Time = {time.time() - t}')
return
@staticmethod @staticmethod
def process_has_sdxl_refiner(p): def process_has_sdxl_refiner(p):
@ -1156,6 +1162,11 @@ class Script(scripts.Script, metaclass=(
gc.collect() gc.collect()
devices.torch_gc() devices.torch_gc()
if getattr(shared.cmd_opts, 'controlnet_tracemalloc', False):
logger.info("After generation:")
for stat in tracemalloc.take_snapshot().compare_to(self.malloc_begin, "lineno")[:10]:
logger.info(stat)
tracemalloc.stop()
def batch_tab_process(self, p, batches, *args, **kwargs): def batch_tab_process(self, p, batches, *args, **kwargs):
self.enabled_units = Script.get_enabled_units(p) self.enabled_units = Script.get_enabled_units(p)