added cheap NN approximation for VAE

This commit is contained in:
AUTOMATIC 2022-12-24 22:39:00 +03:00
parent 5927d3fa95
commit 56e557c6ff
5 changed files with 81 additions and 17 deletions

View File

@ -97,7 +97,10 @@ titles = {
"Learning rate": "how fast should the training go. Low values will take longer to train, high values may fail to converge (not generate accurate results) and/or may break the embedding (This has happened if you see Loss: nan in the training info textbox. If this happens, you need to manually restore your embedding from an older not-broken backup).\n\nYou can set a single numeric value, or multiple learning rates using the syntax:\n\n rate_1:max_steps_1, rate_2:max_steps_2, ...\n\nEG: 0.005:100, 1e-3:1000, 1e-5\n\nWill train with rate of 0.005 for first 100 steps, then 1e-3 until 1000 steps, then 1e-5 for all remaining steps.",
"Clip skip": "Early stopping parameter for CLIP model; 1 is stop at last layer as usual, 2 is stop at penultimate layer, etc."
"Clip skip": "Early stopping parameter for CLIP model; 1 is stop at last layer as usual, 2 is stop at penultimate layer, etc.",
"Approx NN": "Cheap neural network approximation. Very fast compared to VAE, but produces pictures with 4 times smaller horizontal/vertical resoluton and lower quality.",
"Approx cheap": "Very cheap approximation. Very fast compared to VAE, but produces pictures with 8 times smaller horizontal/vertical resoluton and extremely low quality."
}

BIN
models/VAE-approx/model.pt Normal file

Binary file not shown.

View File

@ -9,7 +9,7 @@ import k_diffusion.sampling
import torchsde._brownian.brownian_interval
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
from modules import prompt_parser, devices, processing, images
from modules import prompt_parser, devices, processing, images, sd_vae_approx
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@ -106,28 +106,31 @@ def setup_img2img_steps(p, steps=None):
return steps, t_enc
def single_sample_to_image(sample, approximation=False):
if approximation:
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2
coefs = torch.tensor(
[[ 0.298, 0.207, 0.208],
[ 0.187, 0.286, 0.173],
[-0.158, 0.189, 0.264],
[-0.184, -0.271, -0.473]]).to(sample.device)
x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs)
approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2}
def single_sample_to_image(sample, approximation=None):
if approximation is None:
approximation = approximation_indexes.get(opts.show_progress_type, 0)
if approximation == 2:
x_sample = sd_vae_approx.cheap_approximation(sample)
elif approximation == 1:
x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
else:
x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
return Image.fromarray(x_sample)
def sample_to_image(samples, index=0, approximation=False):
def sample_to_image(samples, index=0, approximation=None):
return single_sample_to_image(samples[index], approximation)
def samples_to_image_grid(samples, approximation=False):
def samples_to_image_grid(samples, approximation=None):
return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
@ -136,7 +139,7 @@ def store_latent(decoded):
if opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
if not shared.parallel_processing_allowed:
shared.state.current_image = sample_to_image(decoded, approximation=opts.show_progress_approximate)
shared.state.current_image = sample_to_image(decoded)
class InterruptedException(BaseException):

58
modules/sd_vae_approx.py Normal file
View File

@ -0,0 +1,58 @@
import os
import torch
from torch import nn
from modules import devices, paths
sd_vae_approx_model = None
class VAEApprox(nn.Module):
def __init__(self):
super(VAEApprox, self).__init__()
self.conv1 = nn.Conv2d(4, 8, (7, 7))
self.conv2 = nn.Conv2d(8, 16, (5, 5))
self.conv3 = nn.Conv2d(16, 32, (3, 3))
self.conv4 = nn.Conv2d(32, 64, (3, 3))
self.conv5 = nn.Conv2d(64, 32, (3, 3))
self.conv6 = nn.Conv2d(32, 16, (3, 3))
self.conv7 = nn.Conv2d(16, 8, (3, 3))
self.conv8 = nn.Conv2d(8, 3, (3, 3))
def forward(self, x):
extra = 11
x = nn.functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2))
x = nn.functional.pad(x, (extra, extra, extra, extra))
for layer in [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5, self.conv6, self.conv7, self.conv8, ]:
x = layer(x)
x = nn.functional.leaky_relu(x, 0.1)
return x
def model():
global sd_vae_approx_model
if sd_vae_approx_model is None:
sd_vae_approx_model = VAEApprox()
sd_vae_approx_model.load_state_dict(torch.load(os.path.join(paths.models_path, "VAE-approx", "model.pt")))
sd_vae_approx_model.eval()
sd_vae_approx_model.to(devices.device, devices.dtype)
return sd_vae_approx_model
def cheap_approximation(sample):
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2
coefs = torch.tensor([
[0.298, 0.207, 0.208],
[0.187, 0.286, 0.173],
[-0.158, 0.189, 0.264],
[-0.184, -0.271, -0.473],
]).to(sample.device)
x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs)
return x_sample

View File

@ -212,9 +212,9 @@ class State:
import modules.sd_samplers
if opts.show_progress_grid:
self.current_image = modules.sd_samplers.samples_to_image_grid(self.current_latent, approximation=opts.show_progress_approximate)
self.current_image = modules.sd_samplers.samples_to_image_grid(self.current_latent)
else:
self.current_image = modules.sd_samplers.sample_to_image(self.current_latent, approximation=opts.show_progress_approximate)
self.current_image = modules.sd_samplers.sample_to_image(self.current_latent)
self.current_image_sampling_step = self.sampling_step
@ -392,7 +392,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
options_templates.update(options_section(('ui', "User interface"), {
"show_progressbar": OptionInfo(True, "Show progressbar"),
"show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set to 0 to disable. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
"show_progress_approximate": OptionInfo(False, "Calculate small previews using fast linear approximation instead of VAE"),
"show_progress_type": OptionInfo("Full", "Image creation progress mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}),
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
"return_grid": OptionInfo(True, "Show grid in results for web"),
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),