pull/4/head
AUTOMATIC 2023-05-27 17:14:35 +03:00
commit c7fae24ca9
8 changed files with 448 additions and 0 deletions

2
.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
__pycache__
/TensorRT-*

36
README.md Normal file
View File

@ -0,0 +1,36 @@
# TensorRT support for webui
Adds the ability to convert loaded model's Unet module into TensortRT. Requires version least after commit [339b5315](htts://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/339b5315700a469f4a9f0d5afc08ca2aca60c579) (currently, it's the `dev` branch after 2023-05-27). Only tested to work on Windows.
Loras are baked in into the converted model. Hypernetwork support is not tested. Controlnet is not supported. Textual inversion works normally.
NVIDIA is also working on releaseing their version of TensorRT for webui, which might be more performant, but they can't release it yet.
There seems to be support for quickly replacing weight of a TensorRT engine without rebuilding it, and this extension does not offer this option yet.
## How to install
Apart from installing the extension normally, you also need to download zip with TensorRT from [NVIDIA](https://developer.nvidia.com/nvidia-tensorrt-8x-download).
You need to choose the same version of CUDA as python's torch library is using. For torch 2.0.1 it is CUDA 11.8.
Extract the zip into extension directory, so that `TensorRT-8.6.1.6` (or similarly named dir) exists in the same place as `scripts` directory and `trt_path.py` file. Restart webui afterwards.
You don't need to install CUDA separately.
## How to use
1. Slect the model you want to optimize and make a picture with it, including needed loras and hypernetworks.
2. Go to a `TensorRT` tab that appears if the extension loads properly.
3. In `Convert to ONNX` tab, press `Convert Unet to ONNX`.
* This takes a short while.
* After the conversion has finished, you will find an `.onnx` file with model in `models/Unet-onnx` directory.
4. In `Convert ONNX to TensorRT` tab, configure the necessary parameters (including writing full path to onnx model) and press `Convert ONNX to TensorRT`.
* This takes very long - from 15 minues to an hour.
* This takes up a lot of VRAM: you might want to press "Show command for conversion" and run the command yourself after shutting down webui.
* After the conversion has finished, you will find a `.trt` file with model in `models/Unet-trt` directory.
5. In settings, in `Stable Diffusion` page, use `SD Unet` option to select newly generated TensorRT model.
6. Generate pictures.
## Stable Diffusion 2.0 support
Stable diffusion 2.0 conversion should fail for both ONNX and TensorRT because of incompatible shapes, but you may be able to rememdy this by chaning instances of 768 to 1024 in the code.

45
export_onnx.py Normal file
View File

@ -0,0 +1,45 @@
import os
from modules import sd_hijack, sd_unet
from modules import shared, devices
import torch
def export_current_unet_to_onnx(filename, opset_version=17):
x = torch.randn(1, 4, 16, 16).to(devices.device, devices.dtype)
timesteps = torch.zeros((1,)).to(devices.device, devices.dtype) + 500
context = torch.randn(1, 77, 768).to(devices.device, devices.dtype)
def disable_checkpoint(self):
if getattr(self, 'use_checkpoint', False) == True:
self.use_checkpoint = False
if getattr(self, 'checkpoint', False) == True:
self.checkpoint = False
shared.sd_model.model.diffusion_model.apply(disable_checkpoint)
sd_unet.apply_unet("None")
sd_hijack.model_hijack.apply_optimizations('None')
os.makedirs(os.path.dirname(filename), exist_ok=True)
with devices.autocast():
torch.onnx.export(
shared.sd_model.model.diffusion_model,
(x, timesteps, context),
filename,
export_params=True,
opset_version=opset_version,
do_constant_folding=True,
input_names=['x', 'timesteps', 'context'],
output_names=['output'],
dynamic_axes={
'x': {0: 'batch_size', 2: 'height', 3: 'width'},
'timesteps': {0: 'batch_size'},
'context': {0: 'batch_size', 1: 'sequence_length'},
'output': {0: 'batch_size'},
},
)
sd_hijack.model_hijack.apply_optimizations()
sd_unet.apply_unet()

32
export_trt.py Normal file
View File

@ -0,0 +1,32 @@
import os.path
def get_trt_command(trt_filename, onnx_filename, min_bs, max_bs, min_token_count, max_token_count, min_width, max_width, min_height, max_height, use_fp16, trt_extra_args):
for val, name in zip([min_width, max_width, min_height, max_height], ["min_width", "max_width", "min_height", "max_height"]):
assert val % 64 == 0, name + ' must be a multiple of 64'
for val, name in zip([min_token_count, max_token_count], ["min_token_count", "max_token_count"]):
assert val % 75 == 0, name + ' must be a multiple of 75'
assert os.path.isfile(onnx_filename), 'onnx model not found: ' + onnx_filename
import trt_paths
trt_exec_candidates = [
os.path.join(trt_paths.trt_path, "bin", "trtexec"),
os.path.join(trt_paths.trt_path, "bin", "trtexec.exe"),
]
trt_exec = next(iter([x for x in trt_exec_candidates if os.path.isfile(x)]), None)
assert trt_exec, f"could not find trtexec; searched in: {', '.join(trt_exec_candidates)}"
cond_dim = 768 # XXX should be detected for SD2.0
x_min = f"{min_bs * 2}x4x{min_height // 8}x{min_width // 8}"
x_max = f"{max_bs * 2}x4x{max_height // 8}x{max_width // 8}"
context_min = f"{min_bs * 2}x{min_token_count // 75 * 77}x{cond_dim}"
context_max = f"{max_bs * 2}x{max_token_count // 75 * 77}x{cond_dim}"
timestamps_min = f"{min_bs * 2}"
timestamps_max = f"{max_bs * 2}"
os.makedirs(os.path.dirname(trt_filename), exist_ok=True)
return f""""{trt_exec}" --onnx="{onnx_filename}" --saveEngine="{trt_filename}" --minShapes=x:{x_min},context:{context_min},timesteps:{timestamps_min} --maxShapes=x:{x_max},context:{context_max},timesteps:{timestamps_max}{' --fp16' if use_fp16 else ''} {trt_extra_args}"""

30
install.py Normal file
View File

@ -0,0 +1,30 @@
import os
import launch
try:
import trt_paths
except Exception as e:
print("Could not find TensorRT directory; skipping install", e)
def install():
if not launch.is_installed("tensorrt"):
trt_whl_path = os.path.join(trt_paths.trt_path, "python")
matching_files = [os.path.join(trt_whl_path, x) for x in os.listdir(trt_whl_path)]
matching_files = [x for x in matching_files if "tensorrt-" in x and "cp310" in x]
if len(matching_files) == 0:
print(f"Could not find TensorRT .whl installer; looked in {trt_whl_path}")
whl = matching_files[0]
launch.run_pip(f'install "{whl}"', "TensorRT wheel")
if not launch.is_installed("pycuda"):
launch.run_pip(f'install pycuda', "pycuda")
if not launch.is_installed("onnx"):
launch.run_pip(f'install onnx', "onnx")
if trt_paths:
install()

123
scripts/trt.py Normal file
View File

@ -0,0 +1,123 @@
import os
import numpy as np
import torch
from modules import script_callbacks, sd_unet, devices, shared, paths_internal
import trt_paths
import ui_trt
import pycuda.driver as cuda
class TrtUnetOption(sd_unet.SdUnetOption):
def __init__(self, filename, name):
self.label = f"[TRT] {name}"
self.model_name = name
self.filename = filename
def create_unet(self):
return TrtUnet(self.filename)
np_to_torch = {
np.float32: torch.float32,
np.float16: torch.float16,
np.int8: torch.int8,
np.uint8: torch.uint8,
np.int32: torch.int32,
}
class TrtUnet(sd_unet.SdUnet):
def __init__(self, filename, *args, **kwargs):
super().__init__(*args, **kwargs)
self.filename = filename
self.engine = None
self.trtcontext = None
self.buffers = None
self.buffers_shape = ()
self.nptype = None
def allocate_buffers(self, feed_dict):
buffers_shape = sum([x.shape for x in feed_dict.values()], ())
if self.buffers_shape == buffers_shape:
return
self.buffers_shape = buffers_shape
self.buffers = {}
for binding in self.engine:
binding_idx = self.engine.get_binding_index(binding)
dtype = self.nptype(self.engine.get_binding_dtype(binding))
if binding in feed_dict:
shape = feed_dict[binding].shape
else:
shape = self.trtcontext.get_binding_shape(binding_idx)
if self.engine.binding_is_input(binding):
if not self.trtcontext.set_binding_shape(binding_idx, shape):
raise Exception(f'bad shape for TensorRT input {binding}: {tuple(shape)}')
tensor = torch.empty(tuple(shape), dtype=np_to_torch[dtype], device=devices.device)
self.buffers[binding] = tensor
def infer(self, feed_dict):
self.allocate_buffers(feed_dict)
for name, tensor in feed_dict.items():
self.buffers[name].copy_(tensor)
for name, tensor in self.buffers.items():
self.trtcontext.set_tensor_address(name, tensor.data_ptr())
ctx = cuda.Context.attach()
stream = cuda.Stream()
self.trtcontext.execute_async_v3(stream.handle)
stream.synchronize()
ctx.detach()
def forward(self, x, timesteps, context, *args, **kwargs):
self.infer({"x": x, "timesteps": timesteps, "context": context})
return self.buffers["output"].to(dtype=x.dtype, device=devices.device)
def activate(self):
import tensorrt as trt # we import this late because it breaks torch onnx export
TRT_LOGGER = trt.Logger()
trt.init_libnvinfer_plugins(None, "")
self.nptype = trt.nptype
with open(self.filename, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
self.engine = runtime.deserialize_cuda_engine(f.read())
self.trtcontext = self.engine.create_execution_context()
def deactivate(self):
self.engine = None
self.trtcontext = None
self.buffers = None
self.buffers_shape = ()
devices.torch_gc()
def list_unets(l):
trt_dir = os.path.join(paths_internal.models_path, 'Unet-trt')
candidates = list(shared.walk_files(trt_dir, allowed_extensions=[".trt"]))
for filename in sorted(candidates, key=str.lower):
name = os.path.splitext(os.path.basename(filename))[0]
opt = TrtUnetOption(filename, name)
l.append(opt)
script_callbacks.on_list_unets(list_unets)
script_callbacks.on_ui_tabs(ui_trt.on_ui_tabs)

42
trt_paths.py Normal file
View File

@ -0,0 +1,42 @@
import os
import torch
script_path = os.path.dirname(os.path.realpath(__file__))
trt_path = None
cuda_path = None
def set_paths():
global trt_path, cuda_path
cuda_path = os.path.dirname(torch.__file__)
cuda_lib_path = os.path.join(cuda_path, "lib")
assert os.path.exists(cuda_lib_path), "CUDA lib directory not found: " + cuda_lib_path
looked_in = []
trt_path = None
for dirname in os.listdir(script_path):
path = os.path.join(script_path, dirname)
if not os.path.isdir(path):
continue
if os.path.exists(os.path.join(path, 'lib')) and (os.path.exists(os.path.join(path, 'bin', 'trtexec.exe')) or os.path.exists(os.path.join(path, 'bin', 'trtexec'))):
trt_path = path
break
looked_in.append(path)
assert trt_path is not None, "Was not able to find TensorRT directory. Looked in: " + ", ".join(looked_in)
trt_lib_path = os.path.join(trt_path, "lib")
if trt_lib_path not in os.environ['PATH']:
os.environ['PATH'] = os.environ['PATH'] + os.pathsep + trt_lib_path
if cuda_lib_path not in os.environ['PATH']:
os.environ['PATH'] = os.environ['PATH'] + os.pathsep + cuda_lib_path
os.environ['CUDA_PATH'] = cuda_path # use same cuda as torch is using
set_paths()

138
ui_trt.py Normal file
View File

@ -0,0 +1,138 @@
import html
import os
import launch
import trt_paths
from modules import script_callbacks, paths_internal, shared
import gradio as gr
import export_onnx
import export_trt
from modules.call_queue import wrap_gradio_gpu_call
from modules.shared import cmd_opts
from modules.ui_components import FormRow
def export_unet_to_onnx(filename, opset):
if not filename:
modelname = shared.sd_model.sd_checkpoint_info.model_name + ".onnx"
filename = os.path.join(paths_internal.models_path, "Unet-onnx", modelname)
export_onnx.export_current_unet_to_onnx(filename, opset)
return f'Saved as {filename}', ''
def get_trt_filename(filename, onnx_filename):
if filename:
return filename
modelname = os.path.splitext(os.path.basename(onnx_filename))[0] + ".trt"
return os.path.join(paths_internal.models_path, "Unet-trt", modelname)
def get_trt_command(filename, onnx_filename, *args):
filename = get_trt_filename(filename, onnx_filename)
command = export_trt.get_trt_command(filename, onnx_filename, *args)
env_command = f"""
set PATH=%PATH%;{trt_paths.cuda_path}\\lib
set PATH=%PATH%;{trt_paths.trt_path}\\lib
"""
run_command = f"""
{command}
"""
return "Command generated", f"""
<p>
Environment variables: <br>
<pre style="white-space: pre-line;">
{html.escape(env_command)}
</pre>
</p>
<p>
Command: <br>
<pre style="white-space: pre-line;">
{html.escape(run_command)}
</pre>
</p>
"""
def convert_onnx_to_trt(filename, onnx_filename, *args):
assert not cmd_opts.disable_extension_access, "won't run the command to create TensorRT file because extension access is dsabled (use --enable-insecure-extension-access)"
filename = get_trt_filename(filename, onnx_filename)
command = export_trt.get_trt_command(filename, onnx_filename, *args)
launch.run(command, live=True)
return f'Saved as {filename}', ''
def on_ui_tabs():
with gr.Blocks(analytics_enabled=False) as trt_interface:
with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'):
with gr.Tabs(elem_id="trt_tabs"):
with gr.Tab(label="Convert to ONNX"):
gr.HTML(value="<p style='margin-bottom: 0.7em'>Convert currently loaded checkpoint into ONNX. The conversion will fail catastrophically if TensorRT was used at any point prior to conversion, so you might have to restart webui before doing the conversion.</p>")
onnx_filename = gr.Textbox(label='Filename', value="", elem_id="onnx_filename", info="Leave empty to use the same name as model and put results into models/Unet-onnx directory")
onnx_opset = gr.Number(label='ONNX opset version', precision=0, value=17, info="Leave this alone unless you know what you are doing")
button_export_unet = gr.Button(value="Convert Unet to ONNX", variant='primary', elem_id="onnx_export_unet")
with gr.Tab(label="Convert ONNX to TensorRT"):
trt_source_filename = gr.Textbox(label='Onnx model filename', value="", elem_id="trt_source_filename")
trt_filename = gr.Textbox(label='Output filename', value="", elem_id="trt_filename", info="Leave empty to use the same name as onnx and put results into models/Unet-trt directory")
with gr.Column(elem_id="trt_width"):
min_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Minimum width", value=512, elem_id="trt_min_width")
max_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Maximum width", value=512, elem_id="trt_max_width")
with gr.Column(elem_id="trt_height"):
min_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Minimum height", value=512, elem_id="trt_min_height")
max_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Maximum height", value=512, elem_id="trt_max_height")
with gr.Column(elem_id="trt_batch_size"):
min_bs = gr.Slider(minimum=1, maximum=16, step=1, label="Minimum batch size", value=1, elem_id="trt_min_bs")
max_bs = gr.Slider(minimum=1, maximum=16, step=1, label="Maximum batch size", value=1, elem_id="trt_max_bs")
with gr.Column(elem_id="trt_token_count"):
min_token_count = gr.Slider(minimum=75, maximum=750, step=75, label="Minimum prompt token count", value=75, elem_id="trt_min_token_count")
max_token_count = gr.Slider(minimum=75, maximum=750, step=75, label="Maximum prompt token count", value=75, elem_id="trt_max_token_count")
trt_extra_args = gr.Textbox(label='Extra arguments', value="", elem_id="trt_extra_args", info="Extra arguments for trtexec command in plain text form")
with FormRow(elem_classes="checkboxes-row", variant="compact"):
use_fp16 = gr.Checkbox(label='Use half floats', value=True, elem_id="trt_fp16")
button_export_trt = gr.Button(value="Convert ONNX to TensorRT", variant='primary', elem_id="trt_convert_from_onnx")
button_show_trt_command = gr.Button(value="Show command for conversion", variant='secondary', elem_id="trt_convert_from_onnx")
with gr.Column(variant='panel'):
trt_result = gr.Label(elem_id="trt_result", value="", show_label=False)
trt_info = gr.HTML(elem_id="trt_info", value="")
button_export_unet.click(
wrap_gradio_gpu_call(export_unet_to_onnx, extra_outputs=["Conversion failed"]),
inputs=[onnx_filename, onnx_opset],
outputs=[trt_result, trt_info],
)
button_export_trt.click(
wrap_gradio_gpu_call(convert_onnx_to_trt, extra_outputs=[""]),
inputs=[trt_filename, trt_source_filename, min_bs, max_bs, min_token_count, max_token_count, min_width, max_width, min_height, max_height, use_fp16, trt_extra_args],
outputs=[trt_result, trt_info],
)
button_show_trt_command.click(
wrap_gradio_gpu_call(get_trt_command, extra_outputs=[""]),
inputs=[trt_filename, trt_source_filename, min_bs, max_bs, min_token_count, max_token_count, min_width, max_width, min_height, max_height, use_fp16, trt_extra_args],
outputs=[trt_result, trt_info],
)
return [(trt_interface, "TensorRT", "tensorrt")]