mirror of https://github.com/vladmandic/automatic
add tensor to samples method
parent
7e298b2039
commit
8797a34e19
|
|
@ -220,7 +220,7 @@ def esrgan_upscale(model, img):
|
|||
|
||||
with Progress(TextColumn('[cyan]{task.description}'), BarColumn(), TaskProgressColumn(), TimeRemainingColumn(), TimeElapsedColumn(), console=console) as progress:
|
||||
total = 0
|
||||
for y, h, row in grid.tiles:
|
||||
for _y, _h, row in grid.tiles:
|
||||
total += len(row)
|
||||
task = progress.add_task(description="Upscaling", total=total)
|
||||
for y, h, row in grid.tiles:
|
||||
|
|
|
|||
|
|
@ -155,7 +155,7 @@ class RealESRGANer():
|
|||
# input tile dimensions
|
||||
input_tile_width = input_end_x - input_start_x
|
||||
input_tile_height = input_end_y - input_start_y
|
||||
tile_idx = y * tiles_x + x + 1
|
||||
tile_idx = y * tiles_x + x + 1 # noqa
|
||||
input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
|
||||
|
||||
# upscale tile
|
||||
|
|
@ -315,7 +315,7 @@ class IOConsumer(threading.Thread):
|
|||
|
||||
from basicsr.utils.registry import ARCH_REGISTRY
|
||||
from torch import nn as nn
|
||||
from torch.nn import functional as F
|
||||
from torch.nn import functional as F # noqa
|
||||
|
||||
|
||||
class SRVGGNetCompact(nn.Module):
|
||||
|
|
@ -381,4 +381,3 @@ class SRVGGNetCompact(nn.Module):
|
|||
base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
|
||||
out += base
|
||||
return out
|
||||
|
||||
|
|
@ -55,6 +55,29 @@ def samples_to_image_grid(samples, approximation=None):
|
|||
return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
|
||||
|
||||
|
||||
def images_tensor_to_samples(image, approximation=None, model=None):
|
||||
'''image[0, 1] -> latent'''
|
||||
if approximation is None:
|
||||
approximation = approximation_indexes.get(shared.opts.show_progress_type, 0)
|
||||
if approximation == 3:
|
||||
image = image.to(devices.device, devices.dtype)
|
||||
x_latent = sd_vae_taesd.encode(image)
|
||||
else:
|
||||
if model is None:
|
||||
model = shared.sd_model
|
||||
model.first_stage_model.to(devices.dtype_vae)
|
||||
image = image.to(shared.device, dtype=devices.dtype_vae)
|
||||
image = image * 2 - 1
|
||||
if len(image) > 1:
|
||||
x_latent = torch.stack([
|
||||
model.get_first_stage_encoding(model.encode_first_stage(torch.unsqueeze(img, 0)))[0]
|
||||
for img in image
|
||||
])
|
||||
else:
|
||||
x_latent = model.get_first_stage_encoding(model.encode_first_stage(image))
|
||||
return x_latent
|
||||
|
||||
|
||||
def store_latent(decoded):
|
||||
shared.state.current_latent = decoded
|
||||
if shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % shared.opts.show_progress_every_n_steps == 0:
|
||||
|
|
|
|||
Loading…
Reference in New Issue