pytorch / glow

Compiler for Neural Network hardware accelerators

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Model compilation miss some weights

jakubmiernik opened this issue · comments

A few days ago I tried to compile the model we worked on. It show up it give incorrect outputs after some investigation I realize that in in generated weights some layers are missing,

After recompilation with NXP version of glow it started working, so probably problem is somewhere in latest version or with my compiled version.

Reproduction:

  1. Export model to onnx:
import torch
import torch.nn as nn
import numpy as np

class test_CNN(nn.Module):
    def __init__(self, drop=0.25):
        super(test_CNN, self).__init__()
        self.conv1 = nn.Conv2d(2, 16, (3, 3), padding=1)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu1 = torch.nn.ReLU()
        self.dropout = nn.Dropout(drop)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.bn1(x)
        x = self.dropout(x)
        x = x.reshape([-1, 1024])
        return x

model = test_CNN()
x = torch.rand(1, 2, 8, 8, requires_grad=True)

onnx_path = "./test_nn.onnx"

torch.onnx.export(model,
                  x,
                  onnx_path,
                  export_params=True,
                  opset_version=10,
                  do_constant_folding=True,
                  input_names=['input'],
                  output_names=['output'])
  1. Compile:
    model-compiler --verbose-compilation -backend=CPU -model=<path_to_model>/test_nn.onnx -emit-bundle=build_test -model-input="input,float,[1,2,8,8]" --bundle-api-verbose

  2. Output from latest Glow:

// Bundle API auto-generated header file. Do not edit!
// Glow Tools version: 2021-11-16 (3563865a2) ()

#ifndef _GLOW_BUNDLE_TEST_NN_H
#define _GLOW_BUNDLE_TEST_NN_H

#include <stdint.h>

// ---------------------------------------------------------------
//                       Common definitions
// ---------------------------------------------------------------
#ifndef _GLOW_BUNDLE_COMMON_DEFS
#define _GLOW_BUNDLE_COMMON_DEFS

// Glow bundle error code for correct execution.
#define GLOW_SUCCESS 0

// Memory alignment definition with given alignment size
// for static allocation of memory.
#define GLOW_MEM_ALIGN(size)  __attribute__((aligned(size)))

// Macro function to get the absolute address of a
// placeholder using the base address of the mutable
// weight buffer and placeholder offset definition.
#define GLOW_GET_ADDR(mutableBaseAddr, placeholderOff)  (((uint8_t*)(mutableBaseAddr)) + placeholderOff)

#endif

// ---------------------------------------------------------------
//                          Bundle API
// ---------------------------------------------------------------
// Model name: "test_nn"
// Total data size: 10432 (bytes)
// Activations allocation efficiency: 1.0000
// Placeholders:
//
//   Name: "input"
//   Type: float<1 x 2 x 8 x 8>
//   Size: 128 (elements)
//   Size: 512 (bytes)
//   Offset: 0 (bytes)
//
//   Name: "output"
//   Type: float<1 x 1024>
//   Size: 1024 (elements)
//   Size: 4096 (bytes)
//   Offset: 512 (bytes)
//
// Constants:
//
//   Name: "conv1_bias"
//   Type: float<16>
//   Size: 16 (elements)
//   Size: 64 (bytes)
//   Offset: 0 (bytes)
//
//   Name: "conv1_weight__1"
//   Type: float<16 x 3 x 3 x 2>
//   Size: 288 (elements)
//   Size: 1152 (bytes)
//   Offset: 64 (bytes)
//
// NOTE: Placeholders are allocated within the "mutableWeight"
// buffer and are identified using an offset relative to base.
// ---------------------------------------------------------------
#ifdef __cplusplus
extern "C" {
#endif

// Placeholder address offsets within mutable buffer (bytes).
#define TEST_NN_input   0
#define TEST_NN_output  512

// Memory sizes (bytes).
#define TEST_NN_CONSTANT_MEM_SIZE     1216
#define TEST_NN_MUTABLE_MEM_SIZE      4608
#define TEST_NN_ACTIVATIONS_MEM_SIZE  4608

// Memory alignment (bytes).
#define TEST_NN_MEM_ALIGN  64

// Bundle entry point (inference function). Returns 0
// for correct execution or some error code otherwise.
int test_nn(uint8_t *constantWeight, uint8_t *mutableWeight, uint8_t *activations);

#ifdef __cplusplus
}
#endif
#endif
  1. Output from NXP Glow:
