tensorflow / serving

A flexible, high-performance serving system for machine learning models

Home Page:https://www.tensorflow.org/serving

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Reshape layer throws std::bad_alloc

Saoqq opened this issue · comments

commented

System information

  • OS:
    • Windows 11, 21H2, build 22000.856
    • WSL2, Ubuntu 22.04 -- allocated 8Gb in .wslconfig
  • TF Serving installed using docker, version 2.9.1
  • Hardware:
    • AMD Ryzen 5 3600
    • RAM 16Gb

Describe the problem

I'm facing issue with Reshape layer in TF serving 2.9.1.
It throws following to docker logs.

pythonproject-serving-1  | 2022-09-01 21:07:16.379819: I tensorflow_serving/model_servers/server.cc:442] Exporting HTTP/REST API at:localhost:8501 ...
pythonproject-serving-1  | [evhttp_server.cc : 245] NET_LOG: Entering the event loop ...
pythonproject-serving-1  | terminate called after throwing an instance of 'std::bad_alloc'
pythonproject-serving-1  |   what():  std::bad_alloc
pythonproject-serving-1  | /usr/bin/tf_serving_entrypoint.sh: line 3:     7 Aborted                 tensorflow_model_server --port=8500 --rest_api_port=8501 --model_name=${MODEL_NAME} --model_base_path=${MODEL_BASE_PATH}/${MODEL_NAME} "$@"

Also tried rolling back to older versions: the same mock model WORKS ON TF SERVING 2.8.2

Exact Steps to Reproduce

Mock model to demonstrate issue, TF is 2.9.1

import numpy as np
import tensorflow as tf

model = tf.keras.Sequential([
    tf.keras.layers.Input(shape=(128, 64, 3)),
    tf.keras.layers.Reshape((12, 2048))
])

model.compile()

# ensure it works both on WSL and Windows
model.predict([np.ones(shape=(128, 64, 3)).tolist()])

model.save("./.data/models/v2_96acc/5")

Then serving with docker-compose with following config:

version: "3.8"
  serving:
    image: tensorflow/serving:2.9.1
    restart: on-failure
    ports:
      - "8501:8501"
    volumes:
      - "./.data/models:/models"
    environment:
      MODEL_NAME: v2_96acc

And evaluating results

import json
import numpy as np
import requests

image = {
    "inputs": [np.ones(shape=(128, 64, 3)).tolist()]
}

res = requests.post('http://localhost:8501/v1/models/v2_96acc/versions/5:predict', data=json.dumps(image))

print(len(res.json()['outputs']))

I'm encountering the same error when sending prediction requests to the universal-sentence-encoder-multilingual model running on Docker container image tensorflow/serving:2.9.1.

This error implies you are running out of memory or in other case a process is allocated a large memory portion that some is left unoccupied and cannot be used by another process.

@Saoqq & @aamontree, Kindly share the complete error stack trace for debugging the issue.

In my case the container is running in a Kubernetes cluster within Google Cloud Platform. Here is what my deployment manifest looks like:

apiVersion: apps/v1
kind: Deployment
metadata:
  name: tensorflow-serving
  namespace: default
  labels:
    app: search-viewed-realtime
    tier: backend
  annotations:
    strategy.spinnaker.io/use-source-capacity: 'true'
spec:
  selector:
    matchLabels:
      app: search-viewed-realtime
      tier: backend
  template:
    metadata:
      labels:
        app: search-viewed-realtime
        tier: backend
    spec:
      containers:
        - name: tensorflow-serving
          image: 'tensorflow/serving:${version}'
          env:
            - name: MODEL_NAME
              value: 'universal-sentence-encoder-multilingual'
          resources:
            limits:
              cpu: 2.0
              memory: 1.0Gi
          ports:
            - containerPort: 8501
          volumeMounts:
            - name: model-storage
              mountPath: /models/
              readOnly: True
      volumes:
        - name: model-storage
          hostPath:
            path: /tmp/files/models/

I have tried increasing the memory allocation to 2.0Gi and that did not fix the issue either.

Here are the full startup logs and error stacktrace:

