diff --git a/modules/seedvr/src/common/decorators.py b/modules/seedvr/src/common/decorators.py index cf504b3bf..52ab8ac03 100644 --- a/modules/seedvr/src/common/decorators.py +++ b/modules/seedvr/src/common/decorators.py @@ -118,9 +118,9 @@ def log_runtime(func: Callable) -> Callable: @functools.wraps(func) def wrapped(*args, **kwargs): - torch.distributed.barrier() + barrier_if_distributed() result = func(*args, **kwargs) - torch.distributed.barrier() + barrier_if_distributed() return result return wrapped diff --git a/modules/seedvr/src/common/distributed/advanced.py b/modules/seedvr/src/common/distributed/advanced.py index d2479bebf..5da21a8ad 100644 --- a/modules/seedvr/src/common/distributed/advanced.py +++ b/modules/seedvr/src/common/distributed/advanced.py @@ -16,13 +16,15 @@ Advanced distributed functions for sequence parallel. """ +from __future__ import annotations + +import logging from typing import Optional, List import torch -import torch.distributed as dist -from torch.distributed.device_mesh import DeviceMesh, init_device_mesh -from torch.distributed.fsdp import ShardingStrategy -from .basic import get_global_rank, get_world_size +from .basic import _is_dist, get_global_rank, get_world_size + +logger = logging.getLogger(__name__) _DATA_PARALLEL_GROUP = None @@ -61,7 +63,10 @@ def get_data_parallel_rank() -> int: Get data parallel rank. """ group = get_data_parallel_group() - return dist.get_rank(group) if group else get_global_rank() + if group and _is_dist(): + import torch.distributed as dist + return dist.get_rank(group) + return get_global_rank() def get_data_parallel_world_size() -> int: @@ -69,7 +74,10 @@ def get_data_parallel_world_size() -> int: Get data parallel world size. """ group = get_data_parallel_group() - return dist.get_world_size(group) if group else get_world_size() + if group and _is_dist(): + import torch.distributed as dist + return dist.get_world_size(group) + return get_world_size() def get_sequence_parallel_rank() -> int: @@ -77,7 +85,10 @@ def get_sequence_parallel_rank() -> int: Get sequence parallel rank. """ group = get_sequence_parallel_group() - return dist.get_rank(group) if group else 0 + if group and _is_dist(): + import torch.distributed as dist + return dist.get_rank(group) + return 0 def get_sequence_parallel_world_size() -> int: @@ -85,7 +96,10 @@ def get_sequence_parallel_world_size() -> int: Get sequence parallel world size. """ group = get_sequence_parallel_group() - return dist.get_world_size(group) if group else 1 + if group and _is_dist(): + import torch.distributed as dist + return dist.get_world_size(group) + return 1 def get_model_shard_cpu_intra_group() -> Optional[dist.ProcessGroup]: @@ -120,11 +134,14 @@ def init_sequence_parallel(sequence_parallel_size: int): """ Initialize sequence parallel. """ + if not _is_dist(): + logger.debug("Skipping init_sequence_parallel: distributed not initialized") + return + import torch.distributed as dist global _DATA_PARALLEL_GROUP global _SEQUENCE_PARALLEL_GROUP global _SEQUENCE_PARALLEL_CPU_GROUP global _SEQUENCE_PARALLEL_GLOBAL_RANKS - assert dist.is_initialized() world_size = dist.get_world_size() rank = dist.get_rank() data_parallel_size = world_size // sequence_parallel_size @@ -142,17 +159,22 @@ def init_sequence_parallel(sequence_parallel_size: int): def init_model_shard_group( *, - sharding_strategy: ShardingStrategy, - device_mesh: Optional[DeviceMesh] = None, + sharding_strategy=None, + device_mesh=None, ): """ Initialize process group of model sharding. """ + if not _is_dist(): + logger.debug("Skipping init_model_shard_group: distributed not initialized") + return + import torch.distributed as dist + from torch.distributed.device_mesh import init_device_mesh + from torch.distributed.fsdp import ShardingStrategy global _MODEL_SHARD_INTER_GROUP global _MODEL_SHARD_INTRA_GROUP global _MODEL_SHARD_CPU_INTER_GROUP global _MODEL_SHARD_CPU_INTRA_GROUP - assert dist.is_initialized() world_size = dist.get_world_size() if device_mesh is not None: num_shards_per_group = device_mesh.shape[1] @@ -182,7 +204,10 @@ def get_sequence_parallel_global_ranks() -> List[int]: that the caller rank belongs to. """ if _SEQUENCE_PARALLEL_GLOBAL_RANKS is None: - return [dist.get_rank()] + if _is_dist(): + import torch.distributed as dist + return [dist.get_rank()] + return [0] return _SEQUENCE_PARALLEL_GLOBAL_RANKS diff --git a/modules/seedvr/src/common/distributed/basic.py b/modules/seedvr/src/common/distributed/basic.py index d880009c9..2ff8d2524 100644 --- a/modules/seedvr/src/common/distributed/basic.py +++ b/modules/seedvr/src/common/distributed/basic.py @@ -20,6 +20,14 @@ import os import torch +def _is_dist() -> bool: + """Check if torch.distributed is available and initialized. + Follows the same guard convention as flash_attn (utils/generation.py). + """ + import torch.distributed as dist + return dist.is_available() and dist.is_initialized() + + def get_global_rank() -> int: """ Get the global rank, the global index of the GPU. @@ -45,15 +53,17 @@ def get_device() -> torch.device: """ Get current rank device. """ - return torch.device("cuda", get_local_rank()) + if torch.cuda.is_available(): + return torch.device("cuda", get_local_rank()) + return torch.device("cpu") def barrier_if_distributed(*args, **kwargs): """ Synchronizes all processes if under distributed context. """ - import torch.distributed as dist - if dist.is_initialized(): + if _is_dist(): + import torch.distributed as dist return dist.barrier(*args, **kwargs) diff --git a/modules/seedvr/src/common/distributed/ops.py b/modules/seedvr/src/common/distributed/ops.py index bba121b85..f3101b62c 100644 --- a/modules/seedvr/src/common/distributed/ops.py +++ b/modules/seedvr/src/common/distributed/ops.py @@ -144,8 +144,8 @@ class Slice(torch.autograd.Function): dim_size = list(grad_output.size()) split_size = dim_size[0] dim_size[0] = dim_size[0] * ctx.seq_world_size - output = torch.empty(dim_size, dtype=grad_output.dtype, device=torch.cuda.current_device()) - dist._all_gather_base(output, grad_output, group=ctx.group) + output = torch.empty(dim_size, dtype=grad_output.dtype, device=grad_output.device) + dist.all_gather_into_tensor(output, grad_output, group=ctx.group) return (None, torch.cat(output.split(split_size), dim=ctx.dim), None) @@ -168,8 +168,8 @@ class Gather(torch.autograd.Function): split_size = dim_size[0] ctx.part_size = dim_size[dim] dim_size[0] = dim_size[0] * seq_world_size - output = torch.empty(dim_size, dtype=local_input.dtype, device=torch.cuda.current_device()) - dist._all_gather_base(output, local_input.contiguous(), group=ctx.group) + output = torch.empty(dim_size, dtype=local_input.dtype, device=local_input.device) + dist.all_gather_into_tensor(output, local_input.contiguous(), group=ctx.group) return torch.cat(output.split(split_size), dim=dim) @staticmethod diff --git a/modules/seedvr/src/models/dit/attention.py b/modules/seedvr/src/models/dit/attention.py index b0dda518a..3ee9ab565 100644 --- a/modules/seedvr/src/models/dit/attention.py +++ b/modules/seedvr/src/models/dit/attention.py @@ -87,8 +87,4 @@ class FlashAttentionVarlen(nn.Module): def forward(self, *args, **kwargs): kwargs["deterministic"] = torch.are_deterministic_algorithms_enabled() - try: - from flash_attn import flash_attn_varlen_func - return flash_attn_varlen_func(*args, **kwargs) - except ImportError: - return pytorch_varlen_attention(*args, **kwargs) + return pytorch_varlen_attention(*args, **kwargs) diff --git a/modules/seedvr/src/models/dit_v2/attention.py b/modules/seedvr/src/models/dit_v2/attention.py index a88fd7d13..33d34428b 100644 --- a/modules/seedvr/src/models/dit_v2/attention.py +++ b/modules/seedvr/src/models/dit_v2/attention.py @@ -85,8 +85,4 @@ class FlashAttentionVarlen(nn.Module): def forward(self, *args, **kwargs): kwargs["deterministic"] = torch.are_deterministic_algorithms_enabled() - try: - from flash_attn import flash_attn_varlen_func - return flash_attn_varlen_func(*args, **kwargs) - except ImportError: - return pytorch_varlen_attention(*args, **kwargs) + return pytorch_varlen_attention(*args, **kwargs)