Improve FakeModel

main
natanjunges 2023-01-12 16:15:11 -03:00
parent 99e4fb87c2
commit 97b74e13fd
1 changed files with 10 additions and 5 deletions

View File

@ -14,7 +14,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from modules import scripts, processing, shared, images, devices, ui
from modules import scripts, processing, shared, images, devices, ui, sd_models
import gradio
import requests
import time
@ -32,6 +32,9 @@ settings_file = os.path.join(scripts.basedir(), "settings.json")
class FakeModel:
sd_model_hash=""
def __init__(self, name):
self.sd_checkpoint_info = sd_models.CheckpointInfo("", "", "", name)
class StableHordeError(Exception):
pass
@ -210,6 +213,8 @@ class Main(scripts.Script):
def process_images_inner(self, p, model, nsfw, shared_laion, seed_variation, post_processing):
# Copyright (C) 2022 AUTOMATIC1111
fake_model = FakeModel(model)
if type(p.prompt) == list:
assert(len(p.prompt) > 0)
else:
@ -241,14 +246,14 @@ class Main(scripts.Script):
def infotext(iteration=0, position_in_batch=0):
old_model = shared.sd_model
shared.sd_model = FakeModel
shared.sd_model = fake_model
ret = processing.create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, {}, iteration, position_in_batch)
shared.sd_model = old_model
return ret
with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
old_model = shared.sd_model
shared.sd_model = FakeModel
shared.sd_model = fake_model
processed = processing.Processed(p, [], p.seed, "")
file.write(processed.infotext(p, 0))
shared.sd_model = old_model
@ -349,7 +354,7 @@ class Main(scripts.Script):
devices.torch_gc()
old_model = shared.sd_model
shared.sd_model = FakeModel
shared.sd_model = fake_model
res = processing.Processed(p, output_images, p.all_seeds[0], infotext(), subseed=-1, index_of_first_image=index_of_first_image, infotexts=infotexts)
shared.sd_model = old_model
@ -447,7 +452,7 @@ class Main(scripts.Script):
time.sleep(1)
except requests.Timeout:
time.sleep(1)
except AssertionError as e:
except AssertionError:
id = id.json()
raise StableHordeError(id["message"])