huggingface / nanotron

Minimalistic large language model 3D-parallelism training

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

New APIs

xrsrke opened this issue · comments

- from nanotron.core.parallelism.tensor_parallelism.nn import (
-     TensorParallelColumnLinear,
-     TensorParallelEmbedding,
-     TensorParallelLinearMode,
-     TensorParallelRowLinear,
- )
- from nanotron.core.optimizer.zero import ZeroDistributedOptimizer
- from nanotron.dataloaders.nemo import get_nemo_dataloader


+ from nanotron.distributed import ParallelContext, ParallelMode
+ from nanotron.nn.tensor_parallel import ColumnParallelLinear, RowParallelLinear, ParallelEmbedding, ParallelCrossEntropy
+ from nanotron.nn.pipeline_parallel import PipelineBlock
+ from nanotron.nn.data_parallel import DataParallel
+ from nanotron.optim import ZeroDistributedOptimizer
+ from nanotron.utils.data import DistributedDataLoader

- dpg = get_process_groups(
-    data_parallel_size=self.config.parallelism.dp,
-    pipeline_parallel_size=self.config.parallelism.pp,
-    tensor_parallel_size=self.config.parallelism.tp,
- )

+ parallel_context = ParallelContext.from_torch(
+     tensor_parallel_size=2,
+     pipeline_parallel_size=4,
+     data_parallel_size=2
+ )

class LlamaModel(nn.Module):
    def __init__(
        self,
        config: LlamaConfig,
-       dpg: DistributedProcessGroups,
       parallel_config: Optional[ParallelismArgs],
+       parallel_context: ParallelContext
    ):
        super().__init__()

        # Declare all the nodes
-       self.p2p = P2P(dpg.pp_pg, device=torch.device("cuda"))
        self.config = config
       self.parallel_config = parallel_config
-       self.dpg = dpg
        self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE
        tp_linear_async_communication = (
            parallel_config.tp_linear_async_communication if parallel_config is not None else False
        )

-        self.token_position_embeddings = PipelineBlock(
-            p2p=self.p2p,
-            module_builder=Embedding,
-            module_kwargs={
-                "tp_pg": dpg.tp_pg,
-                "config": config,
-                "parallel_config": parallel_config,
-            },
-            module_input_keys={"input_ids", "input_mask"},
-            module_output_keys={"input_embeds"},
-        )
+         token_position_embeddings = Embedding(config, parallel_config , parallel_context)
+         self.token_position_embeddings = PipelineBlock(
+           token_position_embeddings, parallel_context,
+            input_keys={"input_ids", "input_mask"}, output_keys={"input_embeds"}
+        )


-         self.decoder = nn.ModuleList(
-             [
-                 PipelineBlock(
-                     p2p=self.p2p,
-                     module_builder=LlamaDecoderLayer,
-                     module_kwargs={
-                         "config": config,
-                         "parallel_config": parallel_config,
-                         "tp_pg": dpg.tp_pg,
-                         "layer_idx": layer_idx,
-                     },
-                     module_input_keys={"hidden_states", "sequence_mask"},
-                     module_output_keys={"hidden_states", "sequence_mask"},
-                 )
-                 for layer_idx in range(config.num_hidden_layers)
-             ]
-         )
+        # user specify how many transformer blocks does a rank has (since this is quite simple)
+.       num_local_pipeline_stages = ....
+        decoder = nn.ModuleList([LlamaDecoderLayer(config, layer_idx, parallel_config , parallel_context) for layer_idx in range(num_local_pipeline_stages)])
+        self.decoder = PipelineBlock(
+           final_layer_norm, parallel_context,
+           input_keys={"hidden_states", "sequence_mask"}, 
+           output_keys={"hidden_states", "sequence_mask"}
+        )


-         self.final_layer_norm = PipelineBlock(
-             p2p=self.p2p,
-             module_builder=RMSNorm,
-             module_kwargs={"normalized_shape": config.hidden_size, "eps": config.rms_norm_eps},
-             module_input_keys={"input"},
-             module_output_keys={"hidden_states"},
-         )
        
+        final_layer_norm = RMSNorm(config.hidden_size, config.rms_norm_eps)
+        self.final_layer_norm = PipelineBlock(
+           final_layer_norm, parallel_context,
+            input_keys={"input"}, output_keys={"hidden_states"}
+        )

-         self.lm_head = PipelineBlock(
-             p2p=self.p2p,
-             # Understand that this means that we return sharded logits that are going to need to be gathered
-             module_builder=TensorParallelColumnLinear,
-             module_kwargs={
-                 "in_features": config.hidden_size,
-                 "out_features": config.vocab_size,
-                 "pg": dpg.tp_pg,
-                 "bias": False,
-                 "mode": self.tp_mode,
-                 "async_communication": tp_linear_async_communication,
-             },
-             module_input_keys={"x"},
-             module_output_keys={"logits"},
-         )
+         lm_head = ColumnParallelLinear(
+              config.hidden_size, config.vocab_size,
+              bias=False, mode=self.tp_mode,
+              async_communication=tp_linear_async_communication
+         )
+         self.lm_head = PipelineBlock(
+             lm_head, parallel_context,
+             input_keys={"x"}, output_keys={"logits"}
+         )

