support loading .yaml config with same name as model

support EMA weights in processing (????)
This commit is contained in:
AUTOMATIC 2022-10-08 23:26:48 +03:00
parent 432782163a
commit 050a6a798c
2 changed files with 24 additions and 8 deletions

@ -347,7 +347,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
infotexts = []
output_images = []
with torch.no_grad():
with torch.no_grad(), p.sd_model.ema_scope():
with devices.autocast():
p.init(all_prompts, all_seeds, all_subseeds)

@ -14,7 +14,7 @@ from modules.paths import models_path
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(models_path, model_dir))
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name'])
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name', 'config'])
checkpoints_list = {}
try:
@ -63,14 +63,20 @@ def list_models():
if os.path.exists(cmd_ckpt):
h = model_hash(cmd_ckpt)
title, short_model_name = modeltitle(cmd_ckpt, h)
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name)
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name, shared.cmd_opts.config)
shared.opts.data['sd_model_checkpoint'] = title
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
for filename in model_list:
h = model_hash(filename)
title, short_model_name = modeltitle(filename, h)
checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name)
basename, _ = os.path.splitext(filename)
config = basename + ".yaml"
if not os.path.exists(config):
config = shared.cmd_opts.config
checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name, config)
def get_closet_checkpoint_match(searchString):
@ -116,7 +122,10 @@ def select_checkpoint():
return checkpoint_info
def load_model_weights(model, checkpoint_file, sd_model_hash):
def load_model_weights(model, checkpoint_info):
checkpoint_file = checkpoint_info.filename
sd_model_hash = checkpoint_info.hash
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
pl_sd = torch.load(checkpoint_file, map_location="cpu")
@ -148,15 +157,19 @@ def load_model_weights(model, checkpoint_file, sd_model_hash):
model.sd_model_hash = sd_model_hash
model.sd_model_checkpoint = checkpoint_file
model.sd_checkpoint_info = checkpoint_info
def load_model():
from modules import lowvram, sd_hijack
checkpoint_info = select_checkpoint()
sd_config = OmegaConf.load(shared.cmd_opts.config)
if checkpoint_info.config != shared.cmd_opts.config:
print(f"Loading config from: {shared.cmd_opts.config}")
sd_config = OmegaConf.load(checkpoint_info.config)
sd_model = instantiate_from_config(sd_config.model)
load_model_weights(sd_model, checkpoint_info.filename, checkpoint_info.hash)
load_model_weights(sd_model, checkpoint_info)
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
@ -178,6 +191,9 @@ def reload_model_weights(sd_model, info=None):
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
return
if sd_model.sd_checkpoint_info.config != checkpoint_info.config:
return load_model()
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu()
else:
@ -185,7 +201,7 @@ def reload_model_weights(sd_model, info=None):
sd_hijack.model_hijack.undo_hijack(sd_model)
load_model_weights(sd_model, checkpoint_info.filename, checkpoint_info.hash)
load_model_weights(sd_model, checkpoint_info)
sd_hijack.model_hijack.hijack(sd_model)