add tensor to samples method

pull/2228/head
Vladimir Mandic 2023-09-18 15:49:04 -04:00
parent 7e298b2039
commit 8797a34e19
3 changed files with 26 additions and 4 deletions

View File

@ -220,7 +220,7 @@ def esrgan_upscale(model, img):
with Progress(TextColumn('[cyan]{task.description}'), BarColumn(), TaskProgressColumn(), TimeRemainingColumn(), TimeElapsedColumn(), console=console) as progress:
total = 0
for y, h, row in grid.tiles:
for _y, _h, row in grid.tiles:
total += len(row)
task = progress.add_task(description="Upscaling", total=total)
for y, h, row in grid.tiles:

View File

@ -155,7 +155,7 @@ class RealESRGANer():
# input tile dimensions
input_tile_width = input_end_x - input_start_x
input_tile_height = input_end_y - input_start_y
tile_idx = y * tiles_x + x + 1
tile_idx = y * tiles_x + x + 1 # noqa
input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
# upscale tile
@ -315,7 +315,7 @@ class IOConsumer(threading.Thread):
from basicsr.utils.registry import ARCH_REGISTRY
from torch import nn as nn
from torch.nn import functional as F
from torch.nn import functional as F # noqa
class SRVGGNetCompact(nn.Module):
@ -381,4 +381,3 @@ class SRVGGNetCompact(nn.Module):
base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
out += base
return out

View File

@ -55,6 +55,29 @@ def samples_to_image_grid(samples, approximation=None):
return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
def images_tensor_to_samples(image, approximation=None, model=None):
'''image[0, 1] -> latent'''
if approximation is None:
approximation = approximation_indexes.get(shared.opts.show_progress_type, 0)
if approximation == 3:
image = image.to(devices.device, devices.dtype)
x_latent = sd_vae_taesd.encode(image)
else:
if model is None:
model = shared.sd_model
model.first_stage_model.to(devices.dtype_vae)
image = image.to(shared.device, dtype=devices.dtype_vae)
image = image * 2 - 1
if len(image) > 1:
x_latent = torch.stack([
model.get_first_stage_encoding(model.encode_first_stage(torch.unsqueeze(img, 0)))[0]
for img in image
])
else:
x_latent = model.get_first_stage_encoding(model.encode_first_stage(image))
return x_latent
def store_latent(decoded):
shared.state.current_latent = decoded
if shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % shared.opts.show_progress_every_n_steps == 0: