Simple image retrieval algorithm on deep-fashion dataset with pytorch
- Python (Compatible with 2 and 3)
- cv2(only for visualizing)
Anaconda is recommended.
- Download dataset from DeepFashion: Attribute Prediction
- Unzip all files and set
The models will be saved to
My model: Download from Google Drive
Deep Feature: ResNet50 - (Linear 1024 to 512) - (Linear 512 to 20), the 512-dim vector is regarded as images' identical features.
Loss: CrossEntropyLoss + TripletMarginLoss * Weight
Color Feature: Get ResNet50 final conv layer output(N * C * 7 * 7), then do avg_pooling on channel dim. Choose the max-N responses and extract the corresponding blocks on avg_pooling map of original image.
Training details: Freeze the conv parameters and train net until a stable accuracy and loss, then set FREEZE to False and train it again.
If you applied the DeepFashion: In-shop Clothes Retrieval, you can set
Generating feature databases
config.pyas trained model
The feature will be saved to
Accelerating querying by clustering
kmeans.pyto train the models, default 50 clusters.
The model will be saved as
Query with a picture
retrieval.py img_path, for example:
python retrieval.py img/Sheer_Pleated-Front_Blouse/img_00000005.jpg.
config.pyto use different metrics such as
euclideanon deep feature and color feature..
- 2.854 sec for loading model
- 0.078 sec for loading feature database
- 0.519 sec for extracting feature of given image
- 0.122 sec for doing naive query(139,709 features)
- 0.038 sec for doing query with kmeans(139,709 features)
- Intel(R) Core(TM) i7-4790K CPU @ 4.00GHz with 32GB RAM
- GeForce GTX TITAN X with CUDA 7.5
- Ubuntu 14.04
- Pytorch 0.2.0_4
- Add web support
- Add more models and fuse them