feat: Pair text with random ckpt
Useful for inserting text into a prompt that Dreambooth models were trained on. Eventually should revisit this to allow for some sort of way to specify where to insert text into the prompt. (Or not. I personally always put it at the front with a ,)
This commit is contained in:
parent
f4372062fd
commit
9137aa7bc9
@ -8,6 +8,11 @@ from modules.shared import opts
|
|||||||
from scripts.xy_grid import build_samplers_dict
|
from scripts.xy_grid import build_samplers_dict
|
||||||
|
|
||||||
class RandomizeScript(scripts.Script):
|
class RandomizeScript(scripts.Script):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.randomize_prompt_word = ''
|
||||||
|
|
||||||
def title(self):
|
def title(self):
|
||||||
return 'Randomize'
|
return 'Randomize'
|
||||||
|
|
||||||
@ -61,6 +66,7 @@ class RandomizeScript(scripts.Script):
|
|||||||
if param == 'CLIP_stop_at_last_layers':
|
if param == 'CLIP_stop_at_last_layers':
|
||||||
opts.data[param] = int(self._opt({param: val}, p)) # type: ignore
|
opts.data[param] = int(self._opt({param: val}, p)) # type: ignore
|
||||||
if param == 'sd_model_checkpoint':
|
if param == 'sd_model_checkpoint':
|
||||||
|
# TODO (mmaker): Trigger changing the ckpt dropdown value in the UI
|
||||||
sd_model_checkpoint = self._opt({param: val}, p)
|
sd_model_checkpoint = self._opt({param: val}, p)
|
||||||
if sd_model_checkpoint:
|
if sd_model_checkpoint:
|
||||||
sd_models.reload_model_weights(shared.sd_model, sd_model_checkpoint)
|
sd_models.reload_model_weights(shared.sd_model, sd_model_checkpoint)
|
||||||
@ -80,6 +86,13 @@ class RandomizeScript(scripts.Script):
|
|||||||
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
|
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
|
||||||
except (TypeError, IndexError) as exception:
|
except (TypeError, IndexError) as exception:
|
||||||
print(f'Failed to utilize highres. fix -- incorrect value?', exception)
|
print(f'Failed to utilize highres. fix -- incorrect value?', exception)
|
||||||
|
|
||||||
|
# Checkpoint specific
|
||||||
|
# TODO (mmaker): Allow for some way to format how this is inserted into the prompt
|
||||||
|
if len(self.randomize_prompt_word) > 0:
|
||||||
|
p.prompt = self.randomize_prompt_word + ', ' + p.prompt
|
||||||
|
if p.all_prompts and len(p.all_prompts) > 0:
|
||||||
|
p.all_prompts = [self.randomize_prompt_word + ', ' + prompt for prompt in p.all_prompts] # type: ignore
|
||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -107,7 +120,14 @@ class RandomizeScript(scripts.Script):
|
|||||||
elif opt_name == 'seed':
|
elif opt_name == 'seed':
|
||||||
return int(random.choice(opt_arr))
|
return int(random.choice(opt_arr))
|
||||||
elif opt_name == 'sd_model_checkpoint':
|
elif opt_name == 'sd_model_checkpoint':
|
||||||
return sd_models.get_closet_checkpoint_match(random.choice(opt_arr))
|
choice = random.choice(opt_arr)
|
||||||
|
if ':' in choice:
|
||||||
|
ckpt_name = choice.split(':')[0].strip()
|
||||||
|
self.randomize_prompt_word = choice.split(':')[1].strip()
|
||||||
|
else:
|
||||||
|
ckpt_name = choice
|
||||||
|
self.randomize_prompt_word = ''
|
||||||
|
return sd_models.get_closet_checkpoint_match(ckpt_name)
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user