diff --git a/modules/hypernetwork.py b/modules/hypernetwork.py new file mode 100644 index 0000000..9ed1eed --- /dev/null +++ b/modules/hypernetwork.py @@ -0,0 +1,55 @@ +import glob +import os +import torch +from modules import devices + + +class HypernetworkModule(torch.nn.Module): + def __init__(self, dim, state_dict): + super().__init__() + + self.linear1 = torch.nn.Linear(dim, dim * 2) + self.linear2 = torch.nn.Linear(dim * 2, dim) + + self.load_state_dict(state_dict, strict=True) + self.to(devices.device) + + def forward(self, x): + return x + (self.linear2(self.linear1(x))) + + +class Hypernetwork: + filename = None + name = None + + def __init__(self, filename): + self.filename = filename + self.name = os.path.splitext(os.path.basename(filename))[0] + self.layers = {} + + state_dict = torch.load(filename, map_location='cpu') + for size, sd in state_dict.items(): + self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1])) + + +def load_hypernetworks(path): + res = {} + + for filename in glob.iglob(path + '**/*.pt', recursive=True): + hn = Hypernetwork(filename) + res[hn.name] = hn + + return res + +def apply(self, x, context=None, mask=None, original=None): + + + if CrossAttention.hypernetwork is not None and context.shape[2] in CrossAttention.hypernetwork: + if context.shape[1] == 77 and CrossAttention.noise_cond: + context = context + (torch.randn_like(context) * 0.1) + h_k, h_v = CrossAttention.hypernetwork[context.shape[2]] + k = self.to_k(h_k(context)) + v = self.to_v(h_v(context)) + else: + k = self.to_k(context) + v = self.to_v(context) diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index ea4cfdf..d9cca48 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -5,6 +5,8 @@ from torch import einsum from ldm.util import default from einops import rearrange +from modules import shared + # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion def split_cross_attention_forward_v1(self, x, context=None, mask=None): @@ -42,8 +44,19 @@ def split_cross_attention_forward(self, x, context=None, mask=None): q_in = self.to_q(x) context = default(context, x) - k_in = self.to_k(context) * self.scale - v_in = self.to_v(context) + + hypernetwork = shared.selected_hypernetwork() + hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) + + if hypernetwork_layers is not None: + k_in = self.to_k(hypernetwork_layers[0](context)) + v_in = self.to_v(hypernetwork_layers[1](context)) + else: + k_in = self.to_k(context) + v_in = self.to_v(context) + + k_in *= self.scale + del context, x q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) diff --git a/modules/shared.py b/modules/shared.py index 25bb6e6..879d842 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -13,7 +13,7 @@ import modules.memmon import modules.sd_models import modules.styles import modules.devices as devices -from modules import sd_samplers +from modules import sd_samplers, hypernetwork from modules.paths import models_path, script_path, sd_path sd_model_file = os.path.join(script_path, 'model.ckpt') @@ -76,6 +76,12 @@ parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram config_filename = cmd_opts.ui_settings_file +hypernetworks = hypernetwork.load_hypernetworks(os.path.join(models_path, 'hypernetworks')) + + +def selected_hypernetwork(): + return hypernetworks.get(opts.sd_hypernetwork, None) + class State: interrupted = False @@ -206,6 +212,7 @@ options_templates.update(options_section(('system', "System"), { options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}), + "sd_hypernetwork": OptionInfo("None", "Stable Diffusion finetune hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."), diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 6344e61..c0c364d 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -77,6 +77,11 @@ def apply_checkpoint(p, x, xs): modules.sd_models.reload_model_weights(shared.sd_model, info) +def apply_hypernetwork(p, x, xs): + hn = shared.hypernetworks.get(x, None) + opts.data["sd_hypernetwork"] = hn.name if hn is not None else 'None' + + def format_value_add_label(p, opt, x): if type(x) == float: x = round(x, 8) @@ -122,6 +127,7 @@ axis_options = [ AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list), AxisOption("Sampler", str, apply_sampler, format_value), AxisOption("Checkpoint name", str, apply_checkpoint, format_value), + AxisOption("Hypernetwork", str, apply_hypernetwork, format_value), AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label), AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label), AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label), @@ -193,6 +199,8 @@ class Script(scripts.Script): modules.processing.fix_seed(p) p.batch_size = 1 + initial_hn = opts.sd_hypernetwork + def process_axis(opt, vals): if opt.label == 'Nothing': return [0] @@ -300,4 +308,6 @@ class Script(scripts.Script): # restore checkpoint in case it was changed by axes modules.sd_models.reload_model_weights(shared.sd_model) + opts.data["sd_hypernetwork"] = initial_hn + return processed