diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 17faeab..a118399 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -3,6 +3,7 @@ import numpy as np import torch import tqdm from PIL import Image +import inspect import k_diffusion.sampling import ldm.models.diffusion.ddim @@ -38,11 +39,11 @@ samplers = [ samplers_for_img2img = [x for x in samplers if x.name != 'PLMS'] sampler_extra_params = { - 'sample_euler':['s_churn','s_tmin','s_tmax','s_noise'], - 'sample_euler_ancestral':['eta'], - 'sample_heun' :['s_churn','s_tmin','s_tmax','s_noise'], - 'sample_dpm_2':['s_churn','s_tmin','s_tmax','s_noise'], - 'sample_dpm_2_ancestral':['eta'], + 'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'], + 'sample_euler_ancestral': ['eta'], + 'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'], + 'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'], + 'sample_dpm_2_ancestral': ['eta'], } def setup_img2img_steps(p, steps=None): @@ -231,7 +232,7 @@ class KDiffusionSampler: self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model, quantize=shared.opts.enable_quantization) self.funcname = funcname self.func = getattr(k_diffusion.sampling, self.funcname) - self.extra_params = sampler_extra_params.get(funcname,[]) + self.extra_params = sampler_extra_params.get(funcname, []) self.model_wrap_cfg = CFGDenoiser(self.model_wrap) self.sampler_noises = None self.sampler_noise_index = 0 @@ -278,9 +279,9 @@ class KDiffusionSampler: k_diffusion.sampling.torch = TorchHijack(self) extra_params_kwargs = {} - for val in self.extra_params: - if hasattr(p,val): - extra_params_kwargs[val] = getattr(p,val) + for param_name in self.extra_params: + if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters: + extra_params_kwargs[param_name] = getattr(p, param_name) return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs) @@ -300,9 +301,9 @@ class KDiffusionSampler: k_diffusion.sampling.torch = TorchHijack(self) extra_params_kwargs = {} - for val in self.extra_params: - if hasattr(p,val): - extra_params_kwargs[val] = getattr(p,val) + for param_name in self.extra_params: + if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters: + extra_params_kwargs[param_name] = getattr(p, param_name) samples = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)