New APIs
xrsrke opened this issue · comments
XλRI-U5 commented
- from nanotron.core.parallelism.tensor_parallelism.nn import (
- TensorParallelColumnLinear,
- TensorParallelEmbedding,
- TensorParallelLinearMode,
- TensorParallelRowLinear,
- )
- from nanotron.core.optimizer.zero import ZeroDistributedOptimizer
- from nanotron.dataloaders.nemo import get_nemo_dataloader
+ from nanotron.distributed import ParallelContext, ParallelMode
+ from nanotron.nn.tensor_parallel import ColumnParallelLinear, RowParallelLinear, ParallelEmbedding, ParallelCrossEntropy
+ from nanotron.nn.pipeline_parallel import PipelineBlock
+ from nanotron.nn.data_parallel import DataParallel
+ from nanotron.optim import ZeroDistributedOptimizer
+ from nanotron.utils.data import DistributedDataLoader
- dpg = get_process_groups(
- data_parallel_size=self.config.parallelism.dp,
- pipeline_parallel_size=self.config.parallelism.pp,
- tensor_parallel_size=self.config.parallelism.tp,
- )
+ parallel_context = ParallelContext.from_torch(
+ tensor_parallel_size=2,
+ pipeline_parallel_size=4,
+ data_parallel_size=2
+ )
class LlamaModel(nn.Module):
def __init__(
self,
config: LlamaConfig,
- dpg: DistributedProcessGroups,
parallel_config: Optional[ParallelismArgs],
+ parallel_context: ParallelContext
):
super().__init__()
# Declare all the nodes
- self.p2p = P2P(dpg.pp_pg, device=torch.device("cuda"))
self.config = config
self.parallel_config = parallel_config
- self.dpg = dpg
self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE
tp_linear_async_communication = (
parallel_config.tp_linear_async_communication if parallel_config is not None else False
)
- self.token_position_embeddings = PipelineBlock(
- p2p=self.p2p,
- module_builder=Embedding,
- module_kwargs={
- "tp_pg": dpg.tp_pg,
- "config": config,
- "parallel_config": parallel_config,
- },
- module_input_keys={"input_ids", "input_mask"},
- module_output_keys={"input_embeds"},
- )
+ token_position_embeddings = Embedding(config, parallel_config , parallel_context)
+ self.token_position_embeddings = PipelineBlock(
+ token_position_embeddings, parallel_context,
+ input_keys={"input_ids", "input_mask"}, output_keys={"input_embeds"}
+ )
- self.decoder = nn.ModuleList(
- [
- PipelineBlock(
- p2p=self.p2p,
- module_builder=LlamaDecoderLayer,
- module_kwargs={
- "config": config,
- "parallel_config": parallel_config,
- "tp_pg": dpg.tp_pg,
- "layer_idx": layer_idx,
- },
- module_input_keys={"hidden_states", "sequence_mask"},
- module_output_keys={"hidden_states", "sequence_mask"},
- )
- for layer_idx in range(config.num_hidden_layers)
- ]
- )
+ # user specify how many transformer blocks does a rank has (since this is quite simple)
+. num_local_pipeline_stages = ....
+ decoder = nn.ModuleList([LlamaDecoderLayer(config, layer_idx, parallel_config , parallel_context) for layer_idx in range(num_local_pipeline_stages)])
+ self.decoder = PipelineBlock(
+ final_layer_norm, parallel_context,
+ input_keys={"hidden_states", "sequence_mask"},
+ output_keys={"hidden_states", "sequence_mask"}
+ )
- self.final_layer_norm = PipelineBlock(
- p2p=self.p2p,
- module_builder=RMSNorm,
- module_kwargs={"normalized_shape": config.hidden_size, "eps": config.rms_norm_eps},
- module_input_keys={"input"},
- module_output_keys={"hidden_states"},
- )
+ final_layer_norm = RMSNorm(config.hidden_size, config.rms_norm_eps)
+ self.final_layer_norm = PipelineBlock(
+ final_layer_norm, parallel_context,
+ input_keys={"input"}, output_keys={"hidden_states"}
+ )
- self.lm_head = PipelineBlock(
- p2p=self.p2p,
- # Understand that this means that we return sharded logits that are going to need to be gathered
- module_builder=TensorParallelColumnLinear,
- module_kwargs={
- "in_features": config.hidden_size,
- "out_features": config.vocab_size,
- "pg": dpg.tp_pg,
- "bias": False,
- "mode": self.tp_mode,
- "async_communication": tp_linear_async_communication,
- },
- module_input_keys={"x"},
- module_output_keys={"logits"},
- )
+ lm_head = ColumnParallelLinear(
+ config.hidden_size, config.vocab_size,
+ bias=False, mode=self.tp_mode,
+ async_communication=tp_linear_async_communication
+ )
+ self.lm_head = PipelineBlock(
+ lm_head, parallel_context,
+ input_keys={"x"}, output_keys={"logits"}
+ )
- self.cast_to_fp32 = PipelineBlock(
- p2p=self.p2p,
- module_builder=lambda: lambda x: x.float(),
- module_kwargs={},
- module_input_keys={"x"},
- module_output_keys={"output"},
- )
def forward(
self,
- input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length]
- input_mask: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length]
+ input_ids: torch.Tensor, # [batch_size, seq_length]
+ input_mask: torch.Tensor, # [batch_size, seq_length]
):
# all tensors are optional as most ranks don't need anything from the dataloader.
output = self.token_position_embeddings(input_ids=input_ids, input_mask=input_mask)
hidden_encoder_states = {
"hidden_states": output["input_embeds"],
"sequence_mask": input_mask,
}
for encoder_block in self.decoder:
hidden_encoder_states = encoder_block(**hidden_encoder_states)
hidden_states = self.final_layer_norm(input=hidden_encoder_states["hidden_states"])["hidden_states"]
sharded_logits = self.lm_head(x=hidden_states)["logits"]
fp32_sharded_logits = self.cast_to_fp32(x=sharded_logits)["output"]
return fp32_sharded_logits, hidden_states
-class LlamaForTraining(BRRRModel):
- def __init__(
- self,
- config: LlamaConfig,
- dpg: DistributedProcessGroups,
- parallel_config: Optional[ParallelismArgs],
- random_states: Optional[RandomStates] = None,
- ):
- super().__init__()
- self.model = LlamaModel(config=config, dpg=dpg, parallel_config=parallel_config)
- self.loss = PipelineBlock(
- p2p=self.model.p2p,
- module_builder=Loss,
- module_kwargs={"tp_pg": dpg.tp_pg},
- module_input_keys={
- "sharded_logits",
- "label_ids",
- "label_mask",
- },
- module_output_keys={"loss"},
- )
- self.dpg = dpg
- self.config = config
- self.parallel_config = parallel_config
- def forward(
- self,
- input_ids: Union[torch.Tensor, TensorPointer],
- input_mask: Union[torch.Tensor, TensorPointer],
- label_ids: Union[torch.Tensor, TensorPointer],
- label_mask: Union[torch.Tensor, TensorPointer],
- ) -> Dict[str, Union[torch.Tensor, TensorPointer]]:
- sharded_logits = self.model(
- input_ids=input_ids,
- input_mask=input_mask,
- )
- loss = self.loss(
- sharded_logits=sharded_logits,
- label_ids=label_ids,
- label_mask=label_mask,
- )["loss"]
- return {"loss": loss}
- model = init_model(
- model_builder=lambda: LlamaForTraining(config=model_config, dpg=dpg, parallel_config=parallel_config),
- model_config=model_config,
- parallel_config=parallel_config,
- dtype=dtype,
- dpg=dpg,
- make_ddp=False,
- )
+ model = LlamaModel(config, parallel_context)
# we eliminate `sync_gradients_across_dp`, `DataParallel` automatically register backward hooks
+ model = DataParallel(model, parallel_context)
- outputs = pipeline_engine.train_batch_iter(
- model=model,
- pg=dpg.pp_pg,
- batch=(next(data_iterator) for _ in range(n_micro_batches_per_batch)),
- nb_microbatches=n_micro_batches_per_batch,
- grad_accumulator=grad_accumulator,
-)
+ model = PiplineParallel(model, num_microbatches, parallel_context)
- optimizer, grad_accumulator = init_optimizer_and_grad_accumulator(
- model=model, optimizer_args=optimizer_args, dpg=dpg
- )
+ named_parameters = ...
+ optimizer = ZeroDistributedOptimizer(named_parameters, parallel_context)
- dataloader = get_nemo_dataloader(
- dataset=train_dataset,
- sequence_length=sequence_length,
- micro_batch_size=micro_batch_size,
- global_batch_size=global_batch_size,
- num_workers=config.data.num_loading_workers,
- cfg=config.data.dataset,
- consumed_samples=consumed_train_samples,
- dpg=dpg,
- input_pp_rank=input_pp_rank,
- output_pp_rank=output_pp_rank,
- dataloader_drop_last=True
- )
# assume that only the first pipeline stage loads data,
# subsequent pipeline stages only receives activations
+ if parallel_context.is_first_rank(ParallelMode.PIPELINE):
+ dataloader = DistributedDataLoader(
+ dataset, sequence_length, microbatch_size, global_batch_size,
+ num_workers, consumed_samples, dataloader_drop_last
+ parallel_context
+ )
+ for _ in range(epochs):
+ for batch in dataloader:
+ outputs = model(batch)
+
+ # assume that only the last pipeline stage has the loss
+ if parallel_context.get_local_rank(ParallelMode.PIPELINE) == parallel_context.pipeline_parallel_size:
+ loss = ParallelCrossEntropy(outputs["logits"], targets) # this is sharded logits
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()