automatic/scripts/lbm/embedders/base/base_conditioner.py

59 lines
1.7 KiB
Python

from typing import Any, Dict
from ...base.base_model import BaseModel
from .base_conditioner_config import BaseConditionerConfig
DIM2CONDITIONING = {
2: "vector",
3: "crossattn",
4: "concat",
}
class BaseConditioner(BaseModel):
"""This is the base class for all the conditioners. This absctacts the conditioning process
Args:
config (BaseConditionerConfig): The configuration of the conditioner
Examples
########
To use the conditioner, you can import the class and use it as follows:
.. code-block:: python
from cr.models.embedders import BaseConditioner, BaseConditionerConfig
# Create the conditioner config
config = BaseConditionerConfig(
input_key="text", # The key for the input
unconditional_conditioning_rate=0.3, # Drops the conditioning with 30% probability during training
)
# Create the conditioner
conditioner = BaseConditioner(config)
"""
def __init__(self, config: BaseConditionerConfig):
BaseModel.__init__(self, config)
self.config = config
self.input_key = config.input_key
self.dim2outputkey = DIM2CONDITIONING
self.ucg_rate = config.unconditional_conditioning_rate
def forward(
self, batch: Dict[str, Any], force_zero_embedding: bool = False, *args, **kwargs
):
"""
Forward pass of the embedder.
Args:
batch (Dict[str, Any]): A dictionary containing the input data.
force_zero_embedding (bool): Whether to force zero embedding.
This will return an embedding with all entries set to 0. Defaults to False.
"""
raise NotImplementedError("Forward pass must be implemented in child class")