fixed lora

pull/704/head
Jingyi 2024-04-23 19:14:56 +08:00
parent 9247eaa942
commit 25b807eaec
2 changed files with 17 additions and 3 deletions

View File

@ -1,4 +1,6 @@
import json
import logging
import re
import requests
@ -13,16 +15,28 @@ logger.setLevel(utils.LOGGING_LEVEL)
class SimpleSagemakerInfer(InferManager):
def parse_lora(self, json_string: str, models):
prompt = json.loads(json_string)['prompt']
matches = re.findall(r"<lora:([^:>]+)", prompt)
lora_list = []
for match in matches:
lora_list.append(f"{match}.safetensors")
models['Lora'] = lora_list
return models
def run(self, userid, models, sd_param, is_txt2img, endpoint_type):
# finished construct api payload
sd_api_param_json = _parse_api_param_to_json(api_param=sd_param)
models = self.parse_lora(sd_api_param_json, models)
if logging.getLogger().getEffectiveLevel() == logging.DEBUG:
# debug only, may delete later
with open(f'api_{"txt2img" if is_txt2img else "img2img"}_param.json', 'w') as f:
f.write(sd_api_param_json)
print(sd_api_param_json)
# create an inference and upload to s3
# Start creating model on cloud.
url = get_variable_from_json('api_gateway_url')

View File

@ -1408,7 +1408,7 @@ def trainings_tab():
default_config = {
"output_name": "model_name",
"save_every_n_epochs": 1,
"max_train_epochs": 10,
"max_train_epochs": 5,
}
config_params = gr.TextArea(value=json.dumps(default_config, indent=4),
label="config_params",