mirror of https://github.com/vladmandic/automatic
Merge pull request #4668 from resonantsky/seedvr-fix-distributed
Fixes for torch distributed errors on SeedVR 3Bpull/4671/head
commit
045df139c0
|
|
@ -118,9 +118,9 @@ def log_runtime(func: Callable) -> Callable:
|
|||
|
||||
@functools.wraps(func)
|
||||
def wrapped(*args, **kwargs):
|
||||
torch.distributed.barrier()
|
||||
barrier_if_distributed()
|
||||
result = func(*args, **kwargs)
|
||||
torch.distributed.barrier()
|
||||
barrier_if_distributed()
|
||||
return result
|
||||
|
||||
return wrapped
|
||||
|
|
|
|||
|
|
@ -16,13 +16,15 @@
|
|||
Advanced distributed functions for sequence parallel.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Optional, List
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
|
||||
from torch.distributed.fsdp import ShardingStrategy
|
||||
|
||||
from .basic import get_global_rank, get_world_size
|
||||
from .basic import _is_dist, get_global_rank, get_world_size
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_DATA_PARALLEL_GROUP = None
|
||||
|
|
@ -61,7 +63,10 @@ def get_data_parallel_rank() -> int:
|
|||
Get data parallel rank.
|
||||
"""
|
||||
group = get_data_parallel_group()
|
||||
return dist.get_rank(group) if group else get_global_rank()
|
||||
if group and _is_dist():
|
||||
import torch.distributed as dist
|
||||
return dist.get_rank(group)
|
||||
return get_global_rank()
|
||||
|
||||
|
||||
def get_data_parallel_world_size() -> int:
|
||||
|
|
@ -69,7 +74,10 @@ def get_data_parallel_world_size() -> int:
|
|||
Get data parallel world size.
|
||||
"""
|
||||
group = get_data_parallel_group()
|
||||
return dist.get_world_size(group) if group else get_world_size()
|
||||
if group and _is_dist():
|
||||
import torch.distributed as dist
|
||||
return dist.get_world_size(group)
|
||||
return get_world_size()
|
||||
|
||||
|
||||
def get_sequence_parallel_rank() -> int:
|
||||
|
|
@ -77,7 +85,10 @@ def get_sequence_parallel_rank() -> int:
|
|||
Get sequence parallel rank.
|
||||
"""
|
||||
group = get_sequence_parallel_group()
|
||||
return dist.get_rank(group) if group else 0
|
||||
if group and _is_dist():
|
||||
import torch.distributed as dist
|
||||
return dist.get_rank(group)
|
||||
return 0
|
||||
|
||||
|
||||
def get_sequence_parallel_world_size() -> int:
|
||||
|
|
@ -85,7 +96,10 @@ def get_sequence_parallel_world_size() -> int:
|
|||
Get sequence parallel world size.
|
||||
"""
|
||||
group = get_sequence_parallel_group()
|
||||
return dist.get_world_size(group) if group else 1
|
||||
if group and _is_dist():
|
||||
import torch.distributed as dist
|
||||
return dist.get_world_size(group)
|
||||
return 1
|
||||
|
||||
|
||||
def get_model_shard_cpu_intra_group() -> Optional[dist.ProcessGroup]:
|
||||
|
|
@ -120,11 +134,14 @@ def init_sequence_parallel(sequence_parallel_size: int):
|
|||
"""
|
||||
Initialize sequence parallel.
|
||||
"""
|
||||
if not _is_dist():
|
||||
logger.debug("Skipping init_sequence_parallel: distributed not initialized")
|
||||
return
|
||||
import torch.distributed as dist
|
||||
global _DATA_PARALLEL_GROUP
|
||||
global _SEQUENCE_PARALLEL_GROUP
|
||||
global _SEQUENCE_PARALLEL_CPU_GROUP
|
||||
global _SEQUENCE_PARALLEL_GLOBAL_RANKS
|
||||
assert dist.is_initialized()
|
||||
world_size = dist.get_world_size()
|
||||
rank = dist.get_rank()
|
||||
data_parallel_size = world_size // sequence_parallel_size
|
||||
|
|
@ -142,17 +159,22 @@ def init_sequence_parallel(sequence_parallel_size: int):
|
|||
|
||||
def init_model_shard_group(
|
||||
*,
|
||||
sharding_strategy: ShardingStrategy,
|
||||
device_mesh: Optional[DeviceMesh] = None,
|
||||
sharding_strategy=None,
|
||||
device_mesh=None,
|
||||
):
|
||||
"""
|
||||
Initialize process group of model sharding.
|
||||
"""
|
||||
if not _is_dist():
|
||||
logger.debug("Skipping init_model_shard_group: distributed not initialized")
|
||||
return
|
||||
import torch.distributed as dist
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.fsdp import ShardingStrategy
|
||||
global _MODEL_SHARD_INTER_GROUP
|
||||
global _MODEL_SHARD_INTRA_GROUP
|
||||
global _MODEL_SHARD_CPU_INTER_GROUP
|
||||
global _MODEL_SHARD_CPU_INTRA_GROUP
|
||||
assert dist.is_initialized()
|
||||
world_size = dist.get_world_size()
|
||||
if device_mesh is not None:
|
||||
num_shards_per_group = device_mesh.shape[1]
|
||||
|
|
@ -182,7 +204,10 @@ def get_sequence_parallel_global_ranks() -> List[int]:
|
|||
that the caller rank belongs to.
|
||||
"""
|
||||
if _SEQUENCE_PARALLEL_GLOBAL_RANKS is None:
|
||||
return [dist.get_rank()]
|
||||
if _is_dist():
|
||||
import torch.distributed as dist
|
||||
return [dist.get_rank()]
|
||||
return [0]
|
||||
return _SEQUENCE_PARALLEL_GLOBAL_RANKS
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -20,6 +20,14 @@ import os
|
|||
import torch
|
||||
|
||||
|
||||
def _is_dist() -> bool:
|
||||
"""Check if torch.distributed is available and initialized.
|
||||
Follows the same guard convention as flash_attn (utils/generation.py).
|
||||
"""
|
||||
import torch.distributed as dist
|
||||
return dist.is_available() and dist.is_initialized()
|
||||
|
||||
|
||||
def get_global_rank() -> int:
|
||||
"""
|
||||
Get the global rank, the global index of the GPU.
|
||||
|
|
@ -45,15 +53,17 @@ def get_device() -> torch.device:
|
|||
"""
|
||||
Get current rank device.
|
||||
"""
|
||||
return torch.device("cuda", get_local_rank())
|
||||
if torch.cuda.is_available():
|
||||
return torch.device("cuda", get_local_rank())
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
def barrier_if_distributed(*args, **kwargs):
|
||||
"""
|
||||
Synchronizes all processes if under distributed context.
|
||||
"""
|
||||
import torch.distributed as dist
|
||||
if dist.is_initialized():
|
||||
if _is_dist():
|
||||
import torch.distributed as dist
|
||||
return dist.barrier(*args, **kwargs)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -144,8 +144,8 @@ class Slice(torch.autograd.Function):
|
|||
dim_size = list(grad_output.size())
|
||||
split_size = dim_size[0]
|
||||
dim_size[0] = dim_size[0] * ctx.seq_world_size
|
||||
output = torch.empty(dim_size, dtype=grad_output.dtype, device=torch.cuda.current_device())
|
||||
dist._all_gather_base(output, grad_output, group=ctx.group)
|
||||
output = torch.empty(dim_size, dtype=grad_output.dtype, device=grad_output.device)
|
||||
dist.all_gather_into_tensor(output, grad_output, group=ctx.group)
|
||||
return (None, torch.cat(output.split(split_size), dim=ctx.dim), None)
|
||||
|
||||
|
||||
|
|
@ -168,8 +168,8 @@ class Gather(torch.autograd.Function):
|
|||
split_size = dim_size[0]
|
||||
ctx.part_size = dim_size[dim]
|
||||
dim_size[0] = dim_size[0] * seq_world_size
|
||||
output = torch.empty(dim_size, dtype=local_input.dtype, device=torch.cuda.current_device())
|
||||
dist._all_gather_base(output, local_input.contiguous(), group=ctx.group)
|
||||
output = torch.empty(dim_size, dtype=local_input.dtype, device=local_input.device)
|
||||
dist.all_gather_into_tensor(output, local_input.contiguous(), group=ctx.group)
|
||||
return torch.cat(output.split(split_size), dim=dim)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -87,8 +87,4 @@ class FlashAttentionVarlen(nn.Module):
|
|||
|
||||
def forward(self, *args, **kwargs):
|
||||
kwargs["deterministic"] = torch.are_deterministic_algorithms_enabled()
|
||||
try:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
return flash_attn_varlen_func(*args, **kwargs)
|
||||
except ImportError:
|
||||
return pytorch_varlen_attention(*args, **kwargs)
|
||||
return pytorch_varlen_attention(*args, **kwargs)
|
||||
|
|
|
|||
|
|
@ -85,8 +85,4 @@ class FlashAttentionVarlen(nn.Module):
|
|||
|
||||
def forward(self, *args, **kwargs):
|
||||
kwargs["deterministic"] = torch.are_deterministic_algorithms_enabled()
|
||||
try:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
return flash_attn_varlen_func(*args, **kwargs)
|
||||
except ImportError:
|
||||
return pytorch_varlen_attention(*args, **kwargs)
|
||||
return pytorch_varlen_attention(*args, **kwargs)
|
||||
|
|
|
|||
Loading…
Reference in New Issue