I tensorflow_serving/model_servers/server.cc:89] Building single TensorFlow model file config:  model_name: universal-sentence-encoder-multilingual model_base_path: /models/universal-sentence-encoder-multilingual
I tensorflow_serving/model_servers/server_core.cc:465] Adding/updating models.
I tensorflow_serving/model_servers/server_core.cc:594] (Re-)adding model: universal-sentence-encoder-multilingual
I tensorflow_serving/core/basic_manager.cc:740] Successfully reserved resources to load servable {name: universal-sentence-encoder-multilingual version: 3}
I tensorflow_serving/core/loader_harness.cc:66] Approving load for servable version {name: universal-sentence-encoder-multilingual version: 3}
I tensorflow_serving/core/loader_harness.cc:74] Loading servable version {name: universal-sentence-encoder-multilingual version: 3}
I external/org_tensorflow/tensorflow/cc/saved_model/reader.cc:43] Reading SavedModel from: /models/universal-sentence-encoder-multilingual/3
I external/org_tensorflow/tensorflow/cc/saved_model/reader.cc:81] Reading meta graph with tags { serve }
I external/org_tensorflow/tensorflow/cc/saved_model/reader.cc:122] Reading SavedModel debug info (if present) from: /models/universal-sentence-encoder-multilingual/3
I external/org_tensorflow/tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
I external/org_tensorflow/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:354] MLIR V1 optimization pass is not enabled
I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:228] Restoring SavedModel bundle.
I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:212] Running initialization op on SavedModel bundle at path: /models/universal-sentence-encoder-multilingual/3I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:301] SavedModel load for tags { serve }; Status: success: OK. Took 2499582 microseconds.
I tensorflow_serving/servables/tensorflow/saved_model_warmup_util.cc:71] Starting to read warmup data for model at /models/universal-sentence-encoder-multilingual/3/assets.extra/tf_serving_warmup_requests with model-warmup-options 
terminate called after throwing an instance of 'std::bad_alloc'
what(): std::bad_alloc
/usr/bin/tf_serving_entrypoint.sh: line 3:     7 Aborted                 (core dumped) tensorflow_model_server --port=8500 --rest_api_port=8501 --model_name=${MODEL_NAME} --model_base_path=${MODEL_BASE_PATH}/${MODEL_NAME} "$@"

I am using SavedModel Warmup to send prediction requests to the model on startup so this error is causing the pod to repeatedly crash and restart.

commented

@singhniraj08 I am not familiar with C/C++ development and don't see any guides how to prepare stack trace in tf serving docs.

I would appreaciate if you could provide some information how to get it.

I'm also receiving this error when using a predict request with the tensorflow/serving:latest (2.9.1) Docker image. Observed this error for two different NLP models on TF hub.

Fixed by switching to tensorflow/serving:2.8.2.

I'm running this using the WSL2 Docker backend - I tried adjusting the WSL2 settings to increase the memory limit, but the error still occurs with 2.9.1.

I'm facing the same issue and solved it with @williamih 's solution!
This issue should be fixed in the future release.

I'm also receiving this error when using a predict request with the tensorflow/serving:latest (2.9.1) Docker image. Observed this error for two different NLP models on TF hub.

Fixed by switching to tensorflow/serving:2.8.2.

I'm running this using the WSL2 Docker backend - I tried adjusting the WSL2 settings to increase the memory limit, but the error still occurs with 2.9.1.

I'm facing the same issue and solved it with @williamih 's solution! This issue should be fixed in the future release.

I'm also receiving this error when using a predict request with the tensorflow/serving:latest (2.9.1) Docker image. Observed this error for two different NLP models on TF hub.
Fixed by switching to tensorflow/serving:2.8.2.
I'm running this using the WSL2 Docker backend - I tried adjusting the WSL2 settings to increase the memory limit, but the error still occurs with 2.9.1.

Good to hear! Credit for the downgrading fix goes to @Saoqq - I just did that based on @Saoqq's original post!

I am trying to deploy a BERT-based model and encountering the same issue.
The problem is solved after downgrading tensorflow/serving to 2.8.2. Thank you!

commented

The same issue deploying bert_multi_cased_L-12_H-768_A-12 from TF-Hub using the corresponding bert_multi_cased_preprocess, i.e.

inp = tf.keras.layers.Input(shape=(), dtype=tf.string, name='input')
preprocess = hub.KerasLayer('https://tfhub.dev/tensorflow/bert_multi_cased_preprocess/3', name='preprocessing')
encoded_inp = preprocess(inp)
bert = hub.KerasLayer('https://tfhub.dev/tensorflow/bert_multi_cased_L-12_H-768_A-12/4', trainable=True, name='bert_encoder')
output = bert(encoded_inp)
network = tf.keras.Model(inp, output, name='bert_classifier')
model = tfm.nlp.models.BertTokenClassifier(network, 5)