-        self.cast_to_fp32 = PipelineBlock(
-            p2p=self.p2p,
-            module_builder=lambda: lambda x: x.float(),
-            module_kwargs={},
-            module_input_keys={"x"},
-            module_output_keys={"output"},
-        )

    def forward(
        self,
-        input_ids: Union[torch.Tensor, TensorPointer],  # [batch_size, seq_length]
-        input_mask: Union[torch.Tensor, TensorPointer],  # [batch_size, seq_length]
+         input_ids: torch.Tensor, # [batch_size, seq_length]
+         input_mask: torch.Tensor, # [batch_size, seq_length]
    ):
        # all tensors are optional as most ranks don't need anything from the dataloader.

        output = self.token_position_embeddings(input_ids=input_ids, input_mask=input_mask)

        hidden_encoder_states = {
            "hidden_states": output["input_embeds"],
            "sequence_mask": input_mask,
        }
        for encoder_block in self.decoder:
            hidden_encoder_states = encoder_block(**hidden_encoder_states)

        hidden_states = self.final_layer_norm(input=hidden_encoder_states["hidden_states"])["hidden_states"]

        sharded_logits = self.lm_head(x=hidden_states)["logits"]

        fp32_sharded_logits = self.cast_to_fp32(x=sharded_logits)["output"]

        return fp32_sharded_logits, hidden_states


-class LlamaForTraining(BRRRModel):
-    def __init__(
-        self,
-        config: LlamaConfig,
-        dpg: DistributedProcessGroups,
-        parallel_config: Optional[ParallelismArgs],
-        random_states: Optional[RandomStates] = None,
-    ):
-        super().__init__()
-        self.model = LlamaModel(config=config, dpg=dpg, parallel_config=parallel_config)
-        self.loss = PipelineBlock(
-            p2p=self.model.p2p,
-            module_builder=Loss,
-            module_kwargs={"tp_pg": dpg.tp_pg},
-            module_input_keys={
-                "sharded_logits",
-                "label_ids",
-                "label_mask",
-            },
-            module_output_keys={"loss"},
-        )
-        self.dpg = dpg
-        self.config = config
-        self.parallel_config = parallel_config

-    def forward(
-        self,
-        input_ids: Union[torch.Tensor, TensorPointer],
-        input_mask: Union[torch.Tensor, TensorPointer],
-        label_ids: Union[torch.Tensor, TensorPointer],
-        label_mask: Union[torch.Tensor, TensorPointer],
-    ) -> Dict[str, Union[torch.Tensor, TensorPointer]]:
-        sharded_logits = self.model(
-            input_ids=input_ids,
-            input_mask=input_mask,
-        )
-        loss = self.loss(
-            sharded_logits=sharded_logits,
-            label_ids=label_ids,
-            label_mask=label_mask,
-        )["loss"]
-        return {"loss": loss}


- model = init_model(
-    model_builder=lambda: LlamaForTraining(config=model_config, dpg=dpg, parallel_config=parallel_config),
-    model_config=model_config,
-    parallel_config=parallel_config,
-    dtype=dtype,
-    dpg=dpg,
-    make_ddp=False,
-    )

+ model = LlamaModel(config, parallel_context)
# we eliminate `sync_gradients_across_dp`, `DataParallel` automatically register backward hooks
+ model = DataParallel(model, parallel_context)

- outputs = pipeline_engine.train_batch_iter(
-    model=model,
-    pg=dpg.pp_pg,
-    batch=(next(data_iterator) for _ in range(n_micro_batches_per_batch)),
-    nb_microbatches=n_micro_batches_per_batch,
-    grad_accumulator=grad_accumulator,
-)

+ model = PiplineParallel(model, num_microbatches, parallel_context)

- optimizer, grad_accumulator = init_optimizer_and_grad_accumulator(
-   model=model, optimizer_args=optimizer_args, dpg=dpg
- )

+ named_parameters = ...
+ optimizer = ZeroDistributedOptimizer(named_parameters, parallel_context)

- dataloader = get_nemo_dataloader(
-     dataset=train_dataset,
-     sequence_length=sequence_length,
-     micro_batch_size=micro_batch_size,
-     global_batch_size=global_batch_size,
-     num_workers=config.data.num_loading_workers,
-     cfg=config.data.dataset,
-     consumed_samples=consumed_train_samples,
-     dpg=dpg,
-     input_pp_rank=input_pp_rank,
-     output_pp_rank=output_pp_rank,
-     dataloader_drop_last=True
- )

# assume that only the first pipeline stage loads data,
# subsequent pipeline stages only receives activations
+ if parallel_context.is_first_rank(ParallelMode.PIPELINE):
+     dataloader = DistributedDataLoader(
+          dataset, sequence_length, microbatch_size, global_batch_size, 
+          num_workers, consumed_samples, dataloader_drop_last
+          parallel_context
+     )

+ for _ in range(epochs):
+     for batch in dataloader:
+          outputs = model(batch)
+ 
+          # assume that only the last pipeline stage has the loss
+          if parallel_context.get_local_rank(ParallelMode.PIPELINE) == parallel_context.pipeline_parallel_size:
+               loss = ParallelCrossEntropy(outputs["logits"], targets) # this is sharded logits
          
+               optimizer.zero_grad()
+               loss.backward()
+               optimizer.step()