feat: Pair text with random ckpt

Useful for inserting text into a prompt that Dreambooth models were
trained on. Eventually should revisit this to allow for some
sort of way to specify where to insert text into the prompt.
(Or not. I personally always put it at the front with a ,)
This commit is contained in:
MMaker 2022-11-17 18:28:53 -05:00
parent f4372062fd
commit 9137aa7bc9
No known key found for this signature in database
GPG Key ID: CCE79B8FEDA40FB2

@ -8,6 +8,11 @@ from modules.shared import opts
from scripts.xy_grid import build_samplers_dict
class RandomizeScript(scripts.Script):
def __init__(self) -> None:
super().__init__()
self.randomize_prompt_word = ''
def title(self):
return 'Randomize'
@ -61,6 +66,7 @@ class RandomizeScript(scripts.Script):
if param == 'CLIP_stop_at_last_layers':
opts.data[param] = int(self._opt({param: val}, p)) # type: ignore
if param == 'sd_model_checkpoint':
# TODO (mmaker): Trigger changing the ckpt dropdown value in the UI
sd_model_checkpoint = self._opt({param: val}, p)
if sd_model_checkpoint:
sd_models.reload_model_weights(shared.sd_model, sd_model_checkpoint)
@ -80,6 +86,13 @@ class RandomizeScript(scripts.Script):
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
except (TypeError, IndexError) as exception:
print(f'Failed to utilize highres. fix -- incorrect value?', exception)
# Checkpoint specific
# TODO (mmaker): Allow for some way to format how this is inserted into the prompt
if len(self.randomize_prompt_word) > 0:
p.prompt = self.randomize_prompt_word + ', ' + p.prompt
if p.all_prompts and len(p.all_prompts) > 0:
p.all_prompts = [self.randomize_prompt_word + ', ' + prompt for prompt in p.all_prompts] # type: ignore
else:
return
@ -107,7 +120,14 @@ class RandomizeScript(scripts.Script):
elif opt_name == 'seed':
return int(random.choice(opt_arr))
elif opt_name == 'sd_model_checkpoint':
return sd_models.get_closet_checkpoint_match(random.choice(opt_arr))
choice = random.choice(opt_arr)
if ':' in choice:
ckpt_name = choice.split(':')[0].strip()
self.randomize_prompt_word = choice.split(':')[1].strip()
else:
ckpt_name = choice
self.randomize_prompt_word = ''
return sd_models.get_closet_checkpoint_match(ckpt_name)
else:
return None