334 lines
11 KiB
Python
334 lines
11 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
|
|
import copy
|
|
import dataclasses
|
|
from typing import Dict, List, Optional, Sequence, Tuple, Union, cast
|
|
from torch.distributed.checkpoint.planner import LoadPlan
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch.distributed._shard.sharded_tensor.api import ShardedTensor
|
|
from torch.distributed._shard.sharded_tensor.metadata import TensorProperties
|
|
from torch.distributed._shard.sharded_tensor.shard import Shard
|
|
from torch.distributed._shard.sharding_spec.chunk_sharding_spec import (
|
|
ChunkShardingSpec,
|
|
)
|
|
|
|
import torch.distributed.checkpoint as dist_cp
|
|
from torch.distributed.checkpoint.metadata import (
|
|
BytesStorageMetadata,
|
|
Metadata,
|
|
MetadataIndex,
|
|
STATE_DICT_TYPE,
|
|
TensorStorageMetadata,
|
|
)
|
|
from torch.distributed.checkpoint.planner_helpers import (
|
|
_create_sharded_read_items,
|
|
_create_read_items,
|
|
)
|
|
from torch.distributed.remote_device import _remote_device
|
|
|
|
from torch.distributed._tensor import DTensor
|
|
from torch.distributed.checkpoint.default_planner import (
|
|
DefaultLoadPlanner,
|
|
)
|
|
from torch.distributed._shard.api import _shard_tensor
|
|
|
|
from torch.distributed.checkpoint._nested_dict import unflatten_state_dict
|
|
from torch.distributed.checkpoint.utils import (
|
|
_element_wise_add,
|
|
_element_wise_sub,
|
|
)
|
|
|
|
STATE_DICT_2D_LAYOUT = Dict[str, Tuple[Optional[Sequence[int]], Sequence[int]]]
|
|
|
|
|
|
# TODO: Update docstrings for optimizer.py
|
|
__all__ = [
|
|
"load_sharded_optimizer_state_dict",
|
|
]
|
|
|
|
|
|
def _gen_rank_device(global_rank: int) -> str:
|
|
if torch.cuda.is_available():
|
|
return f"cuda:{global_rank % torch.cuda.device_count()}"
|
|
return "cpu"
|
|
|
|
|
|
def _create_colwise_spec(
|
|
pg: Optional[dist.ProcessGroup] = None,
|
|
) -> ChunkShardingSpec:
|
|
if pg is None:
|
|
placements = [
|
|
f"rank:{idx}/{_gen_rank_device(idx)}"
|
|
for idx in range(dist.get_world_size())
|
|
]
|
|
else:
|
|
placements = [
|
|
f"rank:{idx}/{_gen_rank_device(dist.get_global_rank(pg, idx))}"
|
|
for idx in range(pg.size())
|
|
]
|
|
return ChunkShardingSpec(
|
|
dim=0,
|
|
placements=cast(List[Union[_remote_device, str]], placements),
|
|
)
|
|
|
|
|
|
def _is_nested_tensor(val: torch.Tensor) -> bool:
|
|
if type(val) is ShardedTensor:
|
|
if len(val.local_shards()) == 0:
|
|
return False
|
|
if type(val.local_shards()[0].tensor) is ShardedTensor:
|
|
return True
|
|
if type(val.local_shards()[0].tensor) is DTensor:
|
|
raise ValueError(
|
|
"Cannot handle DTensor nested insided ShardedTensor"
|
|
)
|
|
elif type(val) is DTensor and (
|
|
type(val._local_tensor) is DTensor
|
|
or type(val._local_tensor) is ShardedTensor
|
|
):
|
|
raise ValueError("Cannot handle nested DTensor")
|
|
return False
|
|
|
|
|
|
def _alloc_tensor(props: TensorProperties, size: Sequence[int]) -> torch.Tensor:
|
|
return torch.empty(
|
|
size=size,
|
|
dtype=props.dtype,
|
|
layout=props.layout,
|
|
requires_grad=props.requires_grad,
|
|
pin_memory=props.pin_memory,
|
|
device=cast(torch.device, torch.cuda.current_device()),
|
|
)
|
|
|
|
|
|
def _get_state_dict_2d_layout(
|
|
state_dict: STATE_DICT_TYPE,
|
|
) -> Tuple[STATE_DICT_2D_LAYOUT, Optional[dist.ProcessGroup]]:
|
|
"""
|
|
We have to load the right TP slice of the optimizer state.
|
|
This is not easy since the per-tensor slicing can't be inferred from checkpoint metadata.
|
|
We take advantage of the model state_dict producing a sliced ST to figure out what we need to load.
|
|
This is pretty fragile and it might be easier for FSDP to compute this info for us.
|
|
Returns a dictionary where keys are the same of the state_dict and the value is a tuple of
|
|
(offset, size) for the current rank TP slice.
|
|
N.B. The state_dict *MUST* come from FSDP.sharded_state_dict.
|
|
"""
|
|
specs: STATE_DICT_2D_LAYOUT = {}
|
|
dp_pg: Optional[dist.ProcessGroup] = None
|
|
for key, value in state_dict.items():
|
|
specs[key] = (None, value.size())
|
|
if _is_nested_tensor(value):
|
|
assert (
|
|
len(value.local_shards()) == 1
|
|
), "Cannot handle ST with multiple shards"
|
|
assert isinstance(
|
|
value, ShardedTensor
|
|
), "Can only handle nested ShardedTensor"
|
|
shard = value.local_shards()[0]
|
|
specs[key] = (
|
|
shard.metadata.shard_offsets,
|
|
shard.metadata.shard_sizes,
|
|
)
|
|
dp_pg = shard.tensor._process_group # type: ignore[attr-defined]
|
|
|
|
return (
|
|
specs,
|
|
dp_pg,
|
|
)
|
|
|
|
|
|
class _ReaderWithOffset(DefaultLoadPlanner):
|
|
translation: Dict[MetadataIndex, MetadataIndex]
|
|
state_dict: STATE_DICT_TYPE
|
|
metadata: Metadata
|
|
|
|
def __init__(self, fqn_to_offset: Dict[str, Sequence[int]]) -> None:
|
|
super().__init__()
|
|
self.fqn_to_offset = fqn_to_offset
|
|
self.metadata = Metadata({})
|
|
self.state_dict = {}
|
|
self.translation = {}
|
|
|
|
def create_local_plan(self) -> LoadPlan:
|
|
requests = []
|
|
self.translation = {}
|
|
for fqn, obj in self.state_dict.items():
|
|
md = self.metadata.state_dict_metadata[fqn]
|
|
if not isinstance(obj, ShardedTensor):
|
|
requests += _create_read_items(fqn, md, obj)
|
|
continue
|
|
|
|
if fqn not in self.fqn_to_offset:
|
|
requests += _create_read_items(fqn, md, obj)
|
|
continue
|
|
|
|
offset = self.fqn_to_offset[fqn]
|
|
|
|
assert len(obj.local_shards()) == 1
|
|
original_shard = obj.local_shards()[0]
|
|
shard_md = copy.deepcopy(original_shard.metadata)
|
|
shard_md.shard_offsets = _element_wise_add(
|
|
shard_md.shard_offsets, offset
|
|
)
|
|
local_shards = [Shard(original_shard.tensor, shard_md)]
|
|
|
|
reqs = _create_sharded_read_items(
|
|
fqn, cast(TensorStorageMetadata, md), local_shards
|
|
)
|
|
# TODO: The WriteItems will have a displaced MetadataIndex, fix it.
|
|
# TODO: we should change _create_sharded_read_items to have more ergonomic API
|
|
for wi in reqs:
|
|
assert wi.dest_index.offset is not None
|
|
original_offset = _element_wise_sub(
|
|
wi.dest_index.offset, offset
|
|
)
|
|
original_index = dataclasses.replace(
|
|
wi.dest_index, offset=torch.Size(original_offset)
|
|
)
|
|
self.translation[wi.dest_index] = original_index
|
|
|
|
requests += reqs
|
|
return LoadPlan(requests)
|
|
|
|
def lookup_tensor(self, index: MetadataIndex) -> torch.Tensor:
|
|
return super().lookup_tensor(self.translation.get(index, index))
|
|
|
|
|
|
def load_sharded_optimizer_state_dict(
|
|
model_state_dict: STATE_DICT_TYPE,
|
|
optimizer_key: str,
|
|
storage_reader: dist_cp.StorageReader,
|
|
) -> STATE_DICT_TYPE:
|
|
"""
|
|
Loads a state_dict to be used in conjuntion with FSDP sharded optimizer state.
|
|
This is the current recommended way to checkpoint is FSDP
|
|
>>> # xdoctest: +SKIP
|
|
>>> import torch.distributed.checkpoint as dist_cp
|
|
>>> # Save
|
|
>>> model: torch.nn.Model
|
|
>>> optim_params = model.parameters()
|
|
>>> optim = torch.optim.SGD(optim_params, lr=0.01)
|
|
>>>
|
|
>>> with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
|
|
>>> state_dict = {
|
|
>>> "optimizer": FSDP.sharded_optim_state_dict(model, optim, optim_params),
|
|
>>> "model": model.state_dict()
|
|
>>> }
|
|
>>> dist_cp.save_state_dict(
|
|
>>> state_dict=optim_state,
|
|
>>> storage_writer=dist_cp.FileSystemWriter("checkpoint"),
|
|
>>> planner=dist_cp.DefaultSavePlanner(),
|
|
>>> )
|
|
>>>
|
|
>>> # Load
|
|
>>> with FSDP.state_dict_type(model_tp, StateDictType.SHARDED_STATE_DICT):
|
|
>>> model_state_dict = model_tp.state_dict()
|
|
>>> checkpoint = {
|
|
>>> "model": model_state_dict
|
|
>>> }
|
|
>>> dist_cp.load_state_dict(
|
|
>>> state_dict=checkpoint,
|
|
>>> storage_reader=dist_cp.FileSystemReader(checkpoint_file),
|
|
>>> planner=dist_cp.DefaultLoadPlanner(),
|
|
>>> )
|
|
>>> model.load_state_dict(checkpoint["model_state"])
|
|
>>>
|
|
>>> optim_state = sp_cp.load_sharded_optimizer_state_dict(
|
|
>>> model_state_dict,
|
|
>>> optimizer_key="optimizer",
|
|
>>> storage_reader=dist_cp.FileSystemReader("checkpoint"),
|
|
>>> )
|
|
>>>
|
|
>>> flattened_osd = FSDP.flatten_sharded_optim_state_dict(
|
|
>>> optim_state["optimizer"], model, optim
|
|
>>> )
|
|
>>>
|
|
>>> optim.load_state_dict(flattened_osd)
|
|
"""
|
|
metadata = storage_reader.read_metadata()
|
|
|
|
layout_specs, dp_pg = _get_state_dict_2d_layout(model_state_dict)
|
|
|
|
if dp_pg is None:
|
|
sharding_spec = ChunkShardingSpec(
|
|
dim=0,
|
|
placements=[
|
|
f"rank:{i}/cuda:{i % torch.cuda.device_count()}"
|
|
for i in range(dist.get_world_size())
|
|
],
|
|
)
|
|
else:
|
|
sharding_spec = _create_colwise_spec(dp_pg)
|
|
|
|
# Create a state_dict for optimizer state
|
|
state_dict: STATE_DICT_TYPE = {}
|
|
|
|
fqn_to_offset: Dict[str, Sequence[int]] = {}
|
|
for key, value in metadata.state_dict_metadata.items():
|
|
key_path = metadata.planner_data[key]
|
|
if key_path[0] != optimizer_key:
|
|
continue
|
|
|
|
if isinstance(value, BytesStorageMetadata):
|
|
state_dict[key] = "<bytes_io>"
|
|
continue
|
|
|
|
# value: TensorStorageMetadata
|
|
if value.size.numel() == 1:
|
|
state_dict[key] = _alloc_tensor(value.properties, value.size)
|
|
elif dp_pg is None:
|
|
state_dict[key] = _shard_tensor(
|
|
_alloc_tensor(value.properties, value.size), sharding_spec
|
|
)
|
|
else:
|
|
spec_key = key_path[2]
|
|
alloc_size = layout_specs.get(spec_key, (None, value.size))[1]
|
|
|
|
st_md = sharding_spec.build_metadata(
|
|
torch.Size(alloc_size), value.properties
|
|
)
|
|
local_shards = []
|
|
current_rank = dist.get_rank(dp_pg)
|
|
for shard_md in st_md.shards_metadata:
|
|
if (
|
|
cast(_remote_device, shard_md.placement).rank()
|
|
!= current_rank
|
|
):
|
|
continue
|
|
local_shards.append(
|
|
Shard(
|
|
tensor=_alloc_tensor(
|
|
value.properties, shard_md.shard_sizes
|
|
),
|
|
metadata=shard_md,
|
|
)
|
|
)
|
|
|
|
st = ShardedTensor._init_from_local_shards_and_global_metadata(
|
|
local_shards, st_md, process_group=dp_pg
|
|
)
|
|
|
|
if (
|
|
spec_key in layout_specs
|
|
and layout_specs[spec_key][0] is not None
|
|
):
|
|
fqn_to_offset[key] = cast(
|
|
Sequence[int], layout_specs[spec_key][0]
|
|
)
|
|
|
|
state_dict[key] = st
|
|
|
|
# Whether we unflatten before or after doesn't matter
|
|
dist_cp.load_state_dict(
|
|
state_dict=state_dict,
|
|
storage_reader=storage_reader,
|
|
# FIXME the type of planner is wrong in load_state_dict
|
|
planner=_ReaderWithOffset(fqn_to_offset) if dp_pg is not None else None,
|
|
)
|
|
|
|
state_dict = unflatten_state_dict(state_dict, metadata.planner_data)
|
|
|
|
return state_dict
|