first!
commit
c7fae24ca9
|
|
@ -0,0 +1,2 @@
|
|||
__pycache__
|
||||
/TensorRT-*
|
||||
|
|
@ -0,0 +1,36 @@
|
|||
# TensorRT support for webui
|
||||
|
||||
Adds the ability to convert loaded model's Unet module into TensortRT. Requires version least after commit [339b5315](htts://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/339b5315700a469f4a9f0d5afc08ca2aca60c579) (currently, it's the `dev` branch after 2023-05-27). Only tested to work on Windows.
|
||||
|
||||
Loras are baked in into the converted model. Hypernetwork support is not tested. Controlnet is not supported. Textual inversion works normally.
|
||||
|
||||
NVIDIA is also working on releaseing their version of TensorRT for webui, which might be more performant, but they can't release it yet.
|
||||
|
||||
There seems to be support for quickly replacing weight of a TensorRT engine without rebuilding it, and this extension does not offer this option yet.
|
||||
|
||||
## How to install
|
||||
|
||||
Apart from installing the extension normally, you also need to download zip with TensorRT from [NVIDIA](https://developer.nvidia.com/nvidia-tensorrt-8x-download).
|
||||
|
||||
You need to choose the same version of CUDA as python's torch library is using. For torch 2.0.1 it is CUDA 11.8.
|
||||
|
||||
Extract the zip into extension directory, so that `TensorRT-8.6.1.6` (or similarly named dir) exists in the same place as `scripts` directory and `trt_path.py` file. Restart webui afterwards.
|
||||
|
||||
You don't need to install CUDA separately.
|
||||
|
||||
## How to use
|
||||
|
||||
1. Slect the model you want to optimize and make a picture with it, including needed loras and hypernetworks.
|
||||
2. Go to a `TensorRT` tab that appears if the extension loads properly.
|
||||
3. In `Convert to ONNX` tab, press `Convert Unet to ONNX`.
|
||||
* This takes a short while.
|
||||
* After the conversion has finished, you will find an `.onnx` file with model in `models/Unet-onnx` directory.
|
||||
4. In `Convert ONNX to TensorRT` tab, configure the necessary parameters (including writing full path to onnx model) and press `Convert ONNX to TensorRT`.
|
||||
* This takes very long - from 15 minues to an hour.
|
||||
* This takes up a lot of VRAM: you might want to press "Show command for conversion" and run the command yourself after shutting down webui.
|
||||
* After the conversion has finished, you will find a `.trt` file with model in `models/Unet-trt` directory.
|
||||
5. In settings, in `Stable Diffusion` page, use `SD Unet` option to select newly generated TensorRT model.
|
||||
6. Generate pictures.
|
||||
|
||||
## Stable Diffusion 2.0 support
|
||||
Stable diffusion 2.0 conversion should fail for both ONNX and TensorRT because of incompatible shapes, but you may be able to rememdy this by chaning instances of 768 to 1024 in the code.
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
import os
|
||||
|
||||
from modules import sd_hijack, sd_unet
|
||||
from modules import shared, devices
|
||||
import torch
|
||||
|
||||
|
||||
def export_current_unet_to_onnx(filename, opset_version=17):
|
||||
x = torch.randn(1, 4, 16, 16).to(devices.device, devices.dtype)
|
||||
timesteps = torch.zeros((1,)).to(devices.device, devices.dtype) + 500
|
||||
context = torch.randn(1, 77, 768).to(devices.device, devices.dtype)
|
||||
|
||||
def disable_checkpoint(self):
|
||||
if getattr(self, 'use_checkpoint', False) == True:
|
||||
self.use_checkpoint = False
|
||||
if getattr(self, 'checkpoint', False) == True:
|
||||
self.checkpoint = False
|
||||
|
||||
shared.sd_model.model.diffusion_model.apply(disable_checkpoint)
|
||||
|
||||
sd_unet.apply_unet("None")
|
||||
sd_hijack.model_hijack.apply_optimizations('None')
|
||||
|
||||
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
||||
|
||||
with devices.autocast():
|
||||
torch.onnx.export(
|
||||
shared.sd_model.model.diffusion_model,
|
||||
(x, timesteps, context),
|
||||
filename,
|
||||
export_params=True,
|
||||
opset_version=opset_version,
|
||||
do_constant_folding=True,
|
||||
input_names=['x', 'timesteps', 'context'],
|
||||
output_names=['output'],
|
||||
dynamic_axes={
|
||||
'x': {0: 'batch_size', 2: 'height', 3: 'width'},
|
||||
'timesteps': {0: 'batch_size'},
|
||||
'context': {0: 'batch_size', 1: 'sequence_length'},
|
||||
'output': {0: 'batch_size'},
|
||||
},
|
||||
)
|
||||
|
||||
sd_hijack.model_hijack.apply_optimizations()
|
||||
sd_unet.apply_unet()
|
||||
|
|
@ -0,0 +1,32 @@
|
|||
import os.path
|
||||
|
||||
|
||||
def get_trt_command(trt_filename, onnx_filename, min_bs, max_bs, min_token_count, max_token_count, min_width, max_width, min_height, max_height, use_fp16, trt_extra_args):
|
||||
for val, name in zip([min_width, max_width, min_height, max_height], ["min_width", "max_width", "min_height", "max_height"]):
|
||||
assert val % 64 == 0, name + ' must be a multiple of 64'
|
||||
|
||||
for val, name in zip([min_token_count, max_token_count], ["min_token_count", "max_token_count"]):
|
||||
assert val % 75 == 0, name + ' must be a multiple of 75'
|
||||
|
||||
assert os.path.isfile(onnx_filename), 'onnx model not found: ' + onnx_filename
|
||||
|
||||
import trt_paths
|
||||
trt_exec_candidates = [
|
||||
os.path.join(trt_paths.trt_path, "bin", "trtexec"),
|
||||
os.path.join(trt_paths.trt_path, "bin", "trtexec.exe"),
|
||||
]
|
||||
|
||||
trt_exec = next(iter([x for x in trt_exec_candidates if os.path.isfile(x)]), None)
|
||||
assert trt_exec, f"could not find trtexec; searched in: {', '.join(trt_exec_candidates)}"
|
||||
|
||||
cond_dim = 768 # XXX should be detected for SD2.0
|
||||
x_min = f"{min_bs * 2}x4x{min_height // 8}x{min_width // 8}"
|
||||
x_max = f"{max_bs * 2}x4x{max_height // 8}x{max_width // 8}"
|
||||
context_min = f"{min_bs * 2}x{min_token_count // 75 * 77}x{cond_dim}"
|
||||
context_max = f"{max_bs * 2}x{max_token_count // 75 * 77}x{cond_dim}"
|
||||
timestamps_min = f"{min_bs * 2}"
|
||||
timestamps_max = f"{max_bs * 2}"
|
||||
|
||||
os.makedirs(os.path.dirname(trt_filename), exist_ok=True)
|
||||
|
||||
return f""""{trt_exec}" --onnx="{onnx_filename}" --saveEngine="{trt_filename}" --minShapes=x:{x_min},context:{context_min},timesteps:{timestamps_min} --maxShapes=x:{x_max},context:{context_max},timesteps:{timestamps_max}{' --fp16' if use_fp16 else ''} {trt_extra_args}"""
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
import os
|
||||
|
||||
import launch
|
||||
|
||||
try:
|
||||
import trt_paths
|
||||
except Exception as e:
|
||||
print("Could not find TensorRT directory; skipping install", e)
|
||||
|
||||
|
||||
def install():
|
||||
if not launch.is_installed("tensorrt"):
|
||||
trt_whl_path = os.path.join(trt_paths.trt_path, "python")
|
||||
matching_files = [os.path.join(trt_whl_path, x) for x in os.listdir(trt_whl_path)]
|
||||
matching_files = [x for x in matching_files if "tensorrt-" in x and "cp310" in x]
|
||||
if len(matching_files) == 0:
|
||||
print(f"Could not find TensorRT .whl installer; looked in {trt_whl_path}")
|
||||
|
||||
whl = matching_files[0]
|
||||
launch.run_pip(f'install "{whl}"', "TensorRT wheel")
|
||||
|
||||
if not launch.is_installed("pycuda"):
|
||||
launch.run_pip(f'install pycuda', "pycuda")
|
||||
|
||||
if not launch.is_installed("onnx"):
|
||||
launch.run_pip(f'install onnx', "onnx")
|
||||
|
||||
|
||||
if trt_paths:
|
||||
install()
|
||||
|
|
@ -0,0 +1,123 @@
|
|||
import os
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
|
||||
from modules import script_callbacks, sd_unet, devices, shared, paths_internal
|
||||
|
||||
import trt_paths
|
||||
import ui_trt
|
||||
|
||||
import pycuda.driver as cuda
|
||||
|
||||
|
||||
class TrtUnetOption(sd_unet.SdUnetOption):
|
||||
def __init__(self, filename, name):
|
||||
self.label = f"[TRT] {name}"
|
||||
self.model_name = name
|
||||
self.filename = filename
|
||||
|
||||
def create_unet(self):
|
||||
return TrtUnet(self.filename)
|
||||
|
||||
|
||||
np_to_torch = {
|
||||
np.float32: torch.float32,
|
||||
np.float16: torch.float16,
|
||||
np.int8: torch.int8,
|
||||
np.uint8: torch.uint8,
|
||||
np.int32: torch.int32,
|
||||
}
|
||||
|
||||
|
||||
class TrtUnet(sd_unet.SdUnet):
|
||||
def __init__(self, filename, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.filename = filename
|
||||
self.engine = None
|
||||
self.trtcontext = None
|
||||
self.buffers = None
|
||||
self.buffers_shape = ()
|
||||
self.nptype = None
|
||||
|
||||
def allocate_buffers(self, feed_dict):
|
||||
|
||||
buffers_shape = sum([x.shape for x in feed_dict.values()], ())
|
||||
if self.buffers_shape == buffers_shape:
|
||||
return
|
||||
|
||||
self.buffers_shape = buffers_shape
|
||||
self.buffers = {}
|
||||
|
||||
for binding in self.engine:
|
||||
binding_idx = self.engine.get_binding_index(binding)
|
||||
dtype = self.nptype(self.engine.get_binding_dtype(binding))
|
||||
|
||||
if binding in feed_dict:
|
||||
shape = feed_dict[binding].shape
|
||||
else:
|
||||
shape = self.trtcontext.get_binding_shape(binding_idx)
|
||||
|
||||
if self.engine.binding_is_input(binding):
|
||||
if not self.trtcontext.set_binding_shape(binding_idx, shape):
|
||||
raise Exception(f'bad shape for TensorRT input {binding}: {tuple(shape)}')
|
||||
|
||||
tensor = torch.empty(tuple(shape), dtype=np_to_torch[dtype], device=devices.device)
|
||||
self.buffers[binding] = tensor
|
||||
|
||||
def infer(self, feed_dict):
|
||||
self.allocate_buffers(feed_dict)
|
||||
|
||||
for name, tensor in feed_dict.items():
|
||||
self.buffers[name].copy_(tensor)
|
||||
|
||||
for name, tensor in self.buffers.items():
|
||||
self.trtcontext.set_tensor_address(name, tensor.data_ptr())
|
||||
|
||||
ctx = cuda.Context.attach()
|
||||
stream = cuda.Stream()
|
||||
|
||||
self.trtcontext.execute_async_v3(stream.handle)
|
||||
|
||||
stream.synchronize()
|
||||
ctx.detach()
|
||||
|
||||
def forward(self, x, timesteps, context, *args, **kwargs):
|
||||
self.infer({"x": x, "timesteps": timesteps, "context": context})
|
||||
|
||||
return self.buffers["output"].to(dtype=x.dtype, device=devices.device)
|
||||
|
||||
def activate(self):
|
||||
import tensorrt as trt # we import this late because it breaks torch onnx export
|
||||
|
||||
TRT_LOGGER = trt.Logger()
|
||||
trt.init_libnvinfer_plugins(None, "")
|
||||
self.nptype = trt.nptype
|
||||
|
||||
with open(self.filename, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
|
||||
self.engine = runtime.deserialize_cuda_engine(f.read())
|
||||
|
||||
self.trtcontext = self.engine.create_execution_context()
|
||||
|
||||
def deactivate(self):
|
||||
self.engine = None
|
||||
self.trtcontext = None
|
||||
self.buffers = None
|
||||
self.buffers_shape = ()
|
||||
devices.torch_gc()
|
||||
|
||||
|
||||
def list_unets(l):
|
||||
|
||||
trt_dir = os.path.join(paths_internal.models_path, 'Unet-trt')
|
||||
candidates = list(shared.walk_files(trt_dir, allowed_extensions=[".trt"]))
|
||||
for filename in sorted(candidates, key=str.lower):
|
||||
name = os.path.splitext(os.path.basename(filename))[0]
|
||||
|
||||
opt = TrtUnetOption(filename, name)
|
||||
l.append(opt)
|
||||
|
||||
|
||||
script_callbacks.on_list_unets(list_unets)
|
||||
|
||||
script_callbacks.on_ui_tabs(ui_trt.on_ui_tabs)
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
import os
|
||||
import torch
|
||||
|
||||
script_path = os.path.dirname(os.path.realpath(__file__))
|
||||
trt_path = None
|
||||
cuda_path = None
|
||||
|
||||
|
||||
def set_paths():
|
||||
global trt_path, cuda_path
|
||||
|
||||
cuda_path = os.path.dirname(torch.__file__)
|
||||
cuda_lib_path = os.path.join(cuda_path, "lib")
|
||||
|
||||
assert os.path.exists(cuda_lib_path), "CUDA lib directory not found: " + cuda_lib_path
|
||||
|
||||
looked_in = []
|
||||
trt_path = None
|
||||
for dirname in os.listdir(script_path):
|
||||
path = os.path.join(script_path, dirname)
|
||||
if not os.path.isdir(path):
|
||||
continue
|
||||
|
||||
if os.path.exists(os.path.join(path, 'lib')) and (os.path.exists(os.path.join(path, 'bin', 'trtexec.exe')) or os.path.exists(os.path.join(path, 'bin', 'trtexec'))):
|
||||
trt_path = path
|
||||
break
|
||||
|
||||
looked_in.append(path)
|
||||
|
||||
assert trt_path is not None, "Was not able to find TensorRT directory. Looked in: " + ", ".join(looked_in)
|
||||
|
||||
trt_lib_path = os.path.join(trt_path, "lib")
|
||||
if trt_lib_path not in os.environ['PATH']:
|
||||
os.environ['PATH'] = os.environ['PATH'] + os.pathsep + trt_lib_path
|
||||
|
||||
if cuda_lib_path not in os.environ['PATH']:
|
||||
os.environ['PATH'] = os.environ['PATH'] + os.pathsep + cuda_lib_path
|
||||
|
||||
os.environ['CUDA_PATH'] = cuda_path # use same cuda as torch is using
|
||||
|
||||
|
||||
set_paths()
|
||||
|
|
@ -0,0 +1,138 @@
|
|||
import html
|
||||
import os
|
||||
|
||||
import launch
|
||||
import trt_paths
|
||||
from modules import script_callbacks, paths_internal, shared
|
||||
import gradio as gr
|
||||
|
||||
import export_onnx
|
||||
import export_trt
|
||||
from modules.call_queue import wrap_gradio_gpu_call
|
||||
from modules.shared import cmd_opts
|
||||
from modules.ui_components import FormRow
|
||||
|
||||
|
||||
def export_unet_to_onnx(filename, opset):
|
||||
if not filename:
|
||||
modelname = shared.sd_model.sd_checkpoint_info.model_name + ".onnx"
|
||||
filename = os.path.join(paths_internal.models_path, "Unet-onnx", modelname)
|
||||
|
||||
export_onnx.export_current_unet_to_onnx(filename, opset)
|
||||
|
||||
return f'Saved as {filename}', ''
|
||||
|
||||
|
||||
def get_trt_filename(filename, onnx_filename):
|
||||
if filename:
|
||||
return filename
|
||||
|
||||
modelname = os.path.splitext(os.path.basename(onnx_filename))[0] + ".trt"
|
||||
return os.path.join(paths_internal.models_path, "Unet-trt", modelname)
|
||||
|
||||
|
||||
def get_trt_command(filename, onnx_filename, *args):
|
||||
filename = get_trt_filename(filename, onnx_filename)
|
||||
command = export_trt.get_trt_command(filename, onnx_filename, *args)
|
||||
|
||||
env_command = f"""
|
||||
set PATH=%PATH%;{trt_paths.cuda_path}\\lib
|
||||
set PATH=%PATH%;{trt_paths.trt_path}\\lib
|
||||
"""
|
||||
|
||||
run_command = f"""
|
||||
{command}
|
||||
"""
|
||||
|
||||
return "Command generated", f"""
|
||||
<p>
|
||||
Environment variables: <br>
|
||||
<pre style="white-space: pre-line;">
|
||||
{html.escape(env_command)}
|
||||
</pre>
|
||||
</p>
|
||||
<p>
|
||||
Command: <br>
|
||||
<pre style="white-space: pre-line;">
|
||||
{html.escape(run_command)}
|
||||
</pre>
|
||||
</p>
|
||||
"""
|
||||
|
||||
|
||||
def convert_onnx_to_trt(filename, onnx_filename, *args):
|
||||
assert not cmd_opts.disable_extension_access, "won't run the command to create TensorRT file because extension access is dsabled (use --enable-insecure-extension-access)"
|
||||
|
||||
filename = get_trt_filename(filename, onnx_filename)
|
||||
command = export_trt.get_trt_command(filename, onnx_filename, *args)
|
||||
|
||||
launch.run(command, live=True)
|
||||
|
||||
return f'Saved as {filename}', ''
|
||||
|
||||
|
||||
def on_ui_tabs():
|
||||
with gr.Blocks(analytics_enabled=False) as trt_interface:
|
||||
with gr.Row().style(equal_height=False):
|
||||
with gr.Column(variant='panel'):
|
||||
with gr.Tabs(elem_id="trt_tabs"):
|
||||
with gr.Tab(label="Convert to ONNX"):
|
||||
gr.HTML(value="<p style='margin-bottom: 0.7em'>Convert currently loaded checkpoint into ONNX. The conversion will fail catastrophically if TensorRT was used at any point prior to conversion, so you might have to restart webui before doing the conversion.</p>")
|
||||
|
||||
onnx_filename = gr.Textbox(label='Filename', value="", elem_id="onnx_filename", info="Leave empty to use the same name as model and put results into models/Unet-onnx directory")
|
||||
onnx_opset = gr.Number(label='ONNX opset version', precision=0, value=17, info="Leave this alone unless you know what you are doing")
|
||||
|
||||
button_export_unet = gr.Button(value="Convert Unet to ONNX", variant='primary', elem_id="onnx_export_unet")
|
||||
|
||||
with gr.Tab(label="Convert ONNX to TensorRT"):
|
||||
trt_source_filename = gr.Textbox(label='Onnx model filename', value="", elem_id="trt_source_filename")
|
||||
trt_filename = gr.Textbox(label='Output filename', value="", elem_id="trt_filename", info="Leave empty to use the same name as onnx and put results into models/Unet-trt directory")
|
||||
|
||||
with gr.Column(elem_id="trt_width"):
|
||||
min_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Minimum width", value=512, elem_id="trt_min_width")
|
||||
max_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Maximum width", value=512, elem_id="trt_max_width")
|
||||
|
||||
with gr.Column(elem_id="trt_height"):
|
||||
min_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Minimum height", value=512, elem_id="trt_min_height")
|
||||
max_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Maximum height", value=512, elem_id="trt_max_height")
|
||||
|
||||
with gr.Column(elem_id="trt_batch_size"):
|
||||
min_bs = gr.Slider(minimum=1, maximum=16, step=1, label="Minimum batch size", value=1, elem_id="trt_min_bs")
|
||||
max_bs = gr.Slider(minimum=1, maximum=16, step=1, label="Maximum batch size", value=1, elem_id="trt_max_bs")
|
||||
|
||||
with gr.Column(elem_id="trt_token_count"):
|
||||
min_token_count = gr.Slider(minimum=75, maximum=750, step=75, label="Minimum prompt token count", value=75, elem_id="trt_min_token_count")
|
||||
max_token_count = gr.Slider(minimum=75, maximum=750, step=75, label="Maximum prompt token count", value=75, elem_id="trt_max_token_count")
|
||||
|
||||
trt_extra_args = gr.Textbox(label='Extra arguments', value="", elem_id="trt_extra_args", info="Extra arguments for trtexec command in plain text form")
|
||||
|
||||
with FormRow(elem_classes="checkboxes-row", variant="compact"):
|
||||
use_fp16 = gr.Checkbox(label='Use half floats', value=True, elem_id="trt_fp16")
|
||||
|
||||
button_export_trt = gr.Button(value="Convert ONNX to TensorRT", variant='primary', elem_id="trt_convert_from_onnx")
|
||||
button_show_trt_command = gr.Button(value="Show command for conversion", variant='secondary', elem_id="trt_convert_from_onnx")
|
||||
|
||||
with gr.Column(variant='panel'):
|
||||
trt_result = gr.Label(elem_id="trt_result", value="", show_label=False)
|
||||
trt_info = gr.HTML(elem_id="trt_info", value="")
|
||||
|
||||
button_export_unet.click(
|
||||
wrap_gradio_gpu_call(export_unet_to_onnx, extra_outputs=["Conversion failed"]),
|
||||
inputs=[onnx_filename, onnx_opset],
|
||||
outputs=[trt_result, trt_info],
|
||||
)
|
||||
|
||||
button_export_trt.click(
|
||||
wrap_gradio_gpu_call(convert_onnx_to_trt, extra_outputs=[""]),
|
||||
inputs=[trt_filename, trt_source_filename, min_bs, max_bs, min_token_count, max_token_count, min_width, max_width, min_height, max_height, use_fp16, trt_extra_args],
|
||||
outputs=[trt_result, trt_info],
|
||||
)
|
||||
|
||||
button_show_trt_command.click(
|
||||
wrap_gradio_gpu_call(get_trt_command, extra_outputs=[""]),
|
||||
inputs=[trt_filename, trt_source_filename, min_bs, max_bs, min_token_count, max_token_count, min_width, max_width, min_height, max_height, use_fp16, trt_extra_args],
|
||||
outputs=[trt_result, trt_info],
|
||||
)
|
||||
|
||||
return [(trt_interface, "TensorRT", "tensorrt")]
|
||||
|
||||
Loading…
Reference in New Issue