mirror of https://github.com/vladmandic/automatic
114 lines
3.8 KiB
Python
114 lines
3.8 KiB
Python
import torch
|
|
import transformers
|
|
import diffusers
|
|
from modules import shared, devices, sd_models, model_quant
|
|
|
|
|
|
class XOmniPipeline(diffusers.DiffusionPipeline):
|
|
def __init__(
|
|
self,
|
|
tokenizer=None,
|
|
model=None,
|
|
):
|
|
super().__init__()
|
|
self.tokenizer = tokenizer
|
|
self.model = model
|
|
self.register_modules(
|
|
tokenizer=tokenizer,
|
|
model=model,
|
|
)
|
|
|
|
def load(
|
|
self,
|
|
repo_id,
|
|
load_config: dict = {},
|
|
):
|
|
from pipelines.xomni import modeling_xomni
|
|
load_args, quant_args = model_quant.get_dit_args(load_config, module='Model', device_map=True)
|
|
shared.log.debug(f'Load model: cls=XOmniPipeline module=tokenizer repo_id="{repo_id}"')
|
|
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
|
|
repo_id,
|
|
use_fast=True,
|
|
)
|
|
shared.log.debug(f'Load model: cls=XOmniPipeline module=transformer repo_id="{repo_id}" args={load_args}')
|
|
# self.model = transformers.AutoModelForCausalLM.from_pretrained(
|
|
self.model = modeling_xomni.XOmniForCausalLM.from_pretrained(
|
|
repo_id,
|
|
# trust_remote_code=True,
|
|
cache_dir=shared.opts.hfcache_dir,
|
|
**load_args,
|
|
**quant_args,
|
|
)
|
|
flux_repo_id = "black-forest-labs/FLUX.1-dev"
|
|
shared.log.debug(f'Load model: cls=XOmniPipeline module=vision repo_id="{flux_repo_id}"')
|
|
self.model.init_vision(
|
|
flux_repo_id,
|
|
**quant_args,
|
|
)
|
|
self.model.set_generation_mode('image')
|
|
|
|
def __call__(
|
|
self,
|
|
prompt: str = "",
|
|
width: int = 1024,
|
|
height: int = 1024,
|
|
seed: int = -1,
|
|
temperature: float = 1.0,
|
|
downsample_size: int = 16,
|
|
min_p: float = 0.03,
|
|
top_p: float = 1.0,
|
|
cfg_scale: float = 1.0,
|
|
):
|
|
|
|
if isinstance(prompt, list):
|
|
prompt = prompt[0]
|
|
token_h, token_w = height // downsample_size, width // downsample_size
|
|
image_prefix = f'<SOM>{token_h} {token_w}<IMAGE>'
|
|
generation_config = transformers.generation.GenerationConfig(
|
|
max_new_tokens=token_h * token_w,
|
|
do_sample=True,
|
|
temperature=temperature,
|
|
min_p=min_p,
|
|
top_p=top_p,
|
|
guidance_scale=cfg_scale,
|
|
suppress_tokens=self.tokenizer.convert_tokens_to_ids(self.model.config.mm_special_tokens),
|
|
)
|
|
|
|
# Sample inputs:
|
|
tokens = self.tokenizer(
|
|
[prompt + image_prefix],
|
|
return_tensors='pt',
|
|
padding='longest',
|
|
padding_side='left',
|
|
)
|
|
input_ids = tokens.input_ids.to(devices.device)
|
|
attention_mask = tokens.attention_mask.to(devices.device)
|
|
negative_ids = self.tokenizer.encode(
|
|
image_prefix,
|
|
add_special_tokens=False,
|
|
return_tensors='pt',
|
|
).to(devices.device).expand(1, -1)
|
|
|
|
torch.manual_seed(seed)
|
|
tokens = self.model.generate(
|
|
inputs=input_ids,
|
|
attention_mask=attention_mask,
|
|
generation_config=generation_config,
|
|
negative_prompt_ids=negative_ids,
|
|
)
|
|
|
|
tokens = torch.nn.functional.pad(tokens, (0, 1), value=self.tokenizer.convert_tokens_to_ids('<EOM>'))
|
|
torch.manual_seed(seed)
|
|
_, images = self.model.mmdecode(self.tokenizer, tokens[0], skip_special_tokens=False)
|
|
images[0].save('/tmp/xomni_out.png')
|
|
return images
|
|
|
|
|
|
def load_xomni(checkpoint_info, diffusers_load_config={}):
|
|
repo_id = sd_models.path_to_repo(checkpoint_info)
|
|
sd_models.hf_auth_check(checkpoint_info)
|
|
|
|
pipe = XOmniPipeline()
|
|
pipe.load(repo_id, load_config=diffusers_load_config)
|
|
return pipe
|