diff --git a/modules/devices.py b/modules/devices.py index 30d30b9..f88e807 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -14,3 +14,9 @@ def get_optimal_device(): return torch.device("mps") return cpu + + +def torch_gc(): + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() diff --git a/modules/extras.py b/modules/extras.py index 6aeae6c..40935f9 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -1,7 +1,7 @@ import numpy as np from PIL import Image -from modules import processing, shared, images +from modules import processing, shared, images, devices from modules.shared import opts import modules.gfpgan_model from modules.ui import plaintext_to_html @@ -11,7 +11,7 @@ cached_images = {} def run_extras(image, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility): - processing.torch_gc() + devices.torch_gc() image = image.convert("RGB") info = "" diff --git a/modules/images.py b/modules/images.py index 26c399b..69d149f 100644 --- a/modules/images.py +++ b/modules/images.py @@ -243,16 +243,32 @@ def sanitize_filename_part(text): return text.replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128] -def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, pnginfo_section_name='parameters'): +def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, pnginfo_section_name='parameters', process_info=None): # would be better to add this as an argument in future, but will do for now is_a_grid = basename != "" if short_filename or prompt is None or seed is None: file_decoration = "" elif opts.save_to_dirs: - file_decoration = f"-{seed}" + file_decoration = opts.samples_filename_format or "[SEED]" else: - file_decoration = f"-{seed}-{sanitize_filename_part(prompt)[:128]}" + file_decoration = opts.samples_filename_format or "[SEED]-[PROMPT]" + #file_decoration = f"-{seed}-{sanitize_filename_part(prompt)[:128]}" + + #Add new filenames tags here + file_decoration = "-" + file_decoration + if seed is not None: + file_decoration = file_decoration.replace("[SEED]", str(seed)) + if prompt is not None: + file_decoration = file_decoration.replace("[PROMPT]", sanitize_filename_part(prompt)[:128]) + file_decoration = file_decoration.replace("[PROMPT_SPACES]", prompt.translate({ord(x): '' for x in invalid_filename_chars})[:128]) + if process_info is not None: + file_decoration = file_decoration.replace("[STEPS]", str(process_info.steps)) + file_decoration = file_decoration.replace("[CFG]", str(process_info.cfg_scale)) + file_decoration = file_decoration.replace("[WIDTH]", str(process_info.width)) + file_decoration = file_decoration.replace("[HEIGHT]", str(process_info.height)) + file_decoration = file_decoration.replace("[SAMPLER]", str(process_info.sampler)) + if extension == 'png' and opts.enable_pnginfo and info is not None: pnginfo = PngImagePlugin.PngInfo() diff --git a/modules/img2img.py b/modules/img2img.py index 779f620..7461bad 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -3,6 +3,7 @@ import cv2 import numpy as np from PIL import Image, ImageOps, ImageChops +from modules import devices from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images from modules.shared import opts, state import modules.shared as shared @@ -131,7 +132,7 @@ def img2img(prompt: str, negative_prompt: str, prompt_style: str, init_img, init upscaler = shared.sd_upscalers[upscaler_index] img = upscaler.upscale(init_img, init_img.width * 2, init_img.height * 2) - processing.torch_gc() + devices.torch_gc() grid = images.split_grid(img, tile_w=width, tile_h=height, overlap=upscale_overlap) @@ -179,7 +180,7 @@ def img2img(prompt: str, negative_prompt: str, prompt_style: str, init_img, init result_images.append(combined_image) if opts.samples_save: - images.save_image(combined_image, p.outpath_samples, "", start_seed, prompt, opts.grid_format, info=initial_info) + images.save_image(combined_image, p.outpath_samples, "", start_seed, prompt, opts.samples_format, info=initial_info) processed = Processed(p, result_images, seed, initial_info) diff --git a/modules/interrogate.py b/modules/interrogate.py index ed97a58..7ebb79f 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -1,3 +1,4 @@ +import contextlib import os import sys import traceback @@ -6,7 +7,6 @@ import re import torch -from PIL import Image from torchvision import transforms from torchvision.transforms.functional import InterpolationMode @@ -26,6 +26,7 @@ class InterrogateModels: clip_model = None clip_preprocess = None categories = None + dtype = None def __init__(self, content_dir): self.categories = [] @@ -60,14 +61,20 @@ class InterrogateModels: def load(self): if self.blip_model is None: self.blip_model = self.load_blip_model() + if not shared.cmd_opts.no_half: + self.blip_model = self.blip_model.half() self.blip_model = self.blip_model.to(shared.device) if self.clip_model is None: self.clip_model, self.clip_preprocess = self.load_clip_model() + if not shared.cmd_opts.no_half: + self.clip_model = self.clip_model.half() self.clip_model = self.clip_model.to(shared.device) + self.dtype = next(self.clip_model.parameters()).dtype + def unload(self): if not shared.opts.interrogate_keep_models_in_memory: if self.clip_model is not None: @@ -76,14 +83,14 @@ class InterrogateModels: if self.blip_model is not None: self.blip_model = self.blip_model.to(devices.cpu) + devices.torch_gc() def rank(self, image_features, text_array, top_count=1): import clip top_count = min(top_count, len(text_array)) - text_tokens = clip.tokenize([text for text in text_array]).cuda() - with torch.no_grad(): - text_features = self.clip_model.encode_text(text_tokens).float() + text_tokens = clip.tokenize([text for text in text_array]).to(shared.device) + text_features = self.clip_model.encode_text(text_tokens).type(self.dtype) text_features /= text_features.norm(dim=-1, keepdim=True) similarity = torch.zeros((1, len(text_array))).to(shared.device) @@ -94,13 +101,12 @@ class InterrogateModels: top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1) return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)] - def generate_caption(self, pil_image): gpu_image = transforms.Compose([ transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC), transforms.ToTensor(), transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) - ])(pil_image).unsqueeze(0).to(shared.device) + ])(pil_image).unsqueeze(0).type(self.dtype).to(shared.device) with torch.no_grad(): caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length) @@ -116,22 +122,23 @@ class InterrogateModels: caption = self.generate_caption(pil_image) res = caption - images = self.clip_preprocess(pil_image).unsqueeze(0).to(shared.device) + images = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(shared.device) - with torch.no_grad(): - image_features = self.clip_model.encode_image(images).float() + precision_scope = torch.autocast if shared.cmd_opts.precision == "autocast" else contextlib.nullcontext + with torch.no_grad(), precision_scope("cuda"): + image_features = self.clip_model.encode_image(images).type(self.dtype) - image_features /= image_features.norm(dim=-1, keepdim=True) + image_features /= image_features.norm(dim=-1, keepdim=True) - if shared.opts.interrogate_use_builtin_artists: - artist = self.rank(image_features, ["by " + artist.name for artist in shared.artist_db.artists])[0] + if shared.opts.interrogate_use_builtin_artists: + artist = self.rank(image_features, ["by " + artist.name for artist in shared.artist_db.artists])[0] - res += ", " + artist[0] + res += ", " + artist[0] - for name, topn, items in self.categories: - matches = self.rank(image_features, items, top_count=topn) - for match, score in matches: - res += ", " + match + for name, topn, items in self.categories: + matches = self.rank(image_features, items, top_count=topn) + for match, score in matches: + res += ", " + match except Exception: print(f"Error interrogating", file=sys.stderr) diff --git a/modules/processing.py b/modules/processing.py index cf2e13d..7dc2b9a 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -10,6 +10,7 @@ from PIL import Image, ImageFilter, ImageOps import random import modules.sd_hijack +from modules import devices from modules.sd_hijack import model_hijack from modules.sd_samplers import samplers, samplers_for_img2img from modules.shared import opts, cmd_opts, state @@ -23,11 +24,6 @@ opt_C = 4 opt_f = 8 -def torch_gc(): - if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.ipc_collect() - class StableDiffusionProcessing: def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", prompt_style="None", seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None): @@ -157,7 +153,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" assert p.prompt is not None - torch_gc() + devices.torch_gc() fix_seed(p) @@ -258,7 +254,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: x_sample = x_sample.astype(np.uint8) if p.restore_faces: - torch_gc() + devices.torch_gc() x_sample = modules.face_restoration.restore_faces(x_sample) @@ -279,7 +275,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: image = image.convert('RGB') if opts.samples_save and not p.do_not_save_samples: - images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i)) + images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), process_info = Processed(p, output_images, all_seeds[0], infotext())) output_images.append(image) @@ -297,7 +293,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if opts.grid_save: images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename) - torch_gc() + devices.torch_gc() return Processed(p, output_images, all_seeds[0], infotext()) diff --git a/modules/shared.py b/modules/shared.py index 0557cfe..f9509a7 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -94,6 +94,7 @@ class Options: data = None hide_dirs = {"visible": False} if cmd_opts.hide_ui_dir_config else None data_labels = { + "samples_filename_format": OptionInfo("", "Samples filename format using following tags: [STEPS],[CFG],[PROMPT],[PROMPT_SPACES],[WIDTH],[HEIGHT],[SAMPLER],[SEED]. Leave blank for default."), "outdir_samples": OptionInfo("", "Output directory for images; if empty, defaults to two directories below", component_args=hide_dirs), "outdir_txt2img_samples": OptionInfo("outputs/txt2img-images", 'Output directory for txt2img images', component_args=hide_dirs), "outdir_img2img_samples": OptionInfo("outputs/img2img-images", 'Output directory for img2img images', component_args=hide_dirs), diff --git a/scripts/poor_mans_outpainting.py b/scripts/poor_mans_outpainting.py index c029c67..102aab0 100644 --- a/scripts/poor_mans_outpainting.py +++ b/scripts/poor_mans_outpainting.py @@ -4,7 +4,7 @@ import modules.scripts as scripts import gradio as gr from PIL import Image, ImageDraw -from modules import images, processing +from modules import images, processing, devices from modules.processing import Processed, process_images from modules.shared import opts, cmd_opts, state @@ -77,7 +77,7 @@ class Script(scripts.Script): mask.height - down - (mask_blur//2 if down > 0 else 0) ), fill="black") - processing.torch_gc() + devices.torch_gc() grid = images.split_grid(img, tile_w=p.width, tile_h=p.height, overlap=pixels) grid_mask = images.split_grid(mask, tile_w=p.width, tile_h=p.height, overlap=pixels)