pull/2585/head
AI-Casanova 2023-12-03 17:20:44 -06:00
parent 323e2c142c
commit b59f79657d
2 changed files with 4 additions and 4 deletions

View File

@ -36,7 +36,7 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
te_multiplier = float(params.positional[1]) if len(params.positional) > 1 else 1.0 te_multiplier = float(params.positional[1]) if len(params.positional) > 1 else 1.0
te_multiplier = float(params.named.get("te", te_multiplier)) te_multiplier = float(params.named.get("te", te_multiplier))
unet_multiplier = [float(params.positional[2]) if len(params.positional) > 2 else te_multiplier] * 3 unet_multiplier = [float(params.positional[2]) if len(params.positional) > 2 else te_multiplier] * 3
unet_multiplier = [float(params.named.get("unet", unet_multiplier))] * 3 unet_multiplier = [float(params.named.get("unet", unet_multiplier[0]))] * 3
unet_multiplier[0] = float(params.named.get("in", unet_multiplier[0])) unet_multiplier[0] = float(params.named.get("in", unet_multiplier[0]))
unet_multiplier[1] = float(params.named.get("mid", unet_multiplier[1])) unet_multiplier[1] = float(params.named.get("mid", unet_multiplier[1]))
unet_multiplier[2] = float(params.named.get("out", unet_multiplier[2])) unet_multiplier[2] = float(params.named.get("out", unet_multiplier[2]))

View File

@ -112,11 +112,11 @@ class NetworkModule:
def multiplier(self): def multiplier(self):
if 'transformer' in self.sd_key[:20]: if 'transformer' in self.sd_key[:20]:
return self.network.te_multiplier return self.network.te_multiplier
if "input_blocks" in self.sd_key: if "down_blocks" in self.sd_key:
return self.network.unet_multiplier[0] return self.network.unet_multiplier[0]
if "middle_block" in self.sd_key: if "mid_block" in self.sd_key:
return self.network.unet_multiplier[1] return self.network.unet_multiplier[1]
if "output_blocks" in self.sd_key: if "up_blocks" in self.sd_key:
return self.network.unet_multiplier[2] return self.network.unet_multiplier[2]
else: else:
return self.network.unet_multiplier[0] return self.network.unet_multiplier[0]