// Bundle API auto-generated header file. Do not edit!
// Glow Tools version: 2021-10-01 (55a459bef)

#ifndef _GLOW_BUNDLE_TEST_NN_H
#define _GLOW_BUNDLE_TEST_NN_H

#include <stdint.h>

// ---------------------------------------------------------------
//                       Common definitions
// ---------------------------------------------------------------
#ifndef _GLOW_BUNDLE_COMMON_DEFS
#define _GLOW_BUNDLE_COMMON_DEFS

// Glow bundle error code for correct execution.
#define GLOW_SUCCESS 0

// Memory alignment definition with given alignment size
// for static allocation of memory.
#define GLOW_MEM_ALIGN(size)  __attribute__((aligned(size)))

// Macro function to get the absolute address of a
// placeholder using the base address of the mutable
// weight buffer and placeholder offset definition.
#define GLOW_GET_ADDR(mutableBaseAddr, placeholderOff)  (((uint8_t*)(mutableBaseAddr)) + placeholderOff)

#endif

// ---------------------------------------------------------------
//                          Bundle API
// ---------------------------------------------------------------
// Model name: "test_nn"
// Total data size: 18624 (bytes)
// Activations allocation efficiency: 1.0000
// Placeholders:
//
//   Name: "input"
//   Type: float<1 x 2 x 8 x 8>
//   Size: 128 (elements)
//   Size: 512 (bytes)
//   Offset: 0 (bytes)
//
//   Name: "output"
//   Type: float<1 x 1024>
//   Size: 1024 (elements)
//   Size: 4096 (bytes)
//   Offset: 512 (bytes)
//
// Constants:
//
//   Name: "conv1_bias"
//   Type: float<16>
//   Size: 16 (elements)
//   Size: 64 (bytes)
//   Offset: 0 (bytes)
//
//   Name: "conv1_weight__1"
//   Type: float<16 x 3 x 3 x 2>
//   Size: 288 (elements)
//   Size: 1152 (bytes)
//   Offset: 64 (bytes)
//
//   Name: "BatchNormalization_2__1_muBroadcasted_tile2__1_constfold"
//   Type: float<1 x 8 x 8 x 16>
//   Size: 1024 (elements)
//   Size: 4096 (bytes)
//   Offset: 1216 (bytes)
//
//   Name: "BatchNormalization_2__1_coefBroadcasted_tile2__1_constfold"
//   Type: float<1 x 8 x 8 x 16>
//   Size: 1024 (elements)
//   Size: 4096 (bytes)
//   Offset: 5312 (bytes)
//
// NOTE: Placeholders are allocated within the "mutableWeight"
// buffer and are identified using an offset relative to base.
// ---------------------------------------------------------------
#ifdef __cplusplus
extern "C" {
#endif

// Placeholder address offsets within mutable buffer (bytes).
#define TEST_NN_input   0
#define TEST_NN_output  512

// Memory sizes (bytes).
#define TEST_NN_CONSTANT_MEM_SIZE     9408
#define TEST_NN_MUTABLE_MEM_SIZE      4608
#define TEST_NN_ACTIVATIONS_MEM_SIZE  4608

// Memory alignment (bytes).
#define TEST_NN_MEM_ALIGN  64

// Bundle entry point (inference function). Returns 0
// for correct execution or some error code otherwise.
int test_nn(uint8_t *constantWeight, uint8_t *mutableWeight, uint8_t *activations);

#ifdef __cplusplus
}
#endif
#endif

I don't have access to the NXP version. But I tried your test using our torch_glow flow, by adding a test to torch_glow/tests/nodes/batchnorm2d_test.py:

    def test_conv_relu_batchnorm(self):

        class ConvReluBatchNormMod(nn.Module):
            def __init__(self, drop=0.25):
                super(ConvReluBatchNormMod, self).__init__()
                self.conv1 = nn.Conv2d(2, 16, (3, 3), padding=1)
                self.bn1 = nn.BatchNorm2d(16)
                self.relu1 = torch.nn.ReLU()
                self.dropout = nn.Dropout(drop)

            def forward(self, x):
                x = self.conv1(x)
                x = self.relu1(x)
                x = self.bn1(x)
                x = self.dropout(x)
                x = x.reshape([-1, 1024])
                return x

        inputs = torch.rand(1, 2, 8, 8, requires_grad=True)
        model = ConvReluBatchNormMod()
        model.eval()

        utils.compare_tracing_methods(model, inputs)

