mirror of https://github.com/vladmandic/automatic
143 lines
4.4 KiB
Python
143 lines
4.4 KiB
Python
# Copyright 2025 The RES4LYF Team (Clybius) and The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import math
|
|
|
|
import torch
|
|
from mpmath import exp as mp_exp
|
|
from mpmath import factorial as mp_factorial
|
|
from mpmath import mp, mpf
|
|
|
|
# Set precision for mpmath
|
|
mp.dps = 80
|
|
|
|
|
|
def calculate_gamma(c2: float, c3: float) -> float:
|
|
"""Calculates the gamma parameter for RES 3s samplers."""
|
|
return (3 * (c3**3) - 2 * c3) / (c2 * (2 - 3 * c2))
|
|
|
|
|
|
def _torch_factorial(n: int) -> float:
|
|
return float(math.factorial(n))
|
|
|
|
|
|
def phi_standard_torch(j: int, neg_h: torch.Tensor) -> torch.Tensor:
|
|
r"""
|
|
Standard implementation of phi functions using torch.
|
|
ϕj(-h) = (e^(-h) - \sum_{k=0}^{j-1} (-h)^k / k!) / (-h)^j
|
|
For h=0, ϕj(0) = 1/j!
|
|
"""
|
|
assert j > 0
|
|
|
|
# Handle h=0 case
|
|
if torch.all(neg_h == 0):
|
|
return torch.full_like(neg_h, 1.0 / _torch_factorial(j))
|
|
|
|
# We use double precision for the series to avoid early overflow/precision loss
|
|
orig_dtype = neg_h.dtype
|
|
neg_h = neg_h.to(torch.float64)
|
|
|
|
# For very small h, use series expansion to avoid 0/0
|
|
if torch.any(torch.abs(neg_h) < 1e-4):
|
|
# 1/j! + z/(j+1)! + z^2/(2!(j+2)!) ...
|
|
result = torch.full_like(neg_h, 1.0 / _torch_factorial(j))
|
|
term = torch.full_like(neg_h, 1.0 / _torch_factorial(j))
|
|
for k in range(1, 5):
|
|
term = term * neg_h / (j + k)
|
|
result += term
|
|
return result.to(orig_dtype)
|
|
|
|
remainder = torch.zeros_like(neg_h)
|
|
for k in range(j):
|
|
remainder += (neg_h**k) / _torch_factorial(k)
|
|
|
|
phi_val = (torch.exp(neg_h) - remainder) / (neg_h**j)
|
|
return phi_val.to(orig_dtype)
|
|
|
|
|
|
def phi_mpmath_series(j: int, neg_h: float) -> float:
|
|
"""Arbitrary-precision phi_j(-h) via series definition."""
|
|
j = int(j)
|
|
z = mpf(float(neg_h))
|
|
|
|
# Handle h=0 case: phi_j(0) = 1/j!
|
|
if z == 0:
|
|
return float(1.0 / mp_factorial(j))
|
|
|
|
s_val = mp.mpf("0")
|
|
for k in range(j):
|
|
s_val += (z**k) / mp_factorial(k)
|
|
phi_val = (mp_exp(z) - s_val) / (z**j)
|
|
return float(phi_val)
|
|
|
|
|
|
class Phi:
|
|
"""
|
|
Class to manage phi function calculations and caching.
|
|
Supports both standard torch-based and high-precision mpmath-based solutions.
|
|
"""
|
|
|
|
def __init__(self, h: torch.Tensor, c: list[float | mpf], analytic_solution: bool = True):
|
|
self.h = h
|
|
self.c = c
|
|
self.cache: dict[tuple[int, int], float | torch.Tensor] = {}
|
|
self.analytic_solution = analytic_solution
|
|
|
|
if analytic_solution:
|
|
self.phi_f = phi_mpmath_series
|
|
self.h_mpf = mpf(float(h))
|
|
self.c_mpf = [mpf(float(c_val)) for c_val in c]
|
|
else:
|
|
self.phi_f = phi_standard_torch
|
|
|
|
def __call__(self, j: int, i: int = -1) -> float | torch.Tensor:
|
|
if (j, i) in self.cache:
|
|
return self.cache[(j, i)]
|
|
|
|
if i < 0:
|
|
c_val = 1.0
|
|
else:
|
|
c_val = self.c[i - 1]
|
|
if c_val == 0:
|
|
self.cache[(j, i)] = 0.0
|
|
return 0.0
|
|
|
|
if self.analytic_solution:
|
|
h_val = self.h_mpf
|
|
c_mapped = self.c_mpf[i - 1] if i >= 0 else 1.0
|
|
|
|
if j == 0:
|
|
result = float(mp_exp(-h_val * c_mapped))
|
|
else:
|
|
# Use the mpmath internal function for higher precision
|
|
z = -h_val * c_mapped
|
|
if z == 0:
|
|
result = float(1.0 / mp_factorial(j))
|
|
else:
|
|
s_val = mp.mpf("0")
|
|
for k in range(j):
|
|
s_val += (z**k) / mp_factorial(k)
|
|
result = float((mp_exp(z) - s_val) / (z**j))
|
|
else:
|
|
h_val = self.h
|
|
c_mapped = float(c_val)
|
|
|
|
if j == 0:
|
|
result = torch.exp(-h_val * c_mapped)
|
|
else:
|
|
result = self.phi_f(j, -h_val * c_mapped)
|
|
|
|
self.cache[(j, i)] = result
|
|
return result
|