fixed lora
parent
9247eaa942
commit
25b807eaec
|
|
@ -1,4 +1,6 @@
|
|||
import json
|
||||
import logging
|
||||
import re
|
||||
|
||||
import requests
|
||||
|
||||
|
|
@ -13,16 +15,28 @@ logger.setLevel(utils.LOGGING_LEVEL)
|
|||
|
||||
class SimpleSagemakerInfer(InferManager):
|
||||
|
||||
def parse_lora(self, json_string: str, models):
|
||||
|
||||
prompt = json.loads(json_string)['prompt']
|
||||
matches = re.findall(r"<lora:([^:>]+)", prompt)
|
||||
lora_list = []
|
||||
for match in matches:
|
||||
lora_list.append(f"{match}.safetensors")
|
||||
|
||||
models['Lora'] = lora_list
|
||||
|
||||
return models
|
||||
|
||||
def run(self, userid, models, sd_param, is_txt2img, endpoint_type):
|
||||
# finished construct api payload
|
||||
sd_api_param_json = _parse_api_param_to_json(api_param=sd_param)
|
||||
models = self.parse_lora(sd_api_param_json, models)
|
||||
|
||||
if logging.getLogger().getEffectiveLevel() == logging.DEBUG:
|
||||
# debug only, may delete later
|
||||
with open(f'api_{"txt2img" if is_txt2img else "img2img"}_param.json', 'w') as f:
|
||||
f.write(sd_api_param_json)
|
||||
|
||||
print(sd_api_param_json)
|
||||
|
||||
# create an inference and upload to s3
|
||||
# Start creating model on cloud.
|
||||
url = get_variable_from_json('api_gateway_url')
|
||||
|
|
|
|||
|
|
@ -1408,7 +1408,7 @@ def trainings_tab():
|
|||
default_config = {
|
||||
"output_name": "model_name",
|
||||
"save_every_n_epochs": 1,
|
||||
"max_train_epochs": 10,
|
||||
"max_train_epochs": 5,
|
||||
}
|
||||
config_params = gr.TextArea(value=json.dumps(default_config, indent=4),
|
||||
label="config_params",
|
||||
|
|
|
|||
Loading…
Reference in New Issue