And I was unable to reproduce this issue on our latest master. If you can provide a reproduction of the issue via this flow then it'd be easier to diagnose. CC: @LuisTbx

@jfix71 thanks for having a look into this issue. I'm having a look to utils.compare_tracing_methods(), if we pass the argument fusible_ops=BN,Conv2D it should fail the same way as the test from @jakubmiernik , since the error comes from trying to fuse the lowered ReLU layer, with the BN

@jfix71 Thanks for having a look at this issue.

I check your test and indeed it's passing, but I'm not fully familiar with this framework so I cannot say if all optimizations are done there in the same way as during the compilation of model.

My first conclusion when I opened this issue was a little bit incorrect I can say. The problem is not fully missing weights, fusing conv and BN layers is fine in general.
Problem is that in this case between conv and BN there is relu layer, so simple fission of those layers should not happen.

I think that problem is that two optimizations are done without taking care of each other. Seems like first RelU is somehow fused into the convolution layer and then convolution is fused with BN (because there is no relu inside anymore), ignoring previous optimization Relu.

I found that removing one foldActivation from lib/Optimizer/GraphOptimizer/GraphOptimizer.cpp:6659 do not let optimized RelU into convolution:

  // Fold activations before lowering to enable cases which would not fuse after
  // lowering. This concerns particularly convolution&relu since relu will be
  // lowered to max(0, x).
  foldActivations(F, cctx, &B);

Removing this part make compiled model work fine, but probably removing some optimization way.

@jakubmiernik @LuisTbx Thanks for finding/investigating this. I agree this is a bug. I have put up #5898 which should hopefully fix it -- please try this PR locally and let me know if it works for you.

@jfix71 Thanks for patch,

It seems like it's not enough:
When trying to compile now it's failing with:

Unsupported node found while compiling Function <path_to_onnx_model> for backend CPU: BatchNormalization
Name : BatchNorm
Input : float<1 x 8 x 8 x 16>
Scale : float<16>
Bias : float<16>
Mean : float<16>
Var : float<16>
ChannelIdx : 3
Epsilon : 0.000000e+00
Momentum : 0.000000e+00
Users : 1
Result : float<1 x 8 x 8 x 16>
...

Well, I believe most ops aren't supported in FP16 on the CPU Backend, so this makes sense to some extent. It'd be good to understand what the stack trace is for that issue though. Because generally we shouldn't see any BatchNormalization nodes at all for the CPU backend, as they should be lowered to other nodes here:

static void lowerBatchNormalizationNode(Function *F, CompilationContext &cctx,
const BatchNormalizationNode &BN) {
LOG_SCOPE(F->getLogContext(), "lowerBatchNormalizationNode")
auto in = BN.getInput();
auto out = BN.getResult();
auto beta = BN.getBias();
auto gamma = BN.getScale();
auto var = BN.getVar();
auto mean = BN.getMean();
// http://cthorey.github.io/backpropagation/
//
// mu = 1/N*np.sum(h,axis =0)
// sigma2 = 1/N*np.sum((h-mu)**2)
// hath = (h-mu)*(sigma2+epsilon)**(-1./2.)
// y = gamma*hath+beta
// In inference mode just apply the transformation:
// y[i] = (x - mu) * gamma / stdvar + beta;
auto channelIdx = BN.getChannelIdx();
auto epsilon = BN.getEpsilon();
auto *epsilonSplat =
F->createSplat(DECORATE_NODE_NAME(BN, "epsilon"), var.getType(), epsilon);
Node *coef =
F->createAdd(DECORATE_NODE_NAME(BN, "var_plus_eps"), var, epsilonSplat);
coef = F->createPow(DECORATE_NODE_NAME(BN, "sqrt_var_plus_eps"), coef, 0.5);
coef = F->createDiv(DECORATE_NODE_NAME(BN, "inverse_sqrt_var_plus_eps"),
gamma, coef);
// Apply: out := (in - mean) * coef + beta
// in and out are of the same size, while others must be broadcasted.
auto *meanB = F->createBroadcast(DECORATE_NODE_NAME(BN, "muBroadcasted"),
mean, in.dims(), channelIdx);
auto *coefB = F->createBroadcast(DECORATE_NODE_NAME(BN, "coefBroadcasted"),
coef, in.dims(), channelIdx);
auto *betaB = F->createBroadcast(DECORATE_NODE_NAME(BN, "betaBroadcasted"),
beta, in.dims(), channelIdx);
Node *newResult =
F->createSub(DECORATE_NODE_NAME(BN, "in_minus_mean"), in, meanB);
newResult =
F->createMul(DECORATE_NODE_NAME(BN, "mul_coef"), newResult, coefB);
newResult = F->createAdd(DECORATE_NODE_NAME(BN, "result"), newResult, betaB);
replaceAllUsesOfWith(cctx.loweredInfoMap, BN.getResult(), newResult);
}

