From d63dbb3acc16a737711fddb3a954238f82deb2fa Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Sun, 18 Sep 2022 01:05:31 +0300 Subject: [PATCH 1/2] Move scale multiplication to the front --- modules/sd_hijack.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 6541451..d819c01 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -20,7 +20,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None): q = self.to_q(x) context = default(context, x) - k = self.to_k(context) + k = self.to_k(context) * self.scale v = self.to_v(context) del context, x @@ -85,7 +85,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None): slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] for i in range(0, q.shape[1], slice_size): end = i + slice_size - s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale + s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) s2 = s1.softmax(dim=-1, dtype=q.dtype) del s1 From 18d6fe4346e2543522cd2a64c71207e45632a46b Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Sun, 18 Sep 2022 01:21:50 +0300 Subject: [PATCH 2/2] ..... --- modules/sd_hijack.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index d819c01..c4450ce 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -20,7 +20,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None): q = self.to_q(x) context = default(context, x) - k = self.to_k(context) * self.scale + k = self.to_k(context) v = self.to_v(context) del context, x @@ -50,7 +50,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None): q_in = self.to_q(x) context = default(context, x) - k_in = self.to_k(context) + k_in = self.to_k(context) * self.scale v_in = self.to_v(context) del context, x