diff --git a/modules/images.py b/modules/images.py index 3399887..064849d 100644 --- a/modules/images.py +++ b/modules/images.py @@ -136,7 +136,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts): color_active = (0, 0, 0) color_inactive = (153, 153, 153) - pad_left = width * 3 // 4 if len(ver_texts) > 0 else 0 + pad_left = 0 if sum([sum([len(line.text) for line in lines]) for lines in ver_texts]) == 0 else width * 3 // 4 cols = im.width // width rows = im.height // height diff --git a/modules/img2img.py b/modules/img2img.py index 00bd626..54023df 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -11,7 +11,7 @@ from modules.ui import plaintext_to_html import modules.images as images import modules.scripts -def img2img(prompt: str, negative_prompt: str, init_img, init_img_with_mask, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, mode: int, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, denoising_strength_change_factor: float, seed: int, height: int, width: int, resize_mode: int, upscaler_index: str, upscale_overlap: int, inpaint_full_res: bool, inpainting_mask_invert: int, *args): +def img2img(prompt: str, negative_prompt: str, init_img, init_img_with_mask, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, mode: int, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, denoising_strength_change_factor: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, height: int, width: int, resize_mode: int, upscaler_index: str, upscale_overlap: int, inpaint_full_res: bool, inpainting_mask_invert: int, *args): is_inpaint = mode == 1 is_loopback = mode == 2 is_upscale = mode == 3 @@ -34,6 +34,10 @@ def img2img(prompt: str, negative_prompt: str, init_img, init_img_with_mask, ste prompt=prompt, negative_prompt=negative_prompt, seed=seed, + subseed=subseed, + subseed_strength=subseed_strength, + seed_resize_from_h=seed_resize_from_h, + seed_resize_from_w=seed_resize_from_w, sampler_index=sampler_index, batch_size=batch_size, n_iter=n_iter, diff --git a/modules/processing.py b/modules/processing.py index d4c4cfa..b91ade1 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -29,7 +29,7 @@ def torch_gc(): class StableDiffusionProcessing: - def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", seed=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None): + def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None): self.sd_model = sd_model self.outpath_samples: str = outpath_samples self.outpath_grids: str = outpath_grids @@ -37,6 +37,10 @@ class StableDiffusionProcessing: self.prompt_for_display: str = None self.negative_prompt: str = (negative_prompt or "") self.seed: int = seed + self.subseed: int = subseed + self.subseed_strength: float = subseed_strength + self.seed_resize_from_h: int = seed_resize_from_h + self.seed_resize_from_w: int = seed_resize_from_w self.sampler_index: int = sampler_index self.batch_size: int = batch_size self.n_iter: int = n_iter @@ -84,23 +88,67 @@ class Processed: return json.dumps(obj) +# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3 +def slerp(val, low, high): + low_norm = low/torch.norm(low, dim=1, keepdim=True) + high_norm = high/torch.norm(high, dim=1, keepdim=True) + omega = torch.acos((low_norm*high_norm).sum(1)) + so = torch.sin(omega) + res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high + return res -def create_random_tensors(shape, seeds): + +def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0): xs = [] - for seed in seeds: - torch.manual_seed(seed) + for i, seed in enumerate(seeds): + noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8) + + subnoise = None + if subseeds is not None: + subseed = 0 if i >= len(subseeds) else subseeds[i] + torch.manual_seed(subseed) + subnoise = torch.randn(noise_shape, device=shared.device) # randn results depend on device; gpu and cpu get different results for same seed; # the way I see it, it's better to do this on CPU, so that everyone gets same result; - # but the original script had it like this so I do not dare change it for now because + # but the original script had it like this, so I do not dare change it for now because # it will break everyone's seeds. - xs.append(torch.randn(shape, device=shared.device)) - x = torch.stack(xs) + torch.manual_seed(seed) + noise = torch.randn(noise_shape, device=shared.device) + + if subnoise is not None: + #noise = subnoise * subseed_strength + noise * (1 - subseed_strength) + noise = slerp(subseed_strength, noise, subnoise) + + if noise_shape != shape: + #noise = torch.nn.functional.interpolate(noise.unsqueeze(1), size=shape[1:], mode="bilinear").squeeze() + # noise_shape = (64, 80) + # shape = (64, 72) + + torch.manual_seed(seed) + x = torch.randn(shape, device=shared.device) + dx = (shape[2] - noise_shape[2]) // 2 # -4 + dy = (shape[1] - noise_shape[1]) // 2 + w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx + h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy + tx = 0 if dx < 0 else dx + ty = 0 if dy < 0 else dy + dx = max(-dx, 0) + dy = max(-dy, 0) + + x[:, ty:ty+h, tx:tx+w] = noise[:, dy:dy+h, dx:dx+w] + noise = x + + + + xs.append(noise) + x = torch.stack(xs).to(shared.device) return x -def set_seed(seed): - return int(random.randrange(4294967294)) if seed is None or seed == -1 else seed +def fix_seed(p): + p.seed = int(random.randrange(4294967294)) if p.seed is None or p.seed == -1 else p.seed + p.subseed = int(random.randrange(4294967294)) if p.subseed is None or p.subseed == -1 else p.subseed def process_images(p: StableDiffusionProcessing) -> Processed: @@ -111,7 +159,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: assert p.prompt is not None torch_gc() - seed = set_seed(p.seed) + fix_seed(p) os.makedirs(p.outpath_samples, exist_ok=True) os.makedirs(p.outpath_grids, exist_ok=True) @@ -125,20 +173,31 @@ def process_images(p: StableDiffusionProcessing) -> Processed: else: all_prompts = p.batch_size * p.n_iter * [prompt] - if type(seed) == list: - all_seeds = seed + if type(p.seed) == list: + all_seeds = int(p.seed) else: - all_seeds = [int(seed + x) for x in range(len(all_prompts))] + all_seeds = [int(p.seed + x) for x in range(len(all_prompts))] + + if type(p.subseed) == list: + all_subseeds = p.subseed + else: + all_subseeds = [int(p.subseed + x) for x in range(len(all_prompts))] def infotext(iteration=0, position_in_batch=0): + index = position_in_batch + iteration * p.batch_size + generation_params = { "Steps": p.steps, "Sampler": samplers[p.sampler_index].name, "CFG scale": p.cfg_scale, - "Seed": all_seeds[position_in_batch + iteration * p.batch_size], + "Seed": all_seeds[index], "Face restoration": (opts.face_restoration_model if p.restore_faces else None), + "Size": f"{p.width}x{p.height}", "Batch size": (None if p.batch_size < 2 else p.batch_size), "Batch pos": (None if p.batch_size < 2 else position_in_batch), + "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]), + "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), + "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), } if p.extra_generation_params is not None: @@ -174,7 +233,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: comments += model_hijack.comments # we manually generate all input noises because each one should have a specific seed - x = create_random_tensors([opt_C, p.height // opt_f, p.width // opt_f], seeds=seeds) + x = create_random_tensors([opt_C, p.height // opt_f, p.width // opt_f], seeds=seeds, subseeds=all_subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w) if p.n_iter > 1: shared.state.job = f"Batch {n+1} out of {p.n_iter}" @@ -231,10 +290,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed: output_images.insert(0, grid) if opts.grid_save: - images.save_image(grid, p.outpath_grids, "grid", seed, all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename) + images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename) torch_gc() - return Processed(p, output_images, seed, infotext()) + return Processed(p, output_images, all_seeds[0], infotext()) class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): diff --git a/modules/shared.py b/modules/shared.py index 280c07f..e577332 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -62,7 +62,6 @@ class State: current_image = None current_image_sampling_step = 0 - def interrupt(self): self.interrupted = True diff --git a/modules/txt2img.py b/modules/txt2img.py index 410a7a7..606421e 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -6,7 +6,7 @@ import modules.processing as processing from modules.ui import plaintext_to_html -def txt2img(prompt: str, negative_prompt: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int, *args): +def txt2img(prompt: str, negative_prompt: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, height: int, width: int, *args): p = StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, @@ -14,6 +14,10 @@ def txt2img(prompt: str, negative_prompt: str, steps: int, sampler_index: int, r prompt=prompt, negative_prompt=negative_prompt, seed=seed, + subseed=subseed, + subseed_strength=subseed_strength, + seed_resize_from_h=seed_resize_from_h, + seed_resize_from_w=seed_resize_from_w, sampler_index=sampler_index, batch_size=batch_size, n_iter=n_iter, diff --git a/modules/ui.py b/modules/ui.py index a2ff660..6784de5 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -192,6 +192,40 @@ def visit(x, func, path=""): func(path + "/" + str(x.label), x) +def create_seed_inputs(): + with gr.Row(): + seed = gr.Number(label='Seed', value=-1) + subseed = gr.Number(label='Variation seed', value=-1, visible=False) + seed_checkbox = gr.Checkbox(label="Extra", elem_id="subseed_show", value=False) + + with gr.Row(): + subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, visible=False) + seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=64, label="Resize seed from height", value=0, visible=False) + seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=64, label="Resize seed from width", value=0, visible=False) + + def change_visiblity(show): + + return { + subseed: gr_show(show), + subseed_strength: gr_show(show), + seed_resize_from_h: gr_show(show), + seed_resize_from_w: gr_show(show), + } + + seed_checkbox.change( + change_visiblity, + inputs=[seed_checkbox], + outputs=[ + subseed, + subseed_strength, + seed_resize_from_h, + seed_resize_from_w + ] + ) + + return seed, subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w + + def create_ui(txt2img, img2img, run_extras, run_pnginfo): with gr.Blocks(analytics_enabled=False) as txt2img_interface: with gr.Row(): @@ -220,7 +254,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) - seed = gr.Number(label='Seed', value=-1) + seed, subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w = create_seed_inputs() with gr.Group(): custom_inputs = modules.scripts.scripts_txt2img.setup_ui(is_img2img=False) @@ -260,6 +294,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): batch_size, cfg_scale, seed, + subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, height, width, ] + custom_inputs, @@ -357,7 +392,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) - seed = gr.Number(label='Seed', value=-1) + seed, subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w = create_seed_inputs() with gr.Group(): custom_inputs = modules.scripts.scripts_img2img.setup_ui(is_img2img=True) @@ -440,6 +475,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): denoising_strength, denoising_strength_change_factor, seed, + subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, height, width, resize_mode, diff --git a/script.js b/script.js index c1143a8..ed37650 100644 --- a/script.js +++ b/script.js @@ -46,6 +46,11 @@ titles = { "Tile overlap": "For SD upscale, how much overlap in pixels should there be between tiles. Tiles overlap so that when they are merged back into one picture, there is no clearly visible seam.", "Roll": "Add a random artist to the prompt.", + + "Variation seed": "Seed of a different picture to be mixed into the generation.", + "Variation strength": "How strong of a variation to produce. At 0, there will be no effect. At 1, you will get the complete picture with variation seed (except for ancestral samplers, where you will just get something).", + "Resize seed from height": "Make an attempt to produce a picture similar to what would have been produced with same seed at specified resolution", + "Resize seed from width": "Make an attempt to produce a picture similar to what would have been produced with same seed at specified resolution", } function gradioApp(){ diff --git a/scripts/prompt_matrix.py b/scripts/prompt_matrix.py index a61d118..8d4a4e7 100644 --- a/scripts/prompt_matrix.py +++ b/scripts/prompt_matrix.py @@ -50,7 +50,7 @@ class Script(scripts.Script): return [put_at_start] def run(self, p, put_at_start): - seed = modules.processing.set_seed(p.seed) + modules.processing.fix_seed(p) original_prompt = p.prompt[0] if type(p.prompt) == list else p.prompt diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index b53c829..4c3f0da 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -2,6 +2,8 @@ from collections import namedtuple from copy import copy import random +import numpy as np + import modules.scripts as scripts import gradio as gr @@ -46,18 +48,27 @@ def format_value_add_label(p, opt, x): def format_value(p, opt, x): return x +def do_nothing(p, x, xs): + pass + +def format_nothing(p, opt, x): + return "" + AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value"]) AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value"]) axis_options = [ + AxisOption("Nothing", str, do_nothing, format_nothing), AxisOption("Seed", int, apply_field("seed"), format_value_add_label), + AxisOption("Var. seed", int, apply_field("subseed"), format_value_add_label), + AxisOption("Var. strength", float, apply_field("subseed_strength"), format_value_add_label), AxisOption("Steps", int, apply_field("steps"), format_value_add_label), AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label), AxisOption("Prompt S/R", str, apply_prompt, format_value), AxisOption("Sampler", str, apply_sampler, format_value), - AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label) # as it is now all AxisOptionImg2Img items must go after AxisOption ones + AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label), # as it is now all AxisOptionImg2Img items must go after AxisOption ones ] @@ -90,6 +101,7 @@ def draw_xy_grid(xs, ys, x_label, y_label, cell): re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*") +re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\(([+-]\d+(?:.\d*)?)\s*\))?\s*") class Script(scripts.Script): def title(self): @@ -99,17 +111,17 @@ class Script(scripts.Script): current_axis_options = [x for x in axis_options if type(x) == AxisOption or type(x) == AxisOptionImg2Img and is_img2img] with gr.Row(): - x_type = gr.Dropdown(label="X type", choices=[x.label for x in current_axis_options], value=current_axis_options[0].label, visible=False, type="index", elem_id="x_type") + x_type = gr.Dropdown(label="X type", choices=[x.label for x in current_axis_options], value=current_axis_options[1].label, visible=False, type="index", elem_id="x_type") x_values = gr.Textbox(label="X values", visible=False, lines=1) with gr.Row(): - y_type = gr.Dropdown(label="Y type", choices=[x.label for x in current_axis_options], value=current_axis_options[1].label, visible=False, type="index", elem_id="y_type") + y_type = gr.Dropdown(label="Y type", choices=[x.label for x in current_axis_options], value=current_axis_options[4].label, visible=False, type="index", elem_id="y_type") y_values = gr.Textbox(label="Y values", visible=False, lines=1) return [x_type, x_values, y_type, y_values] def run(self, p, x_type, x_values, y_type, y_values): - p.seed = modules.processing.set_seed(p.seed) + modules.processing.fix_seed(p) p.batch_size = 1 p.batch_count = 1 @@ -132,6 +144,21 @@ class Script(scripts.Script): valslist_ext.append(val) valslist = valslist_ext + elif opt.type == float: + valslist_ext = [] + + for val in valslist: + m = re_range_float.fullmatch(val) + if m is not None: + start = float(m.group(1)) + end = float(m.group(2)) + step = float(m.group(3)) if m.group(3) is not None else 1 + + valslist_ext += np.arange(start, end + step, step).tolist() + else: + valslist_ext.append(val) + + valslist = valslist_ext valslist = [opt.type(x) for x in valslist] diff --git a/style.css b/style.css index 0ae72cc..9f847e4 100644 --- a/style.css +++ b/style.css @@ -5,6 +5,15 @@ max-width: 13em; } +#subseed_show{ + min-width: 6em; + max-width: 6em; +} + +#subseed_show label{ + height: 100%; +} + #txt2img_roll{ min-width: 1em; max-width: 4em;