Merge pull request #4668 from resonantsky/seedvr-fix-distributed

Fixes for torch distributed errors on SeedVR 3B
pull/4671/head
Vladimir Mandic 2026-03-04 17:10:03 +01:00 committed by GitHub
commit 045df139c0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 59 additions and 32 deletions

View File

@ -118,9 +118,9 @@ def log_runtime(func: Callable) -> Callable:
@functools.wraps(func)
def wrapped(*args, **kwargs):
torch.distributed.barrier()
barrier_if_distributed()
result = func(*args, **kwargs)
torch.distributed.barrier()
barrier_if_distributed()
return result
return wrapped

View File

@ -16,13 +16,15 @@
Advanced distributed functions for sequence parallel.
"""
from __future__ import annotations
import logging
from typing import Optional, List
import torch
import torch.distributed as dist
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.fsdp import ShardingStrategy
from .basic import get_global_rank, get_world_size
from .basic import _is_dist, get_global_rank, get_world_size
logger = logging.getLogger(__name__)
_DATA_PARALLEL_GROUP = None
@ -61,7 +63,10 @@ def get_data_parallel_rank() -> int:
Get data parallel rank.
"""
group = get_data_parallel_group()
return dist.get_rank(group) if group else get_global_rank()
if group and _is_dist():
import torch.distributed as dist
return dist.get_rank(group)
return get_global_rank()
def get_data_parallel_world_size() -> int:
@ -69,7 +74,10 @@ def get_data_parallel_world_size() -> int:
Get data parallel world size.
"""
group = get_data_parallel_group()
return dist.get_world_size(group) if group else get_world_size()
if group and _is_dist():
import torch.distributed as dist
return dist.get_world_size(group)
return get_world_size()
def get_sequence_parallel_rank() -> int:
@ -77,7 +85,10 @@ def get_sequence_parallel_rank() -> int:
Get sequence parallel rank.
"""
group = get_sequence_parallel_group()
return dist.get_rank(group) if group else 0
if group and _is_dist():
import torch.distributed as dist
return dist.get_rank(group)
return 0
def get_sequence_parallel_world_size() -> int:
@ -85,7 +96,10 @@ def get_sequence_parallel_world_size() -> int:
Get sequence parallel world size.
"""
group = get_sequence_parallel_group()
return dist.get_world_size(group) if group else 1
if group and _is_dist():
import torch.distributed as dist
return dist.get_world_size(group)
return 1
def get_model_shard_cpu_intra_group() -> Optional[dist.ProcessGroup]:
@ -120,11 +134,14 @@ def init_sequence_parallel(sequence_parallel_size: int):
"""
Initialize sequence parallel.
"""
if not _is_dist():
logger.debug("Skipping init_sequence_parallel: distributed not initialized")
return
import torch.distributed as dist
global _DATA_PARALLEL_GROUP
global _SEQUENCE_PARALLEL_GROUP
global _SEQUENCE_PARALLEL_CPU_GROUP
global _SEQUENCE_PARALLEL_GLOBAL_RANKS
assert dist.is_initialized()
world_size = dist.get_world_size()
rank = dist.get_rank()
data_parallel_size = world_size // sequence_parallel_size
@ -142,17 +159,22 @@ def init_sequence_parallel(sequence_parallel_size: int):
def init_model_shard_group(
*,
sharding_strategy: ShardingStrategy,
device_mesh: Optional[DeviceMesh] = None,
sharding_strategy=None,
device_mesh=None,
):
"""
Initialize process group of model sharding.
"""
if not _is_dist():
logger.debug("Skipping init_model_shard_group: distributed not initialized")
return
import torch.distributed as dist
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import ShardingStrategy
global _MODEL_SHARD_INTER_GROUP
global _MODEL_SHARD_INTRA_GROUP
global _MODEL_SHARD_CPU_INTER_GROUP
global _MODEL_SHARD_CPU_INTRA_GROUP
assert dist.is_initialized()
world_size = dist.get_world_size()
if device_mesh is not None:
num_shards_per_group = device_mesh.shape[1]
@ -182,7 +204,10 @@ def get_sequence_parallel_global_ranks() -> List[int]:
that the caller rank belongs to.
"""
if _SEQUENCE_PARALLEL_GLOBAL_RANKS is None:
return [dist.get_rank()]
if _is_dist():
import torch.distributed as dist
return [dist.get_rank()]
return [0]
return _SEQUENCE_PARALLEL_GLOBAL_RANKS

View File

@ -20,6 +20,14 @@ import os
import torch
def _is_dist() -> bool:
"""Check if torch.distributed is available and initialized.
Follows the same guard convention as flash_attn (utils/generation.py).
"""
import torch.distributed as dist
return dist.is_available() and dist.is_initialized()
def get_global_rank() -> int:
"""
Get the global rank, the global index of the GPU.
@ -45,15 +53,17 @@ def get_device() -> torch.device:
"""
Get current rank device.
"""
return torch.device("cuda", get_local_rank())
if torch.cuda.is_available():
return torch.device("cuda", get_local_rank())
return torch.device("cpu")
def barrier_if_distributed(*args, **kwargs):
"""
Synchronizes all processes if under distributed context.
"""
import torch.distributed as dist
if dist.is_initialized():
if _is_dist():
import torch.distributed as dist
return dist.barrier(*args, **kwargs)

View File

@ -144,8 +144,8 @@ class Slice(torch.autograd.Function):
dim_size = list(grad_output.size())
split_size = dim_size[0]
dim_size[0] = dim_size[0] * ctx.seq_world_size
output = torch.empty(dim_size, dtype=grad_output.dtype, device=torch.cuda.current_device())
dist._all_gather_base(output, grad_output, group=ctx.group)
output = torch.empty(dim_size, dtype=grad_output.dtype, device=grad_output.device)
dist.all_gather_into_tensor(output, grad_output, group=ctx.group)
return (None, torch.cat(output.split(split_size), dim=ctx.dim), None)
@ -168,8 +168,8 @@ class Gather(torch.autograd.Function):
split_size = dim_size[0]
ctx.part_size = dim_size[dim]
dim_size[0] = dim_size[0] * seq_world_size
output = torch.empty(dim_size, dtype=local_input.dtype, device=torch.cuda.current_device())
dist._all_gather_base(output, local_input.contiguous(), group=ctx.group)
output = torch.empty(dim_size, dtype=local_input.dtype, device=local_input.device)
dist.all_gather_into_tensor(output, local_input.contiguous(), group=ctx.group)
return torch.cat(output.split(split_size), dim=dim)
@staticmethod

View File

@ -87,8 +87,4 @@ class FlashAttentionVarlen(nn.Module):
def forward(self, *args, **kwargs):
kwargs["deterministic"] = torch.are_deterministic_algorithms_enabled()
try:
from flash_attn import flash_attn_varlen_func
return flash_attn_varlen_func(*args, **kwargs)
except ImportError:
return pytorch_varlen_attention(*args, **kwargs)
return pytorch_varlen_attention(*args, **kwargs)

View File

@ -85,8 +85,4 @@ class FlashAttentionVarlen(nn.Module):
def forward(self, *args, **kwargs):
kwargs["deterministic"] = torch.are_deterministic_algorithms_enabled()
try:
from flash_attn import flash_attn_varlen_func
return flash_attn_varlen_func(*args, **kwargs)
except ImportError:
return pytorch_varlen_attention(*args, **kwargs)
return pytorch_varlen_attention(*args, **kwargs)