jgraving / DeepPoseKit

a toolkit for pose estimation using deep learning

Home Page:http://deepposekit.org

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Model config isn't properly saved

arminbahl opened this issue · comments

I am trying to run one of the example datasets on a fresh anaconda-environment installation for DeepPoseKit.

import tensorflow as tf
from deepposekit.io import VideoReader
from deepposekit.io import DataGenerator, TrainingGenerator
from deepposekit.models import StackedDenseNet
from deepposekit.models import load_model

data_generator = DataGenerator('/Users/arminbahl/deepposekit-data/datasets/fly/annotation_data_release.h5')
train_generator = TrainingGenerator(data_generator)
model = StackedDenseNet(train_generator)
model.fit(batch_size=4, n_workers=1)


model = load_model('/Users/arminbahl/deepposekit-data/datasets/fly/example_annotation_model.h5')
reader = VideoReader('/Users/arminbahl/deepposekit-data/datasets/fly/video.avi')
predictions = model.predict(reader)

The training runs fine and the model is saved.

However, when loading the model, I get an error

model = load_model('/Users/arminbahl/deepposekit-data/datasets/fly/example_annotation_model.h5')
  File "/Users/arminbahl/opt/anaconda3/envs/deepposekit/lib/python3.7/site-packages/deepposekit/models/loading.py", line 124, in load_model
KeyError: 'image_shape'

Any idea what might be wrong here?

Thanks for helping!

I use a dedicated conda environment:

conda create --name deepposekit --channel conda-forge python=3.7
conda activate deepposekit
conda install --yes -c conda-forge shapely scikit-image scikit-learn
pip install tensorflow deepposekit

Thanks again. Should be fixed now. Reopen if you're still having this issue. Posting the relevant part of your email here for posterity:

I dived a little bit into your code and found a solution in the function get_config() in engine.py (https://github.com/jgraving/DeepPoseKit/blob/master/deepposekit/models/engine.py#L235). self.train_generator exists (<deepposekit.io.TrainingGenerator.TrainingGenerator object at 0x7ff1f0078438>) but if self.train_generator produces a False I think it should rather say if self.train_generator is not None? Having changed this, the dictionary gets properly populated and contains 'image_shape' now…