🔊 Add tracemalloc to track memory usage (#2373)

pull/2376/head
Chenlei Hu 2023-12-31 04:31:58 +00:00 committed by GitHub
parent 218f075023
commit cd0d5870f4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 23 additions and 5 deletions

View File

@ -219,6 +219,7 @@ This extension adds these command line arguments to the webui:
--no-half-controlnet load controlnet models in full precision
--controlnet-preprocessor-cache-size Cache size for controlnet preprocessor results
--controlnet-loglevel Log level for the controlnet extension
--controlnet-tracemalloc Enable malloc memory tracing
```
# MacOS Support

View File

@ -31,3 +31,9 @@ def preload(parser):
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
help="Set the log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)",
)
parser.add_argument(
"--controlnet-tracemalloc",
action="store_true",
help="Enable memory tracing.",
default=None,
)

View File

@ -1,7 +1,7 @@
import gc
import tracemalloc
import os
import logging
import re
from collections import OrderedDict
from copy import copy
from typing import Dict, Optional, Tuple
@ -10,9 +10,8 @@ from modules import shared, devices, script_callbacks, processing, masking, imag
import gradio as gr
import time
from einops import rearrange
from scripts import global_state, hook, external_code, processor, batch_hijack, controlnet_version, utils
from scripts import global_state, hook, external_code, batch_hijack, controlnet_version, utils
from scripts.controlnet_lora import bind_control_lora, unbind_control_lora
from scripts.processor import *
from scripts.adapter import Adapter, StyleAdapter, Adapter_light
@ -32,7 +31,6 @@ import cv2
import numpy as np
import torch
from pathlib import Path
from PIL import Image, ImageFilter, ImageOps
from scripts.lvminthin import lvmin_thin, nake_nms
from scripts.processor import model_free_preprocessors
@ -1069,10 +1067,18 @@ class Script(scripts.Script, metaclass=(
def controlnet_hack(self, p):
t = time.time()
if getattr(shared.cmd_opts, 'controlnet_tracemalloc', False):
tracemalloc.start()
setattr(self, "malloc_begin", tracemalloc.take_snapshot())
self.controlnet_main_entry(p)
if getattr(shared.cmd_opts, 'controlnet_tracemalloc', False):
logger.info("After hook malloc:")
for stat in tracemalloc.take_snapshot().compare_to(self.malloc_begin, "lineno")[:10]:
logger.info(stat)
if len(self.enabled_units) > 0:
logger.info(f'ControlNet Hooked - Time = {time.time() - t}')
return
@staticmethod
def process_has_sdxl_refiner(p):
@ -1156,6 +1162,11 @@ class Script(scripts.Script, metaclass=(
gc.collect()
devices.torch_gc()
if getattr(shared.cmd_opts, 'controlnet_tracemalloc', False):
logger.info("After generation:")
for stat in tracemalloc.take_snapshot().compare_to(self.malloc_begin, "lineno")[:10]:
logger.info(stat)
tracemalloc.stop()
def batch_tab_process(self, p, batches, *args, **kwargs):
self.enabled_units = Script.get_enabled_units(p)