EmbarkStudios / rust-gpu

🐉 Making Rust a first-class language and ecosystem for GPU shaders 🚧

Home Page:https://shader.rs

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Conditionals and code generation performance

DGriffin91 opened this issue · comments

I noticed that a rust gpu shader was running much slower than the equivalent wgsl one.
The wgsl one takes 53ms, and the rust gpu version takes 67ms.
Looking at the SPIRV I tracked part of the issue down to this:

if uvt.x > 0.0 && uvt.y > 0.0 && uvt.z > 0.0 && uvt.x + uvt.y < 1.0 {
    uvt
} else {
    vec3(f32::MAX, f32::MAX, f32::MAX)
}

I used spirv-cross to look at the code rust-gpu was producing in glsl and noticed it was producing this:

if (_1039 > 0.0)
{
    bool _1057;
    bool _1058;
    if (_1040 > 0.0)
    {
        bool _1047 = _1041 > 0.0;
        bool _1053;
        if (_1047)
        {
            _1053 = fma(_1028, _1023, _1040) < 1.0;
        }
        else
        {
            _1053 = _76;
        }
        _1057 = _1053;
        _1058 = _1047 ? false : true;
    }
    else
    {
        _1057 = _76;
        _1058 = true;
    }
    _1061 = _1057;
    _1062 = _1058;
}
else
{
    _1061 = _76;
    _1062 = true;
}

Whereas if I take wgsl through the same path (wgsl -> spirv -> glsl) it looks like this:

if ((((_87.x > 0.0) && (_87.y > 0.0)) && (_87.z > 0.0)) && ((_87.x + _87.y) < 1.0)) {
    return _87;
} else {
    return vec3(F32MAX);
}

I tried forcing it to not branch but generate a bool, with u32(uvt.x > 0.0 && uvt.y > 0.0 && uvt.z > 0.0 && uvt.x + uvt.y < 1.0) == 1, and while it kept the conversion and the equality check, it still had this same nested branching structure.

I then tried this which got me a lot closer to the wgsl perf (now 58ms):

if (uvt.x > 0.0) as u32
    & (uvt.y > 0.0) as u32
    & (uvt.z > 0.0) as u32
    & (uvt.x + uvt.y < 1.0) as u32
    == 1
{
    uvt
} else {
    vec3(f32::MAX, f32::MAX, f32::MAX)
}

This actually results in it using mix here:

mix(vec3(F32MAX), vec3(_1024, _1025, _1026), bvec3((((uint(_1024 > 0.0) & uint(_1025 > 0.0)) & uint(_1026 > 0.0)) & uint(fma(_1013, _1008, _1025) < 1.0)) == 1u)).y;

Is it possible to improve the code generation in rust gpu to avoid the excessive branching in situations like this?

(I'm aware that this could also be written differently to avoid branching, I'm not concerned about this specific impl, but about the code generation in general)

commented

does it improve after running through spirv-opt?

@ickk rust-gpu uses spirv-opt, so no unfortunately.

It seems like the best option currently it do write it like:
if (uvt.x > 0.0) & (uvt.y > 0.0) & (uvt.z > 0.0) & (uvt.x + uvt.y < 1.0) {
Discussion on discord