@jfix71 It's not FP16 operation as I understand. float<16> here mean float buffer with size 16.

Because generally we shouldn't see any BatchNormalization nodes at all for the CPU backend,
Output I give there is with your patch (#5898) applied, so it turned off fuse conv with BN because conv have already activations modified by RelU. Do I miss something?

Your parch deactivates the lowering of the BN if there is a ReLU. But we should not remove this optimisation.

I think the way to go would be to apply the lowering of the BN: apply the formula to the output of the tensor with the already lowered ReLU.

Curently, what happens instead is that glow tries to fuse BN with the conv despite the lowered ReLU.

It's not FP16 operation as I understand. float<16> here mean float buffer with size 16.

Whoops. Yeah misread it.

Output I give there is with your patch (#5898) applied, so it turned off fuse conv with BN because conv have already activations modified by RelU. Do I miss something?

That's correct.

Your parch deactivates the lowering of the BN if there is a ReLU. But we should not remove this optimisation.
I think the way to go would be to apply the lowering of the BN: apply the formula to the output of the tensor with the already lowered ReLU.
Curently, what happens instead is that glow tries to fuse BN with the conv despite the lowered ReLU.

The patch keeps the optimization in place but makes it only apply when the conv doesn't have a fused activation.

I think the term "lower" is overloaded generally. Let make sure we're on the same page: First the whole graph is loaded in as a series of Glow nodes. So we have Conv-Relu-BN nodes.

Next, the Conv and Relu nodes are fused into a single node. So the Conv node now has Relu fused in (let's call this single node ConvRelu). Now we have ConvRelu-BN.

Previous to my PR, it would have then fused the BN into the ConvRelu's weights. But that's incorrect.

After my PR it will not fused them, so we would still have ConvRelu-BN. At this point the "lowering" logic I pointed to in my previous comment is supposed to change the BN into a series of other nodes (add, pow, div, sub, mul, etc.) that are equivalent to the BN node, and all those nodes be supported by the backend.

So, I am wondering where this Unsupported node found while compiling Function error came from. I would like to see the stack trace so I can better understand where in the compilation flow it occurred.

@jfix71 Thanks for explanation. I think we are on same page. The only difference is that I didn't expect BN to split on different operations, and this was a part where we were misunderstood I think.

Here is a full error output from running model-compilation on our model:

>>>> LLVM_SYMBOLIZER_PATH=llvm-symbolizer-8 ./bin/model-compiler --verbose-compilation -backend=CPU -model=<model_path> -emit-bundle=build_linux -model-input="input,float,[1,2,8,8]" --bundle-api-verbose -onnx-define-symbol="batch_size,1"
Unsupported node found while compiling Function <model_path> for backend CPU: BatchNormalization
Name : BatchNorm
Input : float<1 x 8 x 8 x 16>
Scale : float<16>
Bias : float<16>
Mean : float<16>
Var : float<16>
ChannelIdx : 3
Epsilon : 0.000000e+00
Momentum : 0.000000e+00
Users : 1
Result : float<1 x 8 x 8 x 16>
Unsupported node found while compiling Function <model_path> for backend CPU: BatchNormalization
Name : BatchNorm__1
Input : float<1 x 8 x 8 x 32>
Scale : float<32>
Bias : float<32>
Mean : float<32>
Var : float<32>
ChannelIdx : 3
Epsilon : 0.000000e+00
Momentum : 0.000000e+00
Users : 1
Result : float<1 x 8 x 8 x 32>
Unsupported node found while compiling Function <model_path> for backend CPU: BatchNormalization
Name : BatchNorm__2
Input : float<1 x 4 x 4 x 64>
Scale : float<64>
Bias : float<64>
Mean : float<64>
Var : float<64>
ChannelIdx : 3
Epsilon : 0.000000e+00
Momentum : 0.000000e+00
Users : 1
Result : float<1 x 4 x 4 x 64>
Unsupported node found while compiling Function <model_path> for backend CPU: BatchNormalization
Name : BatchNorm__3
Input : float<1 x 1 x 1 x 32>
Scale : float<32>
Bias : float<32>
Mean : float<32>
Var : float<32>
ChannelIdx : 3
Epsilon : 0.000000e+00
Momentum : 0.000000e+00
Users : 1
Result : float<1 x 1 x 1 x 32>
WARNING: Logging before InitGoogleLogging() is written to STDERR
F0112 19:22:21.749035 24354 Error.cpp:123] exitOnError(Error) got an unexpected ErrorValue:
Error code: COMPILE_UNSUPPORTED_NODE_AFTER_OPTIMIZE
Error message: Unsupported node(s) found after optimizing Function <model_path> for backend CPU

Error return stack:
--------------------------------------------------------------------------------
<glow_repo_path>/lib/Optimizer/GraphOptimizer/GraphOptimizer.cpp:6728
--------------------------------------------------------------------------------
<glow_repo_path>/tools/loader/Loader.cpp:736
--------------------------------------------------------------------------------
*** Check failure stack trace: ***
 #0 0x0000000000c64301 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) (./bin/model-compiler+0xc64301)
 #1 0x0000000000c62402 llvm::sys::RunSignalHandlers() (./bin/model-compiler+0xc62402)
 #2 0x0000000000c646a2 SignalHandler(int) (./bin/model-compiler+0xc646a2)
 #3 0x00007f791d9383c0 __restore_rt (/lib/x86_64-linux-gnu/libpthread.so.0+0x153c0)
 #4 0x00007f791d3c118b gsignal /build/glibc-eX1tMB/glibc-2.31/signal/../sysdeps/unix/sysv/linux/internal-signals.h:86:3
 #5 0x00007f791d3a0859 abort /build/glibc-eX1tMB/glibc-2.31/stdlib/abort.c:81:7
 #6 0x00007f791e388e4c _ULx86_64_init_local (/usr/lib/x86_64-linux-gnu/libglog.so.0+0x9e4c)
 #7 0x00007f791e38c1c3 (/usr/lib/x86_64-linux-gnu/libglog.so.0+0xd1c3)
 #8 0x00007f791e39125b google::LogMessage::SendToLog() (/usr/lib/x86_64-linux-gnu/libglog.so.0+0x1225b)
 #9 0x00007f791e38bebf google::LogMessage::Flush() (/usr/lib/x86_64-linux-gnu/libglog.so.0+0xcebf)
#10 0x00007f791e38c6ef google::LogMessageFatal::~LogMessageFatal() (/usr/lib/x86_64-linux-gnu/libglog.so.0+0xd6ef)
#11 0x000000000325e49a glow::detail::exitOnError(char const*, unsigned long, glow::detail::GlowError) <glow_repo_path>/lib/Support/Error.cpp:126:1
#12 0x00000000004f3c00 glow::Loader::compile(glow::CompilationContext&) <glow_repo_path>/tools/loader/Loader.cpp:736:5
#13 0x000000000052ed70 main <glow_repo_path>/tools/loader/ModelCompiler.cpp:39:10
#14 0x00007f791d3a20b3 __libc_start_main /build/glibc-eX1tMB/glibc-2.31/csu/../csu/libc-start.c:308:16
#15 0x00000000004ef7fe _start (./bin/model-compiler+0x4ef7fe)

Multiple error about BN is coming from fact that I tested it on model with multiple conv-relu-BN blocks.

Unfortunately I'm not able to reproduce this issue on my end at least on our current master. I'm going to at least land the fix as is. If you can reproduce it on master please let me know and I can investigate further.

Hi,

On this hash: cf84c73 from your PR it's impossible now to compile model with conv-relu-BN.

Please take this model: test_nn.zip (I had to zip it because of gihub rules)
It's gen from script from start of issue.

When you run something like this:

./bin/model-compiler --verbose-compilation -backend=CPU -model=../test_nn.onnx -emit-bundle=build_test -model-input="input,float,[1,2,8,8]" --bundle-api-verbose -onnx-define-symbol="batch_size,1"

You will get error like this:

Unsupported node found while compiling Function ../test_nn.onnx for backend CPU: BatchNormalization
Name : BatchNorm
Input : float<1 x 8 x 8 x 16>
Scale : float<16>
Bias : float<16>
Mean : float<16>
Var : float<16>
ChannelIdx : 3
Epsilon : 0.000000e+00
Momentum : 0.000000e+00
Users : 1
Result : float<1 x 8 x 8 x 16>
WARNING: Logging before InitGoogleLogging() is written to STDERR
F0114 22:32:23.325853 32153 Error.cpp:123] exitOnError(Error) got an unexpected ErrorValue:
Error code: COMPILE_UNSUPPORTED_NODE_AFTER_OPTIMIZE
Error message: Unsupported node(s) found after optimizing Function ../test_nn.onnx for backend CPU

Error return stack:
--------------------------------------------------------------------------------
/home/jmiernik/repos/glow/lib/Optimizer/GraphOptimizer/GraphOptimizer.cpp:6728
--------------------------------------------------------------------------------
/home/jmiernik/repos/glow/tools/loader/Loader.cpp:736
--------------------------------------------------------------------------------
*** Check failure stack trace: ***
 #0 0x0000000000c64301 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) (./bin/model-compiler+0xc64301)
 #1 0x0000000000c62402 llvm::sys::RunSignalHandlers() (./bin/model-compiler+0xc62402)
 #2 0x0000000000c646a2 SignalHandler(int) (./bin/model-compiler+0xc646a2)
 #3 0x00007fa0c3d833c0 __restore_rt (/lib/x86_64-linux-gnu/libpthread.so.0+0x153c0)
 #4 0x00007fa0c380c18b gsignal /build/glibc-eX1tMB/glibc-2.31/signal/../sysdeps/unix/sysv/linux/internal-signals.h:86:3
 #5 0x00007fa0c37eb859 abort /build/glibc-eX1tMB/glibc-2.31/stdlib/abort.c:81:7
 #6 0x00007fa0c47d3e4c _ULx86_64_init_local (/usr/lib/x86_64-linux-gnu/libglog.so.0+0x9e4c)
 #7 0x00007fa0c47d71c3 (/usr/lib/x86_64-linux-gnu/libglog.so.0+0xd1c3)
 #8 0x00007fa0c47dc25b google::LogMessage::SendToLog() (/usr/lib/x86_64-linux-gnu/libglog.so.0+0x1225b)
 #9 0x00007fa0c47d6ebf google::LogMessage::Flush() (/usr/lib/x86_64-linux-gnu/libglog.so.0+0xcebf)
#10 0x00007fa0c47d76ef google::LogMessageFatal::~LogMessageFatal() (/usr/lib/x86_64-linux-gnu/libglog.so.0+0xd6ef)
#11 0x000000000325e49a glow::detail::exitOnError(char const*, unsigned long, glow::detail::GlowError) /home/jmiernik/repos/glow/lib/Support/Error.cpp:126:1
#12 0x00000000004f3c00 glow::Loader::compile(glow::CompilationContext&) /home/jmiernik/repos/glow/tools/loader/Loader.cpp:736:5
#13 0x000000000052ed70 main /home/jmiernik/repos/glow/tools/loader/ModelCompiler.cpp:39:10
#14 0x00007fa0c37ed0b3 __libc_start_main /build/glibc-eX1tMB/glibc-2.31/csu/../csu/libc-start.c:308:16
#15 0x00000000004ef7fe _start (./bin/model-compiler+0x4ef7fe)

IMHO it shouldn't be merge in state like this.

@jakubmiernik Thanks -- I was able to repro and determined the issue. I will try to get a fix up soon, either in #5898 or separately.

BTW -- I may still land #5898 as is, i.e. where it causes explicit breakage for you. The current state of the codebase instead silently messes up numerics. I think it's preferable that we explicitly error out.

@jakubmiernik #5901 should fix this when combined with #5898.

@jfix71 Sorry it took me so long :)

I tested it on my side and indeed now it seems to work fine.
Thank you for this fix