diff --git a/modules/merging/merge.py b/modules/merging/merge.py index fe0097da9..ae68fba15 100644 --- a/modules/merging/merge.py +++ b/modules/merging/merge.py @@ -126,7 +126,7 @@ def load_thetas( def merge_models( models: Dict[str, os.PathLike | str], merge_mode: str, - precision: str = "full", + precision: str = "fp16", weights_clip: bool = False, re_basin: bool = False, device: str = "cpu", @@ -242,7 +242,7 @@ def simple_merge( continue if "model" in key and key not in thetas["model_a"].keys(): thetas["model_a"].update({key: thetas["model_b"][key]}) - if precision == 16: + if precision == "fp16": thetas["model_a"].update({key: thetas["model_a"][key].half()}) log_vram("after stage 2") @@ -291,7 +291,7 @@ def rebasin_merge( thetas["model_a"], max_iter=it, init_perm=None, - usefp16=precision == 16, + usefp16=precision == "fp16", device=device, ) @@ -307,7 +307,7 @@ def rebasin_merge( thetas["model_a"], max_iter=it, init_perm=None, - usefp16=precision == 16, + usefp16=precision == "fp16", device=device, ) @@ -343,7 +343,7 @@ def merge_key( thetas: Dict, weight_matcher: WeightClass, merge_mode: str, - precision: int = 16, + precision: str = "fp16", weights_clip: bool = False, device: str = "cpu", work_device: Optional[str] = None, @@ -378,7 +378,7 @@ def merge_key( if weights_clip: merged_key = clip_weights_key(thetas, merged_key, key) - if precision == 16: + if precision == "fp16": merged_key = merged_key.half() return merged_key