Before downgrading to 2.8.2 as mentioned above I have just tried the latest pre-release and the bug is still persists in 2.10.0-rc3-gpu!

issue persists with final 2.10.0

Possibly related to issue: #2061

Hi ,

Std::bad_alloc can be caused by two things: A model that consumes too much memory or a bug in TensorFlow. Since TF / TF Serving 2.9, some previously silent bugs are surfacing as exceptions: Training the model in Python works fine, but loading the model in TF Serving fails with a std::bad_alloc.

Here is an example where this was the case.

This might be also the case here.

A first solution would be to re-compile TF Serving without support for exceptions. This is likely not sustainable.

A second solution would be to help us figure out the bug. Can you share in this github issue the failing stack trace. Here are the instructions:

# Start a TF Serving docker
docker run -it -v ${PWD}:/working_dir -w /working_dir --entrypoint bash -p 8001:8001 tensorflow/serving:2.10.0

# In the docker

# Install gdb
apt-get update
apt-get install gdb

# Start TF Serving in the docker
gdb -ex run --args tensorflow_model_server \
    --rest_api_port=8001 \
    --model_name=my_model \
    --model_base_path=$(pwd)/my_model
    
# In a new shell, outside of the docker, send a request to TF Serving.
# Note: Keep the docker shell open on the side.

curl http://localhost:8001/v1/models/my_model:predict -X POST -d '{"instances": [{"x":1}]}'

# If following this request, TF Serving crashes, you will see it in the docker. 

# In the docker, in gdb type
bt

# Share the result of "bt" in this issue.

I hope this help :)

