from typing import Union, Optional, List import torch from diffusers.utils import logging from transformers import ( T5EncoderModel, T5TokenizerFast, AutoTokenizer ) from transformers import ( CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer ) import numpy as np import torch.distributed as dist import math import os logger = logging.get_logger(__name__) # pylint: disable=invalid-name def get_text(caption): existing_text_list = set() if caption[0]=='\"' and caption[-1]=='\"': caption=caption[1:-2] if caption[0]=='\'' and caption[-1]=='\'': caption=caption[1:-2] text_list=[] current_text='' text_present = False for c in caption: if c=='\"' and not text_present: text_present=True continue if c=='\"' and text_present: if current_text not in existing_text_list: text_list+=[current_text] existing_text_list.add(current_text) text_present=False current_text='' continue if text_present: current_text+=c return text_list def get_by_t5_prompt_embeds( tokenizer: AutoTokenizer , text_encoder: T5EncoderModel, prompt: Union[str, List[str]], max_sequence_length: int = 128, device: Optional[torch.device] = None, ): device = device or text_encoder.device if isinstance(prompt, list): assert len(prompt)==1 prompt=prompt[0] assert type(prompt)==str captions_list = get_text(prompt) embeddings_list=[] for inner_prompt in captions_list: text_inputs = tokenizer( [inner_prompt], max_length=max_sequence_length, truncation=True, add_special_tokens=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids prompt_embeds = text_encoder(text_input_ids.to(device))[0] embeddings_list+=[prompt_embeds[0]] # No Text Found if len(embeddings_list)==0: return None prompt_embeds = torch.concatenate(embeddings_list,axis=0) # Concat zeros to max_sequence seq_len, dim = prompt_embeds.shape if seq_len= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) logger.warning( "The following part of your input was truncated because `max_sequence_length` is set to " f" {max_sequence_length} tokens: {removed_text}" ) prompt_embeds = text_encoder(text_input_ids.to(device))[0] # Concat zeros to max_sequence b, seq_len, dim = prompt_embeds.shape if seq_len torch.Tensor: n_axes = ids.shape[-1] cos_out = [] sin_out = [] pos = ids.float() is_mps = ids.device.type == "mps" freqs_dtype = torch.float32 if is_mps else torch.float64 for i in range(n_axes): cos, sin = get_1d_rotary_pos_embed( self.axes_dim[i], pos[:, i], theta=self.theta, repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype, ) cos_out.append(cos) sin_out.append(sin) freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) return freqs_cos, freqs_sin from diffusers.optimization import get_scheduler from torch.optim import Optimizer from torch.optim.lr_scheduler import LambdaLR # Not really cosine but with decay def get_cosine_schedule_with_warmup_and_decay( optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1, constant_steps=-1,eps=1e-5 ) -> LambdaLR: """ Create a schedule with a learning rate that decreases following the values of the cosine function between the initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the initial lr set in the optimizer. Args: optimizer ([`~torch.optim.Optimizer`]): The optimizer for which to schedule the learning rate. num_warmup_steps (`int`): The number of steps for the warmup phase. num_training_steps (`int`): The total number of training steps. num_periods (`float`, *optional*, defaults to 0.5): The number of periods of the cosine function in a schedule (the default is to just decrease from the max value to 0 following a half-cosine). last_epoch (`int`, *optional*, defaults to -1): The index of the last epoch when resuming training. constant_steps (`int`): The total number of constant lr steps following a warmup Return: `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. """ if constant_steps <=0: constant_steps = num_training_steps-num_warmup_steps def lr_lambda(current_step): # Accelerate sends current_step*num_processes if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) elif current_step