google-parfait / tensorflow-federated

An open-source framework for machine learning and other computations on decentralized data.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Colab stuck at learning_process.initialize()

makabaka2 opened this issue · comments

When using colab to run the code, it always gets stuck in the initialization step. I am using the latest version of TTF, 0.74.0.
It keeps prompting that it is running. After running for two hours, there is no result and no error or log information.
code show as below:

def model_fn():

    model = create_bi_LSTM_model(input_shape, output_shape, lstm_layers, dropout, recurrent_dropout, merge_mode)

    return tff.learning.models.from_keras_model(
        model,
        # dummy_batch=dummy_batch,
        input_spec= tff_valid_data.element_spec,
        loss=tf.keras.losses.Huber(),
        metrics=[tf.keras.metrics.MeanAbsoluteError()])

client_optimizer_func = tf.keras.optimizers.Adam()        # Cannot input this directly into "tff.learning.from_keras_model()"
server_optimizer_func = tf.keras.optimizers.Adam()        # ValueError: Tensor("SGD/learning_rate:0", shape=(), dtype=resource) must be from the same graph as Tensor("zeros_like:0", shape=(1, 1024), dtype=float32).

federated_averaging = tff.learning.algorithms.build_weighted_fed_avg(
    model_fn,
    client_optimizer_fn=lambda:  tf.keras.optimizers.Adam(),
    server_optimizer_fn=lambda:  tf.keras.optimizers.Adam())

federated_evaluation = tff.learning.algorithms.build_fed_eval(model_fn)

state = federated_averaging.initialize()

Hi @makabaka2. Can you provide the following information requested in the bug template:

Environment (please complete the following information):

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
  • Python package versions (e.g., TensorFlow Federated, TensorFlow):
  • Python version:
  • Bazel version (if building from source):
  • CUDA/cuDNN version:
  • What TensorFlow Federated execution stack are you using?

Note: You can collect the Python package information by running pip3 freeze
from the command line and most of the other information can be collected using
TensorFlows environment capture
script.

The code is executed on Colab, and my execution environment information is as follows:
OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 22.04.3 LTS \n \l
Python package versions (e.g., TensorFlow Federated, TensorFlow): tensorflow-2.14.1, tensorflow_federated-0.74.0
Python version: 3.10.12
Bazel version (if building from source):
CUDA/cuDNN version: CUDA-12.2, cuDNN-8.9.6
What TensorFlow Federated execution stack are you using? 0.74.0

Ah, the colab bit helped me repro. You're right, things seem to be hanging indefinitely on colab right now. Looking into this now.

OMG that was such a pain, looking forward to you guys sorting this out soon

I found a version series that runs successfully on colab:
tensorflow 2.12.0
tensorflow_federated 0.61.0
Python 3.10.12
This version configuration can be successfully run on colab. If anyone encounters similar problems like me, you can try the above version.

I am also stuck on initialize() while using colab. It seems to be waiting indefinitely at wait() from threading.py
python 3.10.12
tensorflow 2.14.1
tensorflow_federated 0.74.0

I will try changing the versions to what the author suggested.

在使用 colab 时,我也陷入了initialize() 的困境。它似乎在 threading.py python 3.10.12 tensorflow 2.14.1 tensorflow_federated 0.74.0的 wait() 处无限期等待

我会尝试将版本更改为建议的版本。

Remember to restart colab after installing tensorflow and tensorflow_federated

I've narrowed it down to TFF v0.69.0. Looking at the change list, there's nothing that jumps out at me. It could be related to organizational changes in executor stacks, but this seems unlikely.

Regardless, I'd recommend using TFF v0.68.0 (or earlier) for now.

Also if anyone sees this in no-colab environments, please let me know. So far I have only been able to repro it on colab.

In limited testing I've found that:

  • 0.74.0 hangs
  • 0.75.0 appears to succeed
  • 0.76.0 hangs again

Some further updates:

I was able to figure out why the C++ subprocess was failing on Colab by locating the binary in the pip package and attempting to call the subprocess module directly:

print(tff.__path__)
>>> ['/usr/local/lib/python3.10/dist-packages/tensorflow_federated']

# Using the path above, see if we can find the path to the binary
!ls /lib/python3.10/dist-packages/tensorflow_federated/data/
>>> worker_binary

Now try to start the binary directly and we'll capture the return code:

import subprocess
import portpicker

