import torch from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from transformers import AutoFeatureExtractor from PIL import Image import modules.shared as shared safety_model_id = "CompVis/stable-diffusion-safety-checker" safety_feature_extractor = None safety_checker = None def numpy_to_pil(images): """ Convert a numpy image or a batch of images to a PIL image. """ if images.ndim == 3: images = images[None, ...] images = (images * 255).round().astype("uint8") pil_images = [Image.fromarray(image) for image in images] return pil_images # check and replace nsfw content def check_safety(x_image): global safety_feature_extractor, safety_checker if safety_feature_extractor is None: safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id) safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id) safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt") x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values) return x_checked_image, has_nsfw_concept def censor_batch(x): x_samples_ddim_numpy = x.cpu().permute(0, 2, 3, 1).numpy() x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim_numpy) x = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2) return x