mirror of https://github.com/vladmandic/automatic
add sampler api endpoints
Signed-off-by: vladmandic <mandic00@live.com>pull/4619/head
parent
d7ca4f63a7
commit
d9a2a21c8c
|
|
@ -39,6 +39,9 @@ For full list of changes, see full changelog.
|
||||||
- **API**
|
- **API**
|
||||||
- add `/sdapi/v1/xyz-grid` to enumerate xyz-grid axis options and their choices
|
- add `/sdapi/v1/xyz-grid` to enumerate xyz-grid axis options and their choices
|
||||||
see `/cli/api-xyzenum.py` for example usage
|
see `/cli/api-xyzenum.py` for example usage
|
||||||
|
- add `/sdapi/v1/sampler` to get current sampler config
|
||||||
|
- modify `/sdapi/v1/samplers` to enumerate available samplers possible options
|
||||||
|
see `/cli/api-samplers.py` for example usage
|
||||||
- **Internal**
|
- **Internal**
|
||||||
- tagged release history: <https://github.com/vladmandic/sdnext/tags>
|
- tagged release history: <https://github.com/vladmandic/sdnext/tags>
|
||||||
each major for the past year is now tagged for easier reference
|
each major for the past year is now tagged for easier reference
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,35 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
"""
|
||||||
|
get list of all samplers and details of current sampler
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import logging
|
||||||
|
import urllib3
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
url = "http://127.0.0.1:7860"
|
||||||
|
user = ""
|
||||||
|
password = ""
|
||||||
|
|
||||||
|
log_format = '%(asctime)s %(levelname)s: %(message)s'
|
||||||
|
logging.basicConfig(level = logging.INFO, format = log_format)
|
||||||
|
log = logging.getLogger("sd")
|
||||||
|
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||||
|
|
||||||
|
log.info('available samplers')
|
||||||
|
auth = requests.auth.HTTPBasicAuth(user, password) if len(user) > 0 and len(password) > 0 else None
|
||||||
|
req = requests.get(f'{url}/sdapi/v1/samplers', verify=False, auth=auth, timeout=60)
|
||||||
|
if req.status_code != 200:
|
||||||
|
log.error({ 'url': req.url, 'request': req.status_code, 'reason': req.reason })
|
||||||
|
exit(1)
|
||||||
|
res = req.json()
|
||||||
|
for item in res:
|
||||||
|
log.info(item)
|
||||||
|
|
||||||
|
log.info('current sampler')
|
||||||
|
req = requests.get(f'{url}/sdapi/v1/sampler', verify=False, auth=auth, timeout=60)
|
||||||
|
res = req.json()
|
||||||
|
log.info(res)
|
||||||
|
|
@ -1,12 +1,11 @@
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import inspect
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
# Ensure we can import modules
|
# Ensure we can import modules
|
||||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")))
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
|
||||||
|
|
||||||
from modules.errors import log
|
from modules.errors import log
|
||||||
from modules.res4lyf import (
|
from modules.res4lyf import (
|
||||||
|
|
@ -20,12 +19,21 @@ from modules.res4lyf import (
|
||||||
BongTangentScheduler, CommonSigmaScheduler, RadauIIAScheduler,
|
BongTangentScheduler, CommonSigmaScheduler, RadauIIAScheduler,
|
||||||
LangevinDynamicsScheduler
|
LangevinDynamicsScheduler
|
||||||
)
|
)
|
||||||
|
from modules.schedulers.scheduler_vdm import VDMScheduler
|
||||||
|
from modules.schedulers.scheduler_unipc_flowmatch import FlowUniPCMultistepScheduler
|
||||||
|
from modules.schedulers.scheduler_ufogen import UFOGenScheduler
|
||||||
|
from modules.schedulers.scheduler_tdd import TDDScheduler
|
||||||
|
from modules.schedulers.scheduler_tcd import TCDScheduler
|
||||||
|
from modules.schedulers.scheduler_flashflow import FlashFlowMatchEulerDiscreteScheduler
|
||||||
|
from modules.schedulers.scheduler_dpm_flowmatch import FlowMatchDPMSolverMultistepScheduler
|
||||||
|
from modules.schedulers.scheduler_dc import DCSolverMultistepScheduler
|
||||||
|
from modules.schedulers.scheduler_bdia import BDIA_DDIMScheduler
|
||||||
|
|
||||||
def test_scheduler(name, scheduler_class, config):
|
def test_scheduler(name, scheduler_class, config):
|
||||||
try:
|
try:
|
||||||
scheduler = scheduler_class(**config)
|
scheduler = scheduler_class(**config)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f'scheduler="{name}" cls={scheduler.__class__.__name__} config={config} error="Init failed: {e}"')
|
log.error(f'scheduler="{name}" cls={scheduler_class} config={config} error="Init failed: {e}"')
|
||||||
return False
|
return False
|
||||||
|
|
||||||
num_steps = 20
|
num_steps = 20
|
||||||
|
|
@ -42,12 +50,19 @@ def test_scheduler(name, scheduler_class, config):
|
||||||
model_output = torch.randn_like(sample)
|
model_output = torch.randn_like(sample)
|
||||||
|
|
||||||
# Scaling Check
|
# Scaling Check
|
||||||
sigma = scheduler.sigmas[scheduler.step_index] if scheduler.step_index is not None else scheduler.sigmas[0] # Handle potential index mismatch if step_index is updated differently, usually step_index matches i for these tests
|
step_idx = scheduler.step_index if hasattr(scheduler, "step_index") and scheduler.step_index is not None else i
|
||||||
|
# Clamp index
|
||||||
|
if hasattr(scheduler, 'sigmas'):
|
||||||
|
step_idx = min(step_idx, len(scheduler.sigmas) - 1)
|
||||||
|
sigma = scheduler.sigmas[step_idx]
|
||||||
|
else:
|
||||||
|
sigma = torch.tensor(1.0) # Dummy for non-sigma schedulers
|
||||||
|
|
||||||
# Re-introduce scaling calculation first
|
# Re-introduce scaling calculation first
|
||||||
scaled_sample = scheduler.scale_model_input(sample, t)
|
scaled_sample = scheduler.scale_model_input(sample, t)
|
||||||
|
|
||||||
if config.get("prediction_type") == "flow_prediction":
|
if config.get("prediction_type") == "flow_prediction" or name in ["UFOGenScheduler", "TDDScheduler", "TCDScheduler", "BDIA_DDIMScheduler", "DCSolverMultistepScheduler"]:
|
||||||
|
# Some new schedulers don't use K-diffusion scaling
|
||||||
expected_scale = 1.0
|
expected_scale = 1.0
|
||||||
else:
|
else:
|
||||||
expected_scale = 1.0 / ((sigma**2 + 1) ** 0.5)
|
expected_scale = 1.0 / ((sigma**2 + 1) ** 0.5)
|
||||||
|
|
@ -55,6 +70,10 @@ def test_scheduler(name, scheduler_class, config):
|
||||||
# Simple check with loose tolerance due to float precision
|
# Simple check with loose tolerance due to float precision
|
||||||
expected_scaled_sample = sample * expected_scale
|
expected_scaled_sample = sample * expected_scale
|
||||||
if not torch.allclose(scaled_sample, expected_scaled_sample, atol=1e-4):
|
if not torch.allclose(scaled_sample, expected_scaled_sample, atol=1e-4):
|
||||||
|
# If failed, double check if it's just 'sample' (no scaling)
|
||||||
|
if torch.allclose(scaled_sample, sample, atol=1e-4):
|
||||||
|
messages.append('warning="scaling is identity"')
|
||||||
|
else:
|
||||||
log.error(f'scheduler="{name}" cls={scheduler.__class__.__name__} config={config} step={i} expected={expected_scale} error="scaling mismatch"')
|
log.error(f'scheduler="{name}" cls={scheduler.__class__.__name__} config={config} step={i} expected={expected_scale} error="scaling mismatch"')
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
@ -121,9 +140,6 @@ def test_scheduler(name, scheduler_class, config):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
final_std = sample.std().item()
|
final_std = sample.std().item()
|
||||||
with open("std_log.txt", "a") as f:
|
|
||||||
f.write(f"STD_LOG: {name} config={config} std={final_std}\n")
|
|
||||||
|
|
||||||
if final_std > 50.0 or final_std < 0.1:
|
if final_std > 50.0 or final_std < 0.1:
|
||||||
log.error(f'scheduler="{name}" cls={scheduler.__class__.__name__} config={config} std={final_std} error="variance drift"')
|
log.error(f'scheduler="{name}" cls={scheduler.__class__.__name__} config={config} std={final_std} error="variance drift"')
|
||||||
|
|
||||||
|
|
@ -212,6 +228,7 @@ def run_tests():
|
||||||
# Extra robustness check: Flow Prediction Type
|
# Extra robustness check: Flow Prediction Type
|
||||||
log.warning('type="flow"')
|
log.warning('type="flow"')
|
||||||
flow_schedulers = [
|
flow_schedulers = [
|
||||||
|
# res4lyf schedulers
|
||||||
RESUnifiedScheduler, RESMultistepScheduler, ABNorsettScheduler,
|
RESUnifiedScheduler, RESMultistepScheduler, ABNorsettScheduler,
|
||||||
RESSinglestepScheduler, RESSinglestepSDEScheduler, RESDEISMultistepScheduler,
|
RESSinglestepScheduler, RESSinglestepSDEScheduler, RESDEISMultistepScheduler,
|
||||||
RESMultistepSDEScheduler, ETDRKScheduler, LawsonScheduler, PECScheduler,
|
RESMultistepSDEScheduler, ETDRKScheduler, LawsonScheduler, PECScheduler,
|
||||||
|
|
@ -219,10 +236,27 @@ def run_tests():
|
||||||
GaussLegendreScheduler, RungeKutta44Scheduler, RungeKutta57Scheduler,
|
GaussLegendreScheduler, RungeKutta44Scheduler, RungeKutta57Scheduler,
|
||||||
RungeKutta67Scheduler, SpecializedRKScheduler, BongTangentScheduler,
|
RungeKutta67Scheduler, SpecializedRKScheduler, BongTangentScheduler,
|
||||||
CommonSigmaScheduler, RadauIIAScheduler, LangevinDynamicsScheduler,
|
CommonSigmaScheduler, RadauIIAScheduler, LangevinDynamicsScheduler,
|
||||||
RiemannianFlowScheduler
|
RiemannianFlowScheduler,
|
||||||
|
# sdnext schedulers
|
||||||
|
FlowUniPCMultistepScheduler, FlashFlowMatchEulerDiscreteScheduler, FlowMatchDPMSolverMultistepScheduler,
|
||||||
]
|
]
|
||||||
for cls in flow_schedulers:
|
for cls in flow_schedulers:
|
||||||
test_scheduler(cls.__name__, cls, {"prediction_type": "flow_prediction", "use_flow_sigmas": True})
|
test_scheduler(cls.__name__, cls, {"prediction_type": "flow_prediction", "use_flow_sigmas": True})
|
||||||
|
|
||||||
|
log.warning('type="sdnext"')
|
||||||
|
extended_schedulers = [
|
||||||
|
VDMScheduler,
|
||||||
|
UFOGenScheduler,
|
||||||
|
TDDScheduler,
|
||||||
|
TCDScheduler,
|
||||||
|
DCSolverMultistepScheduler,
|
||||||
|
BDIA_DDIMScheduler
|
||||||
|
]
|
||||||
|
for cls in extended_schedulers:
|
||||||
|
# Most of these support standard prediction types, try epsilon as default safest bet
|
||||||
|
# Some might be flow matching specific, we can try robust default list
|
||||||
|
# For now, just test default init
|
||||||
|
test_scheduler(cls.__name__, cls, {"prediction_type": "epsilon"})
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run_tests()
|
run_tests()
|
||||||
|
|
@ -103,6 +103,7 @@ class Api:
|
||||||
self.add_api_route("/sdapi/v1/latents", endpoints.get_latent_history, methods=["GET"], response_model=List[str])
|
self.add_api_route("/sdapi/v1/latents", endpoints.get_latent_history, methods=["GET"], response_model=List[str])
|
||||||
self.add_api_route("/sdapi/v1/latents", endpoints.post_latent_history, methods=["POST"], response_model=int)
|
self.add_api_route("/sdapi/v1/latents", endpoints.post_latent_history, methods=["POST"], response_model=int)
|
||||||
self.add_api_route("/sdapi/v1/modules", endpoints.get_modules, methods=["GET"])
|
self.add_api_route("/sdapi/v1/modules", endpoints.get_modules, methods=["GET"])
|
||||||
|
self.add_api_route("/sdapi/v1/sampler", endpoints.get_sampler, methods=["GET"], response_model=dict)
|
||||||
|
|
||||||
# lora api
|
# lora api
|
||||||
from modules.api import loras
|
from modules.api import loras
|
||||||
|
|
|
||||||
|
|
@ -6,8 +6,28 @@ from modules.api import models, helpers
|
||||||
|
|
||||||
|
|
||||||
def get_samplers():
|
def get_samplers():
|
||||||
from modules import sd_samplers
|
from modules import sd_samplers_diffusers
|
||||||
return [{"name": sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers]
|
all_samplers = []
|
||||||
|
for k, v in sd_samplers_diffusers.config.items():
|
||||||
|
if k in ['All', 'Default', 'Res4Lyf']:
|
||||||
|
continue
|
||||||
|
all_samplers.append({
|
||||||
|
'name': k,
|
||||||
|
'options': v,
|
||||||
|
})
|
||||||
|
return all_samplers
|
||||||
|
|
||||||
|
def get_sampler():
|
||||||
|
if not shared.sd_loaded or shared.sd_model is None:
|
||||||
|
return {}
|
||||||
|
if hasattr(shared.sd_model, 'scheduler'):
|
||||||
|
scheduler = shared.sd_model.scheduler
|
||||||
|
config = {k: v for k, v in scheduler.config.items() if not k.startswith('_')}
|
||||||
|
return {
|
||||||
|
'name': scheduler.__class__.__name__,
|
||||||
|
'options': config
|
||||||
|
}
|
||||||
|
return {}
|
||||||
|
|
||||||
def get_sd_vaes():
|
def get_sd_vaes():
|
||||||
from modules.sd_vae import vae_dict
|
from modules.sd_vae import vae_dict
|
||||||
|
|
@ -75,6 +95,13 @@ def get_interrogate():
|
||||||
from modules.interrogate.openclip import refresh_clip_models
|
from modules.interrogate.openclip import refresh_clip_models
|
||||||
return ['deepdanbooru'] + refresh_clip_models()
|
return ['deepdanbooru'] + refresh_clip_models()
|
||||||
|
|
||||||
|
def get_schedulers():
|
||||||
|
from modules.sd_samplers import list_samplers
|
||||||
|
all_schedulers = list_samplers()
|
||||||
|
for s in all_schedulers:
|
||||||
|
shared.log.critical(s)
|
||||||
|
return all_schedulers
|
||||||
|
|
||||||
def post_interrogate(req: models.ReqInterrogate):
|
def post_interrogate(req: models.ReqInterrogate):
|
||||||
if req.image is None or len(req.image) < 64:
|
if req.image is None or len(req.image) < 64:
|
||||||
raise HTTPException(status_code=404, detail="Image not found")
|
raise HTTPException(status_code=404, detail="Image not found")
|
||||||
|
|
|
||||||
|
|
@ -86,8 +86,7 @@ class PydanticModelGenerator:
|
||||||
|
|
||||||
class ItemSampler(BaseModel):
|
class ItemSampler(BaseModel):
|
||||||
name: str = Field(title="Name")
|
name: str = Field(title="Name")
|
||||||
aliases: List[str] = Field(title="Aliases")
|
options: dict
|
||||||
options: Dict[str, str] = Field(title="Options")
|
|
||||||
|
|
||||||
class ItemVae(BaseModel):
|
class ItemVae(BaseModel):
|
||||||
model_name: str = Field(title="Model Name")
|
model_name: str = Field(title="Model Name")
|
||||||
|
|
@ -199,6 +198,11 @@ class ItemExtension(BaseModel):
|
||||||
commit_date: Union[str, int] = Field(title="Commit Date", description="Extension Repository Commit Date")
|
commit_date: Union[str, int] = Field(title="Commit Date", description="Extension Repository Commit Date")
|
||||||
enabled: bool = Field(title="Enabled", description="Flag specifying whether this extension is enabled")
|
enabled: bool = Field(title="Enabled", description="Flag specifying whether this extension is enabled")
|
||||||
|
|
||||||
|
class ItemScheduler(BaseModel):
|
||||||
|
name: str = Field(title="Name", description="Scheduler name")
|
||||||
|
cls: str = Field(title="Class", description="Scheduler class name")
|
||||||
|
options: Dict[str, Any] = Field(title="Options", description="Dictionary of scheduler options")
|
||||||
|
|
||||||
### request/response classes
|
### request/response classes
|
||||||
|
|
||||||
ReqTxt2Img = PydanticModelGenerator(
|
ReqTxt2Img = PydanticModelGenerator(
|
||||||
|
|
|
||||||
|
|
@ -155,6 +155,8 @@ class FlowMatchDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||||
algorithm_type: str = "dpmsolver++2M",
|
algorithm_type: str = "dpmsolver++2M",
|
||||||
solver_type: str = "midpoint",
|
solver_type: str = "midpoint",
|
||||||
sigma_schedule: Optional[str] = None,
|
sigma_schedule: Optional[str] = None,
|
||||||
|
prediction_type: str = "flow_prediction",
|
||||||
|
use_flow_sigmas: bool = True,
|
||||||
shift: float = 3.0,
|
shift: float = 3.0,
|
||||||
midpoint_ratio: Optional[float] = 0.5,
|
midpoint_ratio: Optional[float] = 0.5,
|
||||||
s_noise: Optional[float] = 1.0,
|
s_noise: Optional[float] = 1.0,
|
||||||
|
|
|
||||||
|
|
@ -69,6 +69,8 @@ class FlashFlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||||
num_train_timesteps: int = 1000,
|
num_train_timesteps: int = 1000,
|
||||||
shift: float = 1.0,
|
shift: float = 1.0,
|
||||||
use_dynamic_shifting=False,
|
use_dynamic_shifting=False,
|
||||||
|
prediction_type: str = "flow_prediction",
|
||||||
|
use_flow_sigmas: bool = True,
|
||||||
base_shift: Optional[float] = 0.5,
|
base_shift: Optional[float] = 0.5,
|
||||||
max_shift: Optional[float] = 1.15,
|
max_shift: Optional[float] = 1.15,
|
||||||
base_image_seq_len: Optional[int] = 256,
|
base_image_seq_len: Optional[int] = 256,
|
||||||
|
|
@ -261,6 +263,22 @@ class FlashFlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||||
else:
|
else:
|
||||||
self._step_index = self._begin_index
|
self._step_index = self._begin_index
|
||||||
|
|
||||||
|
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
|
||||||
|
"""
|
||||||
|
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
||||||
|
current timestep.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample (`torch.FloatTensor`):
|
||||||
|
The input sample.
|
||||||
|
timestep (`int`, *optional*):
|
||||||
|
The current timestep in the diffusion chain.
|
||||||
|
Returns:
|
||||||
|
`torch.FloatTensor`:
|
||||||
|
A scaled input sample.
|
||||||
|
"""
|
||||||
|
return sample
|
||||||
|
|
||||||
def step(
|
def step(
|
||||||
self,
|
self,
|
||||||
model_output: torch.FloatTensor,
|
model_output: torch.FloatTensor,
|
||||||
|
|
|
||||||
|
|
@ -497,7 +497,7 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
|
||||||
model_output: torch.FloatTensor,
|
model_output: torch.FloatTensor,
|
||||||
timestep: int,
|
timestep: int,
|
||||||
sample: torch.FloatTensor,
|
sample: torch.FloatTensor,
|
||||||
eta: float,
|
eta: float = 0.0,
|
||||||
generator: Optional[torch.Generator] = None,
|
generator: Optional[torch.Generator] = None,
|
||||||
return_dict: bool = True,
|
return_dict: bool = True,
|
||||||
) -> Union[TCDSchedulerOutput, Tuple]:
|
) -> Union[TCDSchedulerOutput, Tuple]:
|
||||||
|
|
|
||||||
|
|
@ -224,7 +224,7 @@ class TDDScheduler(DPMSolverSinglestepScheduler):
|
||||||
model_output: torch.FloatTensor,
|
model_output: torch.FloatTensor,
|
||||||
timestep: int,
|
timestep: int,
|
||||||
sample: torch.FloatTensor,
|
sample: torch.FloatTensor,
|
||||||
eta: float,
|
eta: float = 0.0,
|
||||||
generator: Optional[torch.Generator] = None,
|
generator: Optional[torch.Generator] = None,
|
||||||
return_dict: bool = True,
|
return_dict: bool = True,
|
||||||
) -> Union[SchedulerOutput, Tuple]:
|
) -> Union[SchedulerOutput, Tuple]:
|
||||||
|
|
|
||||||
|
|
@ -86,6 +86,7 @@ class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||||
lower_order_final: bool = True,
|
lower_order_final: bool = True,
|
||||||
disable_corrector: List[int] = [],
|
disable_corrector: List[int] = [],
|
||||||
solver_p: SchedulerMixin = None,
|
solver_p: SchedulerMixin = None,
|
||||||
|
use_flow_sigmas: bool = True,
|
||||||
timestep_spacing: str = "linspace",
|
timestep_spacing: str = "linspace",
|
||||||
steps_offset: int = 0,
|
steps_offset: int = 0,
|
||||||
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
|
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
|
||||||
|
|
|
||||||
|
|
@ -141,7 +141,7 @@ class VDMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
|
|
||||||
# For linear beta schedule, equivalent to torch.exp(-1e-4 - 10 * t ** 2)
|
# For linear beta schedule, equivalent to torch.exp(-1e-4 - 10 * t ** 2)
|
||||||
self.alphas_cumprod = lambda t: torch.sigmoid(self.log_snr(t)) # Equivalent to 1 - self.sigmas
|
self.alphas_cumprod = lambda t: torch.sigmoid(self.log_snr(t)) # Equivalent to 1 - self.sigmas
|
||||||
self.sigmas = lambda t: torch.sigmoid(-self.log_snr(t)) # Equivalent to 1 - self.alphas_cumprod
|
self.sigmas = []
|
||||||
|
|
||||||
self.num_inference_steps = None
|
self.num_inference_steps = None
|
||||||
self.timesteps = torch.from_numpy(self.get_timesteps(len(self)))
|
self.timesteps = torch.from_numpy(self.get_timesteps(len(self)))
|
||||||
|
|
@ -240,6 +240,8 @@ class VDMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
self.num_inference_steps = num_inference_steps
|
self.num_inference_steps = num_inference_steps
|
||||||
timesteps += self.config.steps_offset
|
timesteps += self.config.steps_offset
|
||||||
self.timesteps = torch.from_numpy(timesteps).to(device)
|
self.timesteps = torch.from_numpy(timesteps).to(device)
|
||||||
|
self.sigmas = [torch.sigmoid(-self.log_snr(t)) for t in self.timesteps]
|
||||||
|
self.sigmas = torch.stack(self.sigmas)
|
||||||
|
|
||||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||||
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
|
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
|
||||||
|
|
|
||||||
|
|
@ -37,6 +37,7 @@ def list_samplers():
|
||||||
samplers = all_samplers
|
samplers = all_samplers
|
||||||
samplers_for_img2img = all_samplers
|
samplers_for_img2img = all_samplers
|
||||||
samplers_map = {}
|
samplers_map = {}
|
||||||
|
return all_samplers
|
||||||
# shared.log.debug(f'Available samplers: {[x.name for x in all_samplers]}')
|
# shared.log.debug(f'Available samplers: {[x.name for x in all_samplers]}')
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
2
wiki
2
wiki
|
|
@ -1 +1 @@
|
||||||
Subproject commit da7620df144de8d2af259eff2b7a4522783f38cc
|
Subproject commit 850c155e238f369dd135d79e138470d7822ad5b6
|
||||||
Loading…
Reference in New Issue