shift embedding logic out of textual_inversion

This commit is contained in:
DepFA 2022-10-11 19:50:50 +01:00 committed by GitHub
parent e5fbf5c755
commit 61788c0538
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -7,124 +7,11 @@ import tqdm
import html
import datetime
from PIL import Image,PngImagePlugin,ImageDraw
from ..images import captionImageOverlay
import numpy as np
import base64
import json
import zlib
from PIL import Image,PngImagePlugin
from modules import shared, devices, sd_hijack, processing, sd_models
import modules.textual_inversion.dataset
class EmbeddingEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, torch.Tensor):
return {'TORCHTENSOR':obj.cpu().detach().numpy().tolist()}
return json.JSONEncoder.default(self, obj)
class EmbeddingDecoder(json.JSONDecoder):
def __init__(self, *args, **kwargs):
json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs)
def object_hook(self, d):
if 'TORCHTENSOR' in d:
return torch.from_numpy(np.array(d['TORCHTENSOR']))
return d
def embeddingToB64(data):
d = json.dumps(data,cls=EmbeddingEncoder)
return base64.b64encode(d.encode())
def embeddingFromB64(data):
d = base64.b64decode(data)
return json.loads(d,cls=EmbeddingDecoder)
def lcg(m=2**32, a=1664525, c=1013904223, seed=0):
while True:
seed = (a * seed + c) % m
yield seed
def xorBlock(block):
g = lcg()
randblock = np.array([next(g) for _ in range(np.product(block.shape))]).astype(np.uint8).reshape(block.shape)
return np.bitwise_xor(block.astype(np.uint8),randblock & 0x0F)
def styleBlock(block,sequence):
im = Image.new('RGB',(block.shape[1],block.shape[0]))
draw = ImageDraw.Draw(im)
i=0
for x in range(-6,im.size[0],8):
for yi,y in enumerate(range(-6,im.size[1],8)):
offset=0
if yi%2==0:
offset=4
shade = sequence[i%len(sequence)]
i+=1
draw.ellipse((x+offset, y, x+6+offset, y+6), fill =(shade,shade,shade) )
fg = np.array(im).astype(np.uint8) & 0xF0
return block ^ fg
def insertImageDataEmbed(image,data):
d = 3
data_compressed = zlib.compress( json.dumps(data,cls=EmbeddingEncoder).encode(),level=9)
dnp = np.frombuffer(data_compressed,np.uint8).copy()
dnphigh = dnp >> 4
dnplow = dnp & 0x0F
h = image.size[1]
next_size = dnplow.shape[0] + (h-(dnplow.shape[0]%h))
next_size = next_size + ((h*d)-(next_size%(h*d)))
dnplow.resize(next_size)
dnplow = dnplow.reshape((h,-1,d))
dnphigh.resize(next_size)
dnphigh = dnphigh.reshape((h,-1,d))
edgeStyleWeights = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024]
edgeStyleWeights = (np.abs(edgeStyleWeights)/np.max(np.abs(edgeStyleWeights))*255).astype(np.uint8)
dnplow = styleBlock(dnplow,sequence=edgeStyleWeights)
dnplow = xorBlock(dnplow)
dnphigh = styleBlock(dnphigh,sequence=edgeStyleWeights[::-1])
dnphigh = xorBlock(dnphigh)
imlow = Image.fromarray(dnplow,mode='RGB')
imhigh = Image.fromarray(dnphigh,mode='RGB')
background = Image.new('RGB',(image.size[0]+imlow.size[0]+imhigh.size[0]+2,image.size[1]),(0,0,0))
background.paste(imlow,(0,0))
background.paste(image,(imlow.size[0]+1,0))
background.paste(imhigh,(imlow.size[0]+1+image.size[0]+1,0))
return background
def crop_black(img,tol=0):
mask = (img>tol).all(2)
mask0,mask1 = mask.any(0),mask.any(1)
col_start,col_end = mask0.argmax(),mask.shape[1]-mask0[::-1].argmax()
row_start,row_end = mask1.argmax(),mask.shape[0]-mask1[::-1].argmax()
return img[row_start:row_end,col_start:col_end]
def extractImageDataEmbed(image):
d=3
outarr = crop_black(np.array(image.convert('RGB').getdata()).reshape(image.size[1],image.size[0],d ).astype(np.uint8) ) & 0x0F
blackCols = np.where( np.sum(outarr, axis=(0,2))==0)
if blackCols[0].shape[0] < 2:
print('No Image data blocks found.')
return None
dataBlocklower = outarr[:,:blackCols[0].min(),:].astype(np.uint8)
dataBlockupper = outarr[:,blackCols[0].max()+1:,:].astype(np.uint8)
dataBlocklower = xorBlock(dataBlocklower)
dataBlockupper = xorBlock(dataBlockupper)
dataBlock = (dataBlockupper << 4) | (dataBlocklower)
dataBlock = dataBlock.flatten().tobytes()
data = zlib.decompress(dataBlock)
return json.loads(data,cls=EmbeddingDecoder)
class Embedding:
def __init__(self, vec, name, step=None):
@ -199,10 +86,10 @@ class EmbeddingDatabase:
if filename.upper().endswith('.PNG'):
embed_image = Image.open(path)
if 'sd-ti-embedding' in embed_image.text:
data = embeddingFromB64(embed_image.text['sd-ti-embedding'])
data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
name = data.get('name',name)
else:
data = extractImageDataEmbed(embed_image)
data = extract_image_data_embed(embed_image)
name = data.get('name',name)
else:
data = torch.load(path, map_location="cpu")
@ -393,7 +280,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
info = PngImagePlugin.PngInfo()
data = torch.load(last_saved_file)
info.add_text("sd-ti-embedding", embeddingToB64(data))
info.add_text("sd-ti-embedding", embedding_to_b64(data))
title = "<{}>".format(data.get('name','???'))
checkpoint = sd_models.select_checkpoint()
@ -401,8 +288,8 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
footer_mid = '[{}]'.format(checkpoint.hash)
footer_right = '{}'.format(embedding.step)
captioned_image = captionImageOverlay(image,title,footer_left,footer_mid,footer_right)
captioned_image = insertImageDataEmbed(captioned_image,data)
captioned_image = caption_image_overlay(image,title,footer_left,footer_mid,footer_right)
captioned_image = insert_image_data_embed(captioned_image,data)
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)