port = portpicker.pick_unused_port()
binary = '/usr/local/lib/python3.10/dist-packages/tensorflow_federated/data/worker_binary'
subprocess.check_output(args=[binary, f'--port={port}'], stderr=subprocess.STDOUT)
>>> CalledProcessError: Command '['/usr/local/lib/python3.10/dist-packages/tensorflow_federated/data/worker_binary', '--port=46383']' died with <Signals.SIGILL: 4>.

Bingo! The binary is failing from an Illegal Instruction (SIGILL: 4). So the binary is using some instruction set that isn't supported by the Colab CPU runtime machine. Unfortunately, this doesn't tell us which instruction is problematic…

Let's take a look at what the Colab machine's processor is:

!cat /proc/cpuinfo
>>> processor	: 0
vendor_id	: GenuineIntel
cpu family	: 6
model		: 79
model name	: Intel(R) Xeon(R) CPU @ 2.20GHz
stepping	: 0
…

Google Search tells me that Xeon chips from family 6, model 79 are Broadwell microarchitecture (https://en.wikipedia.org/wiki/Broadwell_(microarchitecture)).

Hypothesis: the binary is including newer AVX instructions (possibly AVX-512), which isn't supported until Skylake (the architecture after Broadwell). AVX-512 is sometimes used to speed up ML frameworks as it has wider instructions for increased SIMD parallelism. This is similar to the type of issue noticed in tensorflow/tensorflow#18275

To confirm we will download the 0.75.0 and 0.76.0 pip packages and inspect the worker_binary.

pip download tensorflow_federated==${VERSION}  -d /tmp/tff_${VERSION}

unzip /tmp/tff_${VERSION} /tensorflow_federated-${VERSION} -py3-none-manylinux_2_31_x86_64.whl -d /tmp/tff_${VERSION} 

objdump --no-show-raw-insn -M x86-64 -d /tmp/tff_${VERSION} /tensorflow_federated/data/worker_binary | awk '{if ($2 !~ ":" && $2 != "data32" && $2 != "file" && $2 != "of" && length($2) > 0) {print $2}}' | sort -u >  /tmp/tff_${VERSION} /instructions.txt

Now we diff the two instructions.txt and see whats different. The section that stood out to me:

$ diff 0.76.0_instructions.txt 0.75.0_instructions.txt
…
< kmovb     
< kmovd      
< kmovw     
…
> vmaskmovpd
> vmaskmovps
> vpmaskmovd
> Vpmaskmovq

kmov* is an AVX-512 instruction (specifically AVX-512F, see https://en.wikipedia.org/wiki/AVX-512#New_opmask_instructions) and only appears in the 0.76.0 pip package binary. Where as it appears that the 0.75.0 package is using AVX2 instructions (https://en.wikipedia.org/wiki/Advanced_Vector_Extensions#New_instructions, and https://en.wikipedia.org/wiki/Advanced_Vector_Extensions#New_instructions_2).

Okay we confirmed there is a difference in AVX2 vs AVX-512 instructions being used in the binaries, and that the colab runtime doesn't support AVX-512 instructions.

As a final test, lets see if one of the other Colab runtimes has newer CPUs and works. Going to Runtime > Change runtime type and choosing TPUv2 then shows:

!cat /proc/cpuinfo
>>> processor	: 0
vendor_id	: GenuineIntel
cpu family	: 6
model		: 85
model name	: Intel(R) Xeon(R) CPU @ 2.00GHz
stepping	: 3
…

Where Family 6 Model 85 appears to be a Cascade Lake microarchitecture, which does include AVX-512 instruction support. Low and behold, executing on this colab runtime does not hang.

This seems like a smoking gun for the issue.

Why are we getting different instruction sets in the built binaries?
Likely this is from the build --copt=-march=native configuration here https://github.com/tensorflow/federated/blob/d4865b22711385f6dbd357b6d8b0e1e978e8986d/.bazelrc#L37. This instructs the compiler to optimize for the architecture of the machine building the binary. Recently our pool of build machines has grown to include some with newer architectures which are incompatible with Colab CPU runtimes.

I'll look into configuring our build systems to ensure that the pip package is built for Haswell and newer CPUs, which should enable it to run on default Colab CPU runtimes.

Incredible sleuthing!

Following up: the fix in #4637 should have landed in the most recent version 0.77.0 (https://pypi.org/project/tensorflow-federated/0.77.0/), please give it a try.