pytorch / pytorch.github.io

The website for PyTorch

Home Page:https://pytorch.org

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Wrong Code in the FSDP blog

Luosuu opened this issue · comments

📚 Documentation

In the blog introducing FSDP API

fsdp_model = FullyShardedDataParallel(
   model(),
   fsdp_auto_wrap_policy=default_auto_wrap_policy,
   cpu_offload=CPUOffload(offload_params=True),
)

it should be model instead of model() inside FullyShardedDataParallel

so it should be

fsdp_model = FullyShardedDataParallel(
   model,
   fsdp_auto_wrap_policy=default_auto_wrap_policy,
   cpu_offload=CPUOffload(offload_params=True),
)

It looks a little confusing and maybe could be written more clearly, but I think that's actually correct. If it was just model, it would be trying to FSDP-wrap the DDP model. By using model(), it's FSDP wrapping a new model instance.