marian-nmt / marian-dev

Fast Neural Machine Translation in C++ - development repository

Home Page:https://marian-nmt.github.io

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Crash during training/inference with fp16 & factors-combine=concat

arturnn opened this issue · comments

Bug description

Marian crashes during training/inference if --fp16 and --factors-combine=concat are provided at the same time.

The problem comes from embedWithConcat method. Removing explicit type cast to Type::float32 in graph->constant call in layers/embedding.cpp:61 seems to fix the issue. Could that be a sufficient fix, or will it break something in the long term?

/*private*/ Expr Embedding::embedWithConcat(const Words& data) const {
  auto graph = E_->graph();
  std::vector<IndexType> lemmaIndices;
  std::vector<float> factorIndices;
  factoredVocab_->lemmaAndFactorsIndexes(data, lemmaIndices, factorIndices);
  auto lemmaEmbs = rows(E_, lemmaIndices);
  int dimFactors = FactorEmbMatrix_->shape()[0];
  auto factEmbs
      = dot(graph->constant(
                {(int)data.size(), dimFactors}, inits::fromVector(factorIndices), Type::float32),
            FactorEmbMatrix_);

  return concatenate({lemmaEmbs, factEmbs}, -1);
}

How to reproduce

Try to train/decode factored model with --factors-combine=concat and --fp16 options provided at the same time.

Context

  • Marian version: v1.11.3 b8bf086 2022-02-11 06:04:38 -0800

[2022-02-11 21:58:10] Error: Child 1 has different type (first: float32 != child: float16)
[2022-02-11 21:58:10] Error: Aborted from static marian::Type marian::NaryNodeOp::commonType(const std::vector<IntrusivePtr<marian::Chainable<IntrusivePtrmarian::TensorBase > > >&) in /home/anowakowski/MTExperiments/marian/tools/marian-dev/src/graph/node.h:207
[CALL STACK]
[0x55685e45c1f0] marian::NaryNodeOp:: commonType (std::vector<IntrusivePtr<marian::Chainable<IntrusivePtrmarian::TensorBase>>,std::allocator<IntrusivePtr<marian::Chainable<IntrusivePtrmarian::TensorBase>>>> const&) + 0x2b0
[0x55685e46377f] IntrusivePtr<marian::Chainable<IntrusivePtrmarian::TensorBase>> marian:: Expression <marian::DotNodeOp,IntrusivePtr<marian::Chainable<IntrusivePtrmarian::TensorBase>>&,IntrusivePtr<marian::Chainable<IntrusivePtrmarian::TensorBase>>&,bool&,bool&,float&>(IntrusivePtr<marian::Chainable<IntrusivePtrmarian::TensorBase>>&, IntrusivePtr<marian::Chainable<IntrusivePtrmarian::TensorBase>>&, bool&, bool&, float&) + 0x12f
[0x55685e39b635] marian:: dot (IntrusivePtr<marian::Chainable<IntrusivePtrmarian::TensorBase>>, IntrusivePtr<marian::Chainable<IntrusivePtrmarian::TensorBase>>, bool, bool, float) + 0x3a5
[0x55685e7b3206] marian::Embedding:: embedWithConcat (std::vector<marian::Word,std::allocatormarian::Word> const&) const + 0x266

Hi Artur, thanks for reporting this. I think your solution should work, would you mind opening a PR?