fix precision

pull/2443/head
AI-Casanova 2023-11-16 19:31:36 -06:00
parent 26f6e717f6
commit 6c34317a57
1 changed files with 6 additions and 6 deletions

View File

@ -126,7 +126,7 @@ def load_thetas(
def merge_models(
models: Dict[str, os.PathLike | str],
merge_mode: str,
precision: str = "full",
precision: str = "fp16",
weights_clip: bool = False,
re_basin: bool = False,
device: str = "cpu",
@ -242,7 +242,7 @@ def simple_merge(
continue
if "model" in key and key not in thetas["model_a"].keys():
thetas["model_a"].update({key: thetas["model_b"][key]})
if precision == 16:
if precision == "fp16":
thetas["model_a"].update({key: thetas["model_a"][key].half()})
log_vram("after stage 2")
@ -291,7 +291,7 @@ def rebasin_merge(
thetas["model_a"],
max_iter=it,
init_perm=None,
usefp16=precision == 16,
usefp16=precision == "fp16",
device=device,
)
@ -307,7 +307,7 @@ def rebasin_merge(
thetas["model_a"],
max_iter=it,
init_perm=None,
usefp16=precision == 16,
usefp16=precision == "fp16",
device=device,
)
@ -343,7 +343,7 @@ def merge_key(
thetas: Dict,
weight_matcher: WeightClass,
merge_mode: str,
precision: int = 16,
precision: str = "fp16",
weights_clip: bool = False,
device: str = "cpu",
work_device: Optional[str] = None,
@ -378,7 +378,7 @@ def merge_key(
if weights_clip:
merged_key = clip_weights_key(thetas, merged_key, key)
if precision == 16:
if precision == "fp16":
merged_key = merged_key.half()
return merged_key