Merge pull request #3970 from evshiron/fix/progress-api

fix broken progress api and current image compatibility
This commit is contained in:
AUTOMATIC1111 2022-11-02 12:06:12 +03:00 committed by GitHub
commit e526f6b378
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 14 additions and 2 deletions

@ -5,10 +5,9 @@ import uvicorn
from gradio.processing_utils import decode_base64_to_file, decode_base64_to_image from gradio.processing_utils import decode_base64_to_file, decode_base64_to_image
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
import modules.shared as shared import modules.shared as shared
from modules import devices
from modules.api.models import * from modules.api.models import *
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.sd_samplers import all_samplers from modules.sd_samplers import all_samplers, sample_to_image, samples_to_image_grid
from modules.extras import run_extras, run_pnginfo from modules.extras import run_extras, run_pnginfo
@ -179,6 +178,16 @@ class Api:
progress = min(progress, 1) progress = min(progress, 1)
# copy from check_progress_call of ui.py
if shared.parallel_processing_allowed:
if shared.state.sampling_step - shared.state.current_image_sampling_step >= shared.opts.show_progress_every_n_steps and shared.state.current_latent is not None:
if shared.opts.show_progress_grid:
shared.state.current_image = samples_to_image_grid(shared.state.current_latent)
else:
shared.state.current_image = sample_to_image(shared.state.current_latent)
shared.state.current_image_sampling_step = shared.state.sampling_step
current_image = None current_image = None
if shared.state.current_image and not req.skip_current_image: if shared.state.current_image and not req.skip_current_image:
current_image = encode_pil_to_base64(shared.state.current_image) current_image = encode_pil_to_base64(shared.state.current_image)

@ -4,6 +4,7 @@ import json
import os import os
import sys import sys
from collections import OrderedDict from collections import OrderedDict
import time
import gradio as gr import gradio as gr
import tqdm import tqdm
@ -135,6 +136,7 @@ class State:
current_image = None current_image = None
current_image_sampling_step = 0 current_image_sampling_step = 0
textinfo = None textinfo = None
time_start = None
need_restart = False need_restart = False
def skip(self): def skip(self):
@ -172,6 +174,7 @@ class State:
self.skipped = False self.skipped = False
self.interrupted = False self.interrupted = False
self.textinfo = None self.textinfo = None
self.time_start = time.time()
devices.torch_gc() devices.torch_gc()