openxla / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.

Home Page:http://iree.dev/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Turbine Llama2 requires `IREE_TASK_EXECUTOR_MAX_OUTSTANDING_WAITS >= 740`

bjacob opened this issue · comments

Here is a nice thanksgiving dinner of an issue, with a big bird at the center and a few side dishes.

The main issue

Talking about this IREE_TASK_EXECUTOR_MAX_OUTSTANDING_WAITS constant:

// Maximum number of simultaneous waits an executor may perform as part of a
// wait-any operation. A larger value may enable better wake coalescing by the
// kernel. This is only a count limiting wait tasks that have been scheduled and
// been promoted to the root executor waiting list. There may be any number of
// waits deeper in the pipeline so long as they don't all become ready
// simultaneously.
//
// Realistically, though, if we have more than 64 outstanding **root** waits
// it's hard to reason about if/when the executor queue could make forward
// progress and indicates a possible error in task assignment.
//
// Also, the underlying iree_wait_set_t may not support more than 64 handles on
// certain platforms without emulation. Trying to keep us on the fast-path
// with a reasonable number seems fine for now until we have a need for more.
//
// NOTE: we reserve 1 wait handle for our own internal use. This allows us to
// wake the coordination worker when new work is submitted from external
// sources.
#define IREE_TASK_EXECUTOR_MAX_OUTSTANDING_WAITS (64 - 1)

The Turbine Llama2 model (reproduction details below) requires a higher value.

With --task_topology_group_count=1 it consistently requires 740. With the default on my 16-core machine, it fluctuates between just 20 (so it occasionally works) and 660. When enabling any sanitizer, it stays consistently below limit, which made this interesting to debug.

Besides this main issue, there are surrounding issues in this code that arise when hitting this limit. The symptom here is not a clean "exceeded MAX_OUTSTANDING_WAITS" error. The current immediate symptom is that ppoll(2) fails because its nfds parameter is a huge value which is a negative value casted to unsigned. The reason why it's a negative value is... actually a combination of 3 issues, described below.

Reproduction

Download both files in this comment: https://discord.com/channels/973663919757492264/1173330951791706113/1177060782933037166

Compile:

tools/iree-compile --iree-opt-const-eval=false --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-target-cpu=znver4 --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 --iree-global-opt-enable-quantized-matmul-reassociation --iree-opt-data-tiling --iree-llvmcpu-enable-ukernels=mmt4d ~/testing/Llama_2_7b_chat_hf.mlir -o /tmp/Llama_2_7b_chat_hf.vmfb

Run:

tools/iree-run-module --module=/tmp/Llama_2_7b_chat_hf.vmfb --function=run_forward --device=local-task --input=1x1xi64 --task_worker_stack_size=4000000 --parameters=model=$HOME/Downloads/Llama2_7b_i4quant.safetensors --task_topology_group_count=1

The --task_topology_group_count=1 is not necessary to reproduce, but it makes this issue much more deterministic.

Side issue 1: mismatched iree_wait_set_erase vs iree_wait_set_insert calls

The basic reason why we end up with a negative nfds is that we've called iree_wait_set_erase, which decrements handle_count, more times than we have called iree_wait_set_insert, which increments it.

The issue seems localized within a single caller, iree_task_poller_prepare_task.

It conditionally calls iree_wait_set_insert, and then it conditionally calls iree_wait_set_erase. There doesn't seem to be something to keep these in sync. This diff seems to fix it, but I don't know if it's correct -- it technically still doesn't guarantee matched inserts/erases, it just removes some possibilities of mismatched erases. Rather than trying to rationalize this into a PR, i will leave it here for someone who really understands this code.

diff --git a/runtime/src/iree/task/poller.c b/runtime/src/iree/task/poller.c
index 7fc9b4744..6d0281ba0 100644
--- a/runtime/src/iree/task/poller.c
+++ b/runtime/src/iree/task/poller.c
@@ -323,7 +323,7 @@ static iree_task_poller_prepare_result_t iree_task_poller_prepare_task(
   }
 
   // Remove the system wait handle from the wait set, if assigned.
-  if (iree_all_bits_set(task->header.flags, IREE_TASK_FLAG_WAIT_EXPORTED)) {
+  if (iree_status_is_ok(status) && iree_all_bits_set(task->header.flags, IREE_TASK_FLAG_WAIT_EXPORTED)) {
     iree_wait_handle_t* wait_handle =
         iree_wait_handle_from_source(&task->wait_source);
     if (wait_handle) {

By the way, while looking at the statuses in this function, I wondered if it was intentional that the return status of iree_wait_source_query is ignored and overwritten by the return status of iree_task_poller_insert_wait_handle ?

Side issue 2: iree_wait_set_erase silently proceeds with erasing an unknown entry

This loop over the existing entries finds the one that needs to be erased:

for (iree_host_size_t i = 0; i < set->handle_count; ++i) {
if (iree_wait_primitive_compare_identical(&set->user_handles[i],
&handle)) {
index = i;
break;
}
}

If the handle to be erased wasn't in the set, this silently continues with the default value. Maybe we should assert that this loop did find the handle?

Side issue 3: handle_count is unsigned and sometimes treated as signed

There's some evidence that the code is split on whether handle_count is signed on unsigned. Its C type is unsigned (iree_host_size_t) and on Unix it lowers to nfds_t, also unsigned, but some code seems to treat it as if it's signed:

Here it's casted to int:

int tail_index = (int)set->handle_count - 1;

Here there's a x <= 0 condition that perhaps would be phrased x == 0 if it realized it was unsigned:

IREE_TRACE_ZONE_BEGIN(z0);

Maybe make handle_count type int, or if we are righteous about the bit width, introduce a iree_host_ssize_t here? I still believe in the value of this element from the Google C/C++ style: use unsigned only for bitfields or when modular arithmetic is intended, use signed for everything else. Maybe the lack of a iree_host_ssize_t is pushing our code to more unsigned usage than we intend?