automatic/modules/sd_te_remote.py

41 lines
1.3 KiB
Python

import os
import time
import json
import torch
import requests
from modules import devices, errors
from modules.logger import log
def get_t5_prompt_embeds(
prompt: str | list[str] = None,
num_images_per_prompt: int = 1, # pylint: disable=unused-argument
max_sequence_length: int = 512, # pylint: disable=unused-argument
device: torch.device | None = None,
dtype: torch.dtype | None = None,
):
device = device or devices.device
dtype = dtype or devices.dtype
url = os.environ.get('SD_REMOTE_T5', None)
if url is None:
log.error('Remote-TE: url is not set')
return None
try:
t0 = time.time()
response = requests.post(
url=url,
headers={ "Content-Type": "application/json" },
json=prompt,
timeout=300,
)
t1 = time.time()
shape = json.loads(response.headers["shape"])
buffer = bytearray(response.content)
tensor = torch.frombuffer(buffer, dtype=dtype).reshape(shape)
log.debug(f'Remote-TE: url="{url}" prompt="{prompt}" shape={shape} time={t1-t0:.3f}')
return tensor.to(device=device, dtype=dtype)
except Exception as e:
log.error(f'Remote-TE: {e}')
errors.display(e, 'remote-te')
return None