344 lines
15 KiB
Python
344 lines
15 KiB
Python
"""
|
|
This file includes public APIs for FSDP such as the classes used for the
|
|
constructor arguments.
|
|
"""
|
|
|
|
from dataclasses import dataclass
|
|
from enum import auto, Enum
|
|
|
|
from typing import Optional
|
|
|
|
import torch
|
|
|
|
__all__ = [
|
|
"ShardingStrategy",
|
|
"BackwardPrefetch",
|
|
"MixedPrecision",
|
|
"CPUOffload",
|
|
"StateDictType",
|
|
"StateDictConfig",
|
|
"FullStateDictConfig",
|
|
"LocalStateDictConfig",
|
|
"ShardedStateDictConfig",
|
|
"OptimStateDictConfig",
|
|
"FullOptimStateDictConfig",
|
|
"LocalOptimStateDictConfig",
|
|
"ShardedOptimStateDictConfig",
|
|
"StateDictSettings",
|
|
]
|
|
|
|
|
|
class ShardingStrategy(Enum):
|
|
"""
|
|
This specifies the sharding strategy to be used for distributed training by
|
|
:class:`FullyShardedDataParallel`.
|
|
|
|
- ``FULL_SHARD``: Parameters, gradients, and optimizer states are sharded.
|
|
For the parameters, this strategy unshards (via all-gather) before the
|
|
forward, reshards after the forward, unshards before the backward
|
|
computation, and reshards after the backward computation. For gradients,
|
|
it synchronizes and shards them (via reduce-scatter) after the backward
|
|
computation. The sharded optimizer states are updated locally per rank.
|
|
- ``SHARD_GRAD_OP``: Gradients and optimizer states are sharded during
|
|
computation, and additionally, parameters are sharded outside
|
|
computation. For the parameters, this strategy unshards before the
|
|
forward, does not reshard them after the forward, and only reshards them
|
|
after the backward computation. The sharded optimizer states are updated
|
|
locally per rank. Inside ``no_sync()``, the parameters are not resharded
|
|
after the backward computation.
|
|
- ``NO_SHARD``: Parameters, gradients, and optimizer states are not sharded
|
|
but instead replicated across ranks similar to PyTorch's
|
|
:class:`DistributedDataParallel` API. For gradients, this strategy
|
|
synchronizes them (via all-reduce) after the backward computation. The
|
|
unsharded optimizer states are updated locally per rank.
|
|
- ``HYBRID_SHARD``: Apply ``FULL_SHARD`` within a node, and replicate parameters across
|
|
nodes. This results in reduced communication volume as expensive all-gathers and
|
|
reduce-scatters are only done within a node, which can be more performant for medium
|
|
-sized models.
|
|
- ``_HYBRID_SHARD_ZERO2``: Apply ``SHARD_GRAD_OP`` within a node, and replicate parameters across
|
|
nodes. This is like ``HYBRID_SHARD``, except this may provide even higher throughput
|
|
since the unsharded parameters are not freed after the forward pass, saving the
|
|
all-gathers in the pre-backward.
|
|
"""
|
|
|
|
FULL_SHARD = auto()
|
|
SHARD_GRAD_OP = auto()
|
|
NO_SHARD = auto()
|
|
HYBRID_SHARD = auto()
|
|
_HYBRID_SHARD_ZERO2 = auto()
|
|
|
|
|
|
class BackwardPrefetch(Enum):
|
|
"""
|
|
This configures explicit backward prefetching, which can improve throughput
|
|
but may slightly increase peak memory usage.
|
|
|
|
For NCCL backend, any collectives, even if issued in different streams,
|
|
contend for the same per-device NCCL stream, which is why the relative
|
|
order in which the collectives are issued matters for overlapping. The
|
|
different backward prefetching settings correspond to different orderings.
|
|
|
|
- ``BACKWARD_PRE``: This prefetches the next set of parameters before the
|
|
current set of parameter's gradient computation. This improves backward
|
|
pass throughput by overlapping communication (next all-gather) and
|
|
computation (current gradient computation).
|
|
- ``BACKWARD_POST``: This prefetches the next set of parameters after the
|
|
current set of parameter's gradient computation. This may improve
|
|
backward pass throughput by overlapping communication (current
|
|
reduce-scatter) and computation (next gradient computation).
|
|
Specifically, the next all-gather is reordered to be before the current
|
|
reduce-scatter.
|
|
|
|
.. note:: If the increase in peak memory usage from prefetching is an
|
|
issue, you may consider passing ``limit_all_gathers=True`` to the FSDP
|
|
constructor, which may help reduce peak memory usage in some cases.
|
|
"""
|
|
|
|
# NOTE: For both modes, the ordering that defines "current" and "next" is
|
|
# not always correct in the current implementation, so this may cause some
|
|
# performance regression for some models.
|
|
BACKWARD_PRE = auto()
|
|
BACKWARD_POST = auto()
|
|
|
|
|
|
@dataclass
|
|
class MixedPrecision:
|
|
"""
|
|
This configures FSDP-native mixed precision training.
|
|
|
|
Attributes:
|
|
param_dtype (torch.dtype): This specifies the dtype for model
|
|
parameters, inputs (when ``cast_forward_inputs`` or
|
|
``cast_root_forward_inputs``is set to
|
|
``True``), and therefore the dtype for computation.
|
|
However, outside the forward and backward passes, parameters are in
|
|
full precision. Model checkpointing always happens in full
|
|
precision.
|
|
reduce_dtype (torch.dtype): This specifies the dtype for gradient
|
|
reduction, which is permitted to differ from ``param_dtype``.
|
|
buffer_dtype (torch.dtype): This specifies the dtype for buffers. FSDP
|
|
does not shard buffers, casts them to ``buffer_dtype`` in the first
|
|
forward pass, and keeps them in that dtype thereafter. Model
|
|
checkpointing always happens in full precision.
|
|
keep_low_precision_grads (bool): This specifies whether to upcast
|
|
gradients back to the full parameter precision after the backward
|
|
pass. This may be set to ``False`` to save memory if using custom
|
|
optimizers that can perform the optimizer step in ``reduce_dtype``.
|
|
(Default: ``False``)
|
|
cast_forward_inputs (bool): Cast floating point tensors in the forward
|
|
arguments and keyword arguments to ``param_dtype``.
|
|
(Default: ``False``)
|
|
cast_root_forward_inputs (bool): Cast floating point tensors in the forward
|
|
arguments and keyword arguments to ``param_dtype`` for the root FSDP instance.
|
|
It takes precedence over ``cast_forward_inputs`` for the root FSDP instance.
|
|
(Default: ``True``)
|
|
|
|
.. note:: This API is experimental and subject to change.
|
|
|
|
.. note:: Only floating point tensors are cast to their specified dtypes.
|
|
|
|
.. note:: In ``summon_full_params``, parameters are forced to full
|
|
precision, but buffers are not.
|
|
|
|
.. note:: ``state_dict`` checkpoints parameters and buffers in full
|
|
precision. For buffers, this is only supported for
|
|
``StateDictType.FULL_STATE_DICT``.
|
|
|
|
.. note:: Each low precision dtype must be specified explicitly. For
|
|
example, ``MixedPrecision(reduce_dtype=torch.float16)`` only specifies
|
|
the reduction dtype to be low precision, and FSDP will not cast
|
|
parameters or buffers.
|
|
|
|
.. note:: If a ``reduce_dtype`` is not specified, then gradient reduction
|
|
happens in ``param_dtype`` if specified or the original parameter dtype
|
|
otherwise.
|
|
|
|
.. note:: If the user passes a model with ``BatchNorm`` modules and an
|
|
``auto_wrap_policy`` to the FSDP constructor, then FSDP will disable
|
|
mixed precision for ``BatchNorm`` modules by wrapping them separately
|
|
in their own FSDP instance with mixed precision disabled. This is due
|
|
to some missing low precision ``BatchNorm`` kernels. If the user does
|
|
not use an ``auto_wrap_policy``, then the user must take care to not
|
|
use mixed precision for FSDP instances containing ``BatchNorm``
|
|
modules.
|
|
|
|
.. note:: ``MixedPrecision`` has ``cast_root_forward_inputs=True`` and
|
|
``cast_forward_inputs=False`` by default. For the root FSDP instance,
|
|
its ``cast_root_forward_inputs`` takes precedence over its
|
|
``cast_forward_inputs``. For non-root FSDP instances, their
|
|
``cast_root_forward_inputs`` values are ignored. The default setting is
|
|
sufficient for the typical case where each FSDP instance has the same
|
|
``MixedPrecision`` configuration and only needs to cast inputs to the
|
|
``param_dtype`` at the beginning of the model's forward pass.
|
|
|
|
.. note:: For nested FSDP instances with different ``MixedPrecision``
|
|
configurations, we recommend setting individual ``cast_forward_inputs``
|
|
values to configure casting inputs or not before each instance's
|
|
forward. In such a case, since the casts happen before each FSDP
|
|
instance's forward, a parent FSDP instance should have its non-FSDP
|
|
submodules run before its FSDP submodules to avoid the activation dtype
|
|
being changed due to a different ``MixedPrecision`` configuration.
|
|
|
|
Example::
|
|
|
|
>>> # xdoctest: +SKIP("undefined variables")
|
|
>>> model = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))
|
|
>>> model[1] = FSDP(
|
|
>>> model[1],
|
|
>>> mixed_precision=MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True),
|
|
>>> )
|
|
>>> model = FSDP(
|
|
>>> model,
|
|
>>> mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, cast_forward_inputs=True),
|
|
>>> )
|
|
|
|
The above shows a working example. On the other hand, if ``model[1]``
|
|
were replaced with ``model[0]``, meaning that the submodule using
|
|
different ``MixedPrecision`` ran its forward first, then ``model[1]``
|
|
would incorrectly see ``float16`` activations instead of ``bfloat16``
|
|
ones.
|
|
|
|
"""
|
|
|
|
param_dtype: Optional[torch.dtype] = None
|
|
reduce_dtype: Optional[torch.dtype] = None
|
|
buffer_dtype: Optional[torch.dtype] = None
|
|
keep_low_precision_grads: bool = False
|
|
cast_forward_inputs: bool = False
|
|
cast_root_forward_inputs: bool = True
|
|
|
|
|
|
@dataclass
|
|
class CPUOffload:
|
|
"""
|
|
This configures CPU offloading.
|
|
|
|
Attributes:
|
|
offload_params (bool): This specifies whether to offload parameters to
|
|
CPU when not involved in computation. If enabled, this implicitly
|
|
offloads gradients to CPU as well. This is to support the optimizer
|
|
step, which requires parameters and gradients to be on the same
|
|
device.
|
|
"""
|
|
|
|
offload_params: bool = False
|
|
|
|
|
|
class StateDictType(Enum):
|
|
"""
|
|
This enum indicates that which type of ``state_dict`` the FSDP module is
|
|
currently processing (returning or loading).
|
|
The default value is FULL_STATE_DICT to comply the PyTorch convention.
|
|
..note::
|
|
FSDP currently supports three types of ``state_dict``:
|
|
1. ``state_dict/load_state_dict`: this pair of APIs return and load
|
|
the non-sharded, unflattened parameters. The semantics is the
|
|
same as using DDP.
|
|
2. ``_local_state_dict/_load_local_state_dict``: this pair of APIs return
|
|
and load local sharded, flattened parameters. The values returned
|
|
by ``_local_state_dict`` can be directly used by FSDP and is only
|
|
meaningful to FSDP (because parameters are flattened). Note that
|
|
these APIs are meant for use via the :func:`state_dict_type`
|
|
context manager as follows:
|
|
>>> # xdoctest: +SKIP("undefined variables")
|
|
>>> with fsdp.state_dict_type(StateDictType.LOCAL_STATE_DICT):
|
|
... state = fsdp.state_dict() # loads local state dict
|
|
3. ``_sharded_state_dict/_load_sharded_state_dict``: this pair of APIs
|
|
return and load sharded, unflattened parameters. The ``state_dict``
|
|
return by ``sharded_state_dict`` can be used by all other parallel
|
|
schemes (resharding may be required).
|
|
"""
|
|
|
|
FULL_STATE_DICT = auto()
|
|
LOCAL_STATE_DICT = auto()
|
|
SHARDED_STATE_DICT = auto()
|
|
|
|
|
|
@dataclass
|
|
class StateDictConfig:
|
|
"""
|
|
``StateDictConfig`` is the base class for all state_dict configuration classes.
|
|
Users should instantiate a child version (i.e. ``FullStateDictConfig``) in
|
|
order to configure settings for the particular type of ``state_dict``
|
|
implementation FSDP will use.
|
|
"""
|
|
|
|
offload_to_cpu: bool = False
|
|
|
|
|
|
@dataclass
|
|
class FullStateDictConfig(StateDictConfig):
|
|
"""
|
|
``FullStateDictConfig`` is a config class meant to be used with
|
|
``StateDictType.FULL_STATE_DICT``. Currently, it accepts two parameters,
|
|
``offload_to_cpu`` and ``rank0_only`` which can be configured to offload
|
|
the full ``state_dict`` to CPU and to materialize the ``state_dict`` on
|
|
rank 0 only. When used, it is recommended to enable both of these flags
|
|
together to optimize memory savings when taking checkpoints. Note that
|
|
this config class is meant for user via the :func:`state_dict_type`
|
|
context manager as follows:
|
|
>>> # xdoctest: +SKIP("undefined variables")
|
|
>>> fsdp = FSDP(model, auto_wrap_policy=...)
|
|
>>> cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
|
|
>>> with FullyShardedDataParallel.state_dict_type(fsdp, StateDictType.FULL_STATE_DICT, cfg):
|
|
>>> state = fsdp.state_dict()
|
|
>>> # state will be empty on non rank 0 and contain CPU tensors on rank 0.
|
|
>>> # To reload checkpoint for inference, finetuning, transfer learning, etc:
|
|
>>> model = model_fn() # Initialize model on CPU in preparation for wrapping with FSDP
|
|
>>> if dist.get_rank() == 0:
|
|
>>> # Load checkpoint only on rank 0 to avoid memory redundancy
|
|
>>> state_dict = torch.load("my_checkpoint.pt")
|
|
>>> model.load_state_dict(state_dict)
|
|
>>> # All ranks initialize FSDP module as usual. ``sync_module_states`` argument
|
|
>>> # communicates loaded checkpoint states from rank 0 to rest of the world.
|
|
>>> fsdp = FSDP(model, device_id=torch.cuda.current_device(), auto_wrap_policy=..., sync_module_states=True)
|
|
>>> # After this point, all ranks have FSDP model with loaded checkpoint.
|
|
"""
|
|
|
|
rank0_only: bool = False
|
|
|
|
|
|
@dataclass
|
|
class LocalStateDictConfig(StateDictConfig):
|
|
pass
|
|
|
|
|
|
@dataclass
|
|
class ShardedStateDictConfig(StateDictConfig):
|
|
pass
|
|
|
|
|
|
@dataclass
|
|
class OptimStateDictConfig:
|
|
"""
|
|
``OptimStateDictConfig`` is the base class for all optimizer state_dict
|
|
configuration classes. Users should instantiate a child version
|
|
(i.e. ``FullOptimStateDictConfig``) in order to configure settings for the
|
|
particular type of ``optim_state_dict`` implementation FSDP will use.
|
|
"""
|
|
|
|
# TODO: actually use this flag in the _optim_utils.py
|
|
offload_to_cpu: bool = True
|
|
|
|
|
|
@dataclass
|
|
class FullOptimStateDictConfig(OptimStateDictConfig):
|
|
rank0_only: bool = False
|
|
|
|
|
|
@dataclass
|
|
class LocalOptimStateDictConfig(OptimStateDictConfig):
|
|
offload_to_cpu: bool = False
|
|
|
|
|
|
@dataclass
|
|
class ShardedOptimStateDictConfig(OptimStateDictConfig):
|
|
pass
|
|
|
|
|
|
@dataclass
|
|
class StateDictSettings:
|
|
state_dict_type: StateDictType
|
|
state_dict_config: StateDictConfig
|
|
optim_state_dict_config: OptimStateDictConfig
|