yahoo / open_nsfw

Not Suitable for Work (NSFW) classification using deep neural network Caffe models.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

The python script does not run on Python 3 - Probably an issue related to StringIO and PIL

fabianfrz opened this issue · comments

After changing some lines, it is still not working (however the includes and the syntax seem to be ok):

  File "classify_nsfw.py", line 128, in <module>
    main(sys.argv)
  File "classify_nsfw.py", line 104, in main
    image_data = open(args.input_file).read()
  File "/usr/lib/python3.5/codecs.py", line 321, in decode
    (result, consumed) = self._buffer_decode(data, self.errors, final)
UnicodeDecodeError: 'utf-8' codec can't decode byte 0xff in position 0: invalid start byte

Changed

from StringIO import StringIO

to

from io import StringIO

and

print "NSFW score:  " , scores[1]

to

print("NSFW score:  %f" %  float(scores[1]))

I was able to circumvent the issue by using

image_data = open(args.input_file,"rb").read()

instead of

image_data = open(args.input_file).read()

however it is still broken, but it looks related to PIL

Traceback (most recent call last):
  File "classify_nsfw.py", line 128, in <module>
    main(sys.argv)
  File "classify_nsfw.py", line 119, in main
    scores = caffe_preprocess_and_compute(image_data, caffe_transformer=caffe_transformer, caffe_net=nsfw_net, output_layers=['prob'])
  File "classify_nsfw.py", line 62, in caffe_preprocess_and_compute
    img_data_rs = resize_image(pimg, sz=(256, 256))
  File "classify_nsfw.py", line 31, in resize_image
    im = Image.open(StringIO(img_data))
  File "/usr/lib/python3/dist-packages/PIL/Image.py", line 2319, in open
    % (filename if filename else fp))
OSError: cannot identify image file <_io.StringIO object at 0x7fcb3a242dc8>

I made it finally working by replacing StringIO by BytesIO

@fabianfrz I have the same problem as you. Could you paste the final classify_nsfw.py? Thank you !

commented

run on Python3

#!/usr/bin/env python
"""
Copyright 2016 Yahoo Inc.
Licensed under the terms of the 2 clause BSD license. 
Please see LICENSE file in the project root for terms.
"""

import argparse
import glob
import os
import sys
import time
from io import BytesIO

import caffe
import numpy as np
from PIL import Image


def resize_image(data, sz=(256, 256)):
    """
    Resize image. Please use this resize logic for best results instead of the 
    caffe, since it was used to generate training dataset 
    :param byte data:
        The image data
    :param sz tuple:
        The resized image dimensions
    :returns bytearray:
        A byte array with the resized image
    """
    im = Image.open(BytesIO(data))
    if im.mode != "RGB":
        im = im.convert('RGB')
    imr = im.resize(sz, resample=Image.BILINEAR)
    fh_im = BytesIO()
    imr.save(fh_im, format='JPEG')
    fh_im.seek(0)
    return fh_im


def caffe_preprocess_and_compute(pimg, caffe_transformer=None, caffe_net=None,
                                 output_layers=None):
    """
    Run a Caffe network on an input image after preprocessing it to prepare
    it for Caffe.
    :param PIL.Image pimg:
        PIL image to be input into Caffe.
    :param caffe.Net caffe_net:
    :param list output_layers:
        A list of the names of the layers from caffe_net whose outputs are to
        to be returned.  If this is None, the default outputs for the network
        are returned.
    :return:
        Returns the requested outputs from the Caffe net.
    """
    if caffe_net is not None:

        # Grab the default output names if none were requested specifically.
        if output_layers is None:
            output_layers = caffe_net.outputs

        img_bytes = resize_image(pimg, sz=(256, 256))
        image = caffe.io.load_image(img_bytes)

        H, W, _ = image.shape
        _, _, h, w = caffe_net.blobs['data'].data.shape
        h_off = max((H - h) / 2, 0)
        w_off = max((W - w) / 2, 0)
        crop = image[int(h_off):int(h_off + h), int(w_off):int(w_off + w), :]
        transformed_image = caffe_transformer.preprocess('data', crop)
        transformed_image.shape = (1,) + transformed_image.shape

        input_name = caffe_net.inputs[0]
        all_outputs = caffe_net.forward_all(blobs=output_layers,
                                            **{input_name: transformed_image})

        outputs = all_outputs[output_layers[0]][0].astype(float)
        return outputs
    else:
        return []


def main(argv):
    pycaffe_dir = os.path.dirname(__file__)

    parser = argparse.ArgumentParser()
    # Required arguments: input file.
    parser.add_argument(
        "input_file",
        help="Path to the input image file"
    )

    # Optional arguments.
    parser.add_argument(
        "--model_def",
        help="Model definition file."
    )
    parser.add_argument(
        "--pretrained_model",
        help="Trained model weights file."
    )

    args = parser.parse_args()
    image_data = open(args.input_file, 'rb').read()
    # Pre-load caffe model.
    nsfw_net = caffe.Net(args.model_def,  # pylint: disable=invalid-name
                         args.pretrained_model, caffe.TEST)

    # Load transformer
    # Note that the parameters are hard-coded for best results
    caffe_transformer = caffe.io.Transformer({'data': nsfw_net.blobs['data'].data.shape})
    caffe_transformer.set_transpose('data', (2, 0, 1))  # move image channels to outermost
    caffe_transformer.set_mean('data', np.array([104, 117, 123]))  # subtract the dataset-mean value in each channel
    caffe_transformer.set_raw_scale('data', 255)  # rescale from [0, 1] to [0, 255]
    caffe_transformer.set_channel_swap('data', (2, 1, 0))  # swap channels from RGB to BGR

    # Classify.
    scores = caffe_preprocess_and_compute(image_data, caffe_transformer=caffe_transformer, caffe_net=nsfw_net,
                                          output_layers=['prob'])

    # Scores is the array containing SFW / NSFW image probabilities
    # scores[1] indicates the NSFW probability
    print("NSFW score: %s " % scores[1])


if __name__ == '__main__':
    main(sys.argv)