I'm having the same issue for a different model so I generated a backtrace as described. In this case model warmup was enabled so it raised the std::bad_alloc when trying to run the warmup inference (ie: no need for external request). backtrace seems too complex for me to make sense of it but maybe it helps @achoum ?
#0 __GI_raise (sig=sig@entry=6) at ../sysdeps/unix/sysv/linux/raise.c:51 #1 0x00007f4ee2bab7f1 in __GI_abort () at abort.c:79 #2 0x00007f4ee3200957 in ?? () from /usr/lib/x86_64-linux-gnu/libstdc++.so.6 #3 0x00007f4ee3206ae6 in ?? () from /usr/lib/x86_64-linux-gnu/libstdc++.so.6 #4 0x00007f4ee3206b21 in std::terminate() () from /usr/lib/x86_64-linux-gnu/libstdc++.so.6 #5 0x00007f4ee3206d54 in __cxa_throw () from /usr/lib/x86_64-linux-gnu/libstdc++.so.6 #6 0x00007f4ee322f012 in std::__throw_bad_alloc() () from /usr/lib/x86_64-linux-gnu/libstdc++.so.6 #7 0x000056434bf74f43 in void absl::lts_20211102::inlined_vector_internal::Storage<long, 4ul, std::allocator<long> >::Resize<absl::lts_20211102::inlined_vector_internal::DefaultValueAdapter<std::allocator<long> > >(absl::lts_20211102::inlined_vector_internal::DefaultValueAdapter<std::allocator<long> >, unsigned long) () #8 0x0000564356e22639 in tensorflow::ValidateStridedSliceOp(tensorflow::Tensor const*, tensorflow::Tensor const*, tensorflow::Tensor const&, tensorflow::PartialTensorShape const&, int, int, int, int, int, tensorflow::PartialTensorShape*, tensorflow::PartialTensorShape*, bool*, bool*, bool*, absl::lts_20211102::InlinedVector<long, 4ul, std::allocator<long> >*, absl::lts_20211102::InlinedVector<long, 4ul, std::allocator<long> >*, absl::lts_20211102::InlinedVector<long, 4ul, std::allocator<long> >*, tensorflow::StridedSliceShapeSpec*) () #9 0x0000564351819a89 in tensorflow::{lambda(tensorflow::shape_inference::InferenceContext*)#34}::operator()(tensorflow::shape_inference::InferenceContext*) const [clone .isra.751] () #10 0x0000564351819dd4 in std::_Function_handler<tensorflow::Status (tensorflow::shape_inference::InferenceContext*), tensorflow::{lambda(tensorflow::shape_inference::InferenceContext*)#34}>::_M_invoke(std::_Any_data const&, tensorflow::shape_inference::InferenceContext*&&) () #11 0x0000564356e500e6 in tensorflow::shape_inference::InferenceContext::Run(std::function<tensorflow::Status (tensorflow::shape_inference::InferenceContext*)> const&) () #12 0x0000564351a77ba4 in tensorflow::grappler::SymbolicShapeRefiner::InferShapes(tensorflow::NodeDef const&, tensorflow::grappler::SymbolicShapeRefiner::NodeContext*) () #13 0x0000564351a805f4 in tensorflow::grappler::SymbolicShapeRefiner::UpdateNode(tensorflow::NodeDef const*, bool*) () #14 0x0000564351a81903 in tensorflow::grappler::GraphProperties::UpdateShapes(tensorflow::grappler::SymbolicShapeRefiner*, absl::lts_20211102::flat_hash_map<tensorflow::NodeDef const*, tensorflow::NodeDef const*, absl::lts_20211102::container_internal::HashEq<tensorflow::NodeDef const*, void>::Hash, absl::lts_20211102::container_internal::HashEq<tensorflow::NodeDef const*, void>::Eq, std::allocator<std::pair<tensorflow::NodeDef const* const, tensorflow::NodeDef const*> > > const&, tensorflow::NodeDef const*, bool*) const () #15 0x0000564351a81acf in tensorflow::grappler::GraphProperties::PropagateShapes(tensorflow::grappler::SymbolicShapeRefiner*, tensorflow::grappler::TopoQueue*, absl::lts_20211102::flat_hash_map<tensorflow::NodeDef const*, tensorflow::NodeDef const*, absl::lts_20211102::container_internal::HashEq<tensorflow::NodeDef const*, void>::Hash, absl::lts_20211102::container_internal::HashEq<tensorflow::NodeDef const*, void>::Eq, std::allocator<std::pair<tensorflow::NodeDef const* const, tensorflow::NodeDef const*> > > const&, int) const () #16 0x0000564351a7a6c7 in tensorflow::grappler::GraphProperties::InferStatically(bool, bool, bool, bool) () #17 0x00005643519edc53 in tensorflow::grappler::ConstantFolding::Optimize(tensorflow::grappler::Cluster*, tensorflow::grappler::GrapplerItem const&, tensorflow::GraphDef*) () #18 0x00005643518b9dbd in tensorflow::grappler::MetaOptimizer::RunOptimizer(tensorflow::grappler::GraphOptimizer*, tensorflow::grappler::Cluster*, tensorflow::grappler::GrapplerItem*, tensorflow::GraphDef*, tensorflow::grappler::MetaOptimizer::GraphOptimizationResult*) () #19 0x00005643518bb511 in tensorflow::grappler::MetaOptimizer::OptimizeGraph(tensorflow::grappler::Cluster*, tensorflow::grappler::GrapplerItem&&, tensorflow::GraphDef*) () #20 0x00005643518bd272 in tensorflow::grappler::MetaOptimizer::OptimizeConsumeItem(tensorflow::grappler::Cluster*, tensorflow::grappler::GrapplerItem&&, tensorflow::GraphDef*) () #21 0x00005643518bee91 in tensorflow::grappler::RunMetaOptimizer(tensorflow::grappler::GrapplerItem&&, tensorflow::ConfigProto const&, tensorflow::DeviceBase*, tensorflow::grappler::Cluster*, tensorflow::GraphDef*) () #22 0x00005643518a95ef in tensorflow::GraphExecutionState::OptimizeGraph(tensorflow::BuildGraphOptions const&, tensorflow::Graph const&, tensorflow::FunctionLibraryDefinition const*, std::unique_ptr<tensorflow::Graph, std::default_delete<tensorflow::Graph> >*, std::unique_ptr<tensorflow::FunctionLibraryDefinition, std::default_delete<tensorflow::FunctionLibraryDefinition> >*) () #23 0x00005643518aa2a0 in tensorflow::GraphExecutionState::BuildGraph(tensorflow::BuildGraphOptions const&, std::unique_ptr<tensorflow::ClientGraph, std::default_delete<tensorflow::ClientGraph> >*) () #24 0x000056434be91de2 in tensorflow::DirectSession::CreateGraphs(tensorflow::BuildGraphOptions const&, std::unordered_map<std::string, std::unique_ptr<tensorflow::Graph, std::default_delete<tensorflow::Graph> >, std::hash<std::string>, std::equal_to<std::string>, std::allocator<std::pair<std::string const, std::unique_ptr<tensorflow::Graph, std::default_delete<tensorflow::Graph> > > > >*, std::unique_ptr<tensorflow::FunctionLibraryDefinition, std::default_delete<tensorflow::FunctionLibraryDefinition> >*, tensorflow::DirectSession::RunStateArgs*, absl::lts_20211102::InlinedVector<tensorflow::DataType, 4ul, std::allocator<tensorflow::DataType> >*, absl::lts_20211102::InlinedVector<tensorflow::DataType, 4ul, std::allocator<tensorflow::DataType> >*, long*) () #25 0x000056434be93390 in tensorflow::DirectSession::CreateExecutors(tensorflow::CallableOptions const&, std::unique_ptr<tensorflow::DirectSession::ExecutorsAndKeys, std::default_delete<tensorflow::DirectSession::ExecutorsAndKeys> >*, std::unique_ptr<tensorflow::DirectSession::FunctionInfo, std::default_delete<tensorflow::DirectSession::FunctionInfo> >*, tensorflow::DirectSession::RunStateArgs*) () #26 0x000056434be955ef in tensorflow::DirectSession::GetOrCreateExecutors(absl::lts_20211102::Span<std::string const>, absl::lts_20211102::Span<std::string const>, absl::lts_20211102::Span<std::string const>, tensorflow::DirectSession::ExecutorsAndKeys**, tensorflow::DirectSession::RunStateArgs*) () #27 0x000056434be992c1 in tensorflow::DirectSession::Run(tensorflow::RunOptions const&, std::vector<std::pair<std::string, tensorflow::Tensor>, std::allocator<std::pair<std::string, tensorflow::Tensor> > > const&, std::vector<std::string, std::allocator<std::string> > const&, std::vector<std::string, std::allocator<std::string> > const&, std::vector<tensorflow::Tensor, std::allocator<tensorflow::Tensor> >*, tensorflow::RunMetadata*, tensorflow::thread::ThreadPoolOptions const&) () #28 0x000056434b8fb56f in tensorflow::serving::ServingSessionWrapper::Run(tensorflow::RunOptions const&, std::vector<std::pair<std::string, tensorflow::Tensor>, std::allocator<std::pair<std::string, tensorflow::Tensor> > > const&, std::vector<std::string, std::allocator<std::string> > const&, std::vector<std::string, std::allocator<std::string> > const&, std::vector<tensorflow::Tensor, std::allocator<tensorflow::Tensor> >*, tensorflow::RunMetadata*, tensorflow::thread::ThreadPoolOptions const&) () #29 0x000056434bda363a in tensorflow::serving::internal::RunPredict(tensorflow::RunOptions const&, tensorflow::MetaGraphDef const&, std::optional<long> const&, tensorflow::serving::internal::PredictResponseTensorSerializationOption, tensorflow::Session*, tensorflow::serving::PredictRequest const&, tensorflow::serving::PredictResponse*, tensorflow::thread::ThreadPoolOptions const&) () #30 0x000056434bda39cf in tensorflow::serving::RunPredict(tensorflow::RunOptions const&, tensorflow::MetaGraphDef const&, std::optional<long> const&, tensorflow::Session*, tensorflow::serving::PredictRequest const&, tensorflow::serving::PredictResponse*, tensorflow::thread::ThreadPoolOptions const&) () #31 0x000056434bd98ca0 in tensorflow::serving::(anonymous namespace)::RunWarmupRequest(tensorflow::serving::PredictionLog const&, tensorflow::RunOptions const&, tensorflow::MetaGraphDef const&, tensorflow::Session*) () #32 0x000056434bd99018 in std::_Function_handler<tensorflow::Status (tensorflow::serving::PredictionLog), tensorflow::serving::RunSavedModelWarmup(tensorflow::serving::ModelWarmupOptions const&, tensorflow::RunOptions const&, std::string const&, tensorflow::SavedModelBundle*)::{lambda(tensorflow::serving::PredictionLog)#1}>::_M_invoke(std::_Any_data const&, tensorflow::serving::PredictionLog&&) () #33 0x000056434bda7ff8 in tensorflow::serving::internal::RunSavedModelWarmup(tensorflow::serving::ModelWarmupOptions const&, std::string, std::function<tensorflow::Status (tensorflow::serving::PredictionLog)>) () #34 0x000056434bd98714 in tensorflow::serving::RunSavedModelWarmup(tensorflow::serving::ModelWarmupOptions const&, tensorflow::RunOptions const&, std::string const&, tensorflow::SavedModelBundle*) () #35 0x000056434b8ef041 in std::_Function_handler<tensorflow::Status (std::unique_ptr<tensorflow::SavedModelBundle, std::default_delete<tensorflow::SavedModelBundle> >*), tensorflow::serving::SavedModelBundleSourceAdapter::GetServableCreator(std::shared_ptr<tensorflow::serving::SavedModelBundleFactory>, std::string const&) const::{lambda(std::unique_ptr<tensorflow::SavedModelBundle, std::default_delete<tensorflow::SavedModelBundle> >*)#2}>::_M_invoke(std::_Any_data const&, std::unique_ptr<tensorflow::SavedModelBundle, std::default_delete<tensorflow::SavedModelBundle> >*&&) () #36 0x000056434b8eb7aa in tensorflow::serving::SimpleLoader<tensorflow::SavedModelBundle>::LoadWithMetadata(tensorflow::serving::Loader::Metadata const&) () #37 0x000056434b8d12fd in std::_Function_handler<tensorflow::Status (), tensorflow::serving::LoaderHarness::Load()::{lambda()#1}>::_M_invoke(std::_Any_data const&) () #38 0x000056434b8d41e6 in tensorflow::serving::Retry(std::string const&, unsigned int, long, std::function<tensorflow::Status ()> const&, std::function<bool ()> const&) () #39 0x000056434b8d3187 in tensorflow::serving::LoaderHarness::Load() () #40 0x000056434b8cd2ea in tensorflow::serving::BasicManager::ExecuteLoad(tensorflow::serving::LoaderHarness*) () #41 0x000056434b8cdacc in tensorflow::serving::BasicManager::ExecuteLoadOrUnload(tensorflow::serving::BasicManager::LoadOrUnloadRequest const&, tensorflow::serving::LoaderHarness*) () #42 0x000056434b8cfa1e in tensorflow::serving::BasicManager::HandleLoadOrUnloadRequest(tensorflow::serving::BasicManager::LoadOrUnloadRequest const&, std::function<void (tensorflow::Status const&)>) () #43 0x000056434b8cfad1 in std::_Function_handler<void (), tensorflow::serving::BasicManager::LoadOrUnloadServable(tensorflow::serving::BasicManager::LoadOrUnloadRequest const&, std::function<void (tensorflow::Status const&)>)::{lambda()#2}>::_M_invoke(std::_Any_data const&) () #44 0x0000564357182c71 in Eigen::ThreadPoolTempl<tensorflow::thread::EigenEnvironment>::WorkerLoop(int) () #45 0x0000564357180c13 in std::_Function_handler<void (), tensorflow::thread::EigenEnvironment::CreateThread(std::function<void ()>)::{lambda()#1}>::_M_invoke(std::_Any_data const&) () #46 0x000056435716d985 in tensorflow::(anonymous namespace)::PThread::ThreadFn(void*) () #47 0x00007f4ee3cae6db in start_thread (arg=0x7f4ead7fa700) at pthread_create.c:463 #48 0x00007f4ee2c8c61f in clone () at ../sysdeps/unix/sysv/linux/x86_64/clone.S:95

Hi @4sfaloth,

Thanks. This seems to be the same issue as spotted here, that is the inline vector allocation in ValidateStridedSliceOp.

The issue is solved for TensorFlow and TensorFlow Serving 2.11 (not yet released). I'll try to patch older versions of TensorFlow (e.g., 2.9 and 2.10).
The fix will be included in the next night release of TensorFlow Serving (see https://hub.docker.com/r/tensorflow/serving/tags; not yet done at the time I am writing those lines).

In the mean time, you can re-compile TF Serving from head, or use this pre-compiled version.

@Saoqq,

TF serving 2.11.0 is released. Please try the new release and let us know if your issue has been resolved. Thank you!

✔️ issue fixed with 2.11 for me

commented

@singhniraj08 Thanks, I don't have chance to test it near future.

I see at least someone replied that it has been fixed, so I'm closing it.