diff --git a/extract_controlnet.py b/extract_controlnet.py index d8752b5..52e21c4 100644 --- a/extract_controlnet.py +++ b/extract_controlnet.py @@ -6,6 +6,7 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--src", default=None, type=str, required=True, help="Path to the model to convert.") parser.add_argument("--dst", default=None, type=str, required=True, help="Path to the output model.") + parser.add_argument("--half", action="store_true", help="Cast to FP16.") args = parser.parse_args() assert args.src is not None, "Must provide a model path!" @@ -15,9 +16,10 @@ if __name__ == "__main__": state_dict = load_file(args.src) else: state_dict = torch.load(args.src) - + if any([k.startswith("control_model.") for k, v in state_dict.items()]): - state_dict = {k.replace("control_model.", ""): v for k, v in state_dict.items() if k.startswith("control_model.")} + dtype = torch.float16 if args.half else torch.float32 + state_dict = {k.replace("control_model.", ""): v.to(dtype) for k, v in state_dict.items() if k.startswith("control_model.")} if args.dst.endswith(".safetensors"): save_file(state_dict, args.dst)