In [1]:
%matplotlib inline

import tensorflow as tf
from keras.backend.tensorflow_backend import set_session, get_session
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.9
config.gpu_options.allow_growth = True
set_session(tf.Session(config=config))

import sys
sys.path.append('/mnt/raid/cheetahs/modules/')

from inception.inception_resnet_v2 import InceptionResNetV2, preprocess_input
from keras.preprocessing.image import ImageDataGenerator
from keras.layers import Dense, Dropout, Input, concatenate
from keras.models import Model
from keras.optimizers import Nadam
from scipy.ndimage.interpolation import rotate
from sklearn.metrics import precision_recall_curve, classification_report, accuracy_score, confusion_matrix
import pandas as pd
import numpy as np
import itertools
import matplotlib.pyplot as plt
Using TensorFlow backend.
In [2]:
# Monkey-patch keras DirectoryIterator to also return filename

import keras
from keras_util.util import DirectoryIteratorWithFname
keras.preprocessing.image.DirectoryIterator = DirectoryIteratorWithFname
In [3]:
# Config

data_path = '/mnt/raid/cheetahs/data/train/'
val_data_path = '/mnt/raid/cheetahs/data/val/'
batch_size = 32

Data loading

In [4]:
# Load and parse ImageNet class labels

classes = open('/mnt/raid/cheetahs/modules/imagenet_classes', 'r').readlines()

def strip(c):
    key, value = c.split(':')
    key = key.strip()
    key = key.split('{')[-1]
    value = value.split("'")[1].strip()
    return int(key), value

classes = dict([strip(c) for c in classes])
In [5]:
izw_classes = ('unknown', 'cheetah', 'leopard')
In [6]:
metadata = pd.read_hdf('/mnt/raid/cheetahs/modules/metadata.hdf5')
In [7]:
metadata.head()
Out[7]:
ambient_temp brightness contrast datetime event1 event2 filename hour label path ... sequence_max serial_no set sharpness event_key_simple sortkey timeoffset event_key duplicate new_path
13657 14 0 160 1494269790000000000 0 587 Leopard_000051.jpeg 18 leopard /mnt/raid/cheetahs/data/train/leopard/Leopard_... ... 3 H600HG01173547 train 32 H600HG01173547_2017_128_587 H600HG01173547_2017_128_5871494269790000000000 -9223372036854775808 H600HG01173547_2017_128_587 False /mnt/raid/cheetahs/data2/train/leopard/0_H600H...
14356 14 0 160 1494269791000000000 0 587 Leopard_000052.jpeg 18 leopard /mnt/raid/cheetahs/data/train/leopard/Leopard_... ... 3 H600HG01173547 train 32 H600HG01173547_2017_128_587 H600HG01173547_2017_128_5871494269791000000000 1000000000 H600HG01173547_2017_128_587 False /mnt/raid/cheetahs/data2/train/leopard/1_H600H...
14364 14 0 160 1494269792000000000 0 587 Leopard_000053.jpeg 18 leopard /mnt/raid/cheetahs/data/train/leopard/Leopard_... ... 3 H600HG01173547 train 32 H600HG01173547_2017_128_587 H600HG01173547_2017_128_5871494269792000000000 1000000000 H600HG01173547_2017_128_587 False /mnt/raid/cheetahs/data2/train/leopard/2_H600H...
14975 14 0 160 1494269794000000000 0 588 Leopard_000054.jpeg 18 leopard /mnt/raid/cheetahs/data/train/leopard/Leopard_... ... 3 H600HG01173547 train 32 H600HG01173547_2017_128_588 H600HG01173547_2017_128_5881494269794000000000 2000000000 H600HG01173547_2017_128_587 False /mnt/raid/cheetahs/data2/train/leopard/3_H600H...
13928 14 0 160 1494269795000000000 0 588 Leopard_000055.jpeg 18 leopard /mnt/raid/cheetahs/data/train/leopard/Leopard_... ... 3 H600HG01173547 train 32 H600HG01173547_2017_128_588 H600HG01173547_2017_128_5881494269795000000000 1000000000 H600HG01173547_2017_128_587 False /mnt/raid/cheetahs/data2/train/leopard/4_H600H...

5 rows × 22 columns

In [8]:
# Crop camera metainformation from images

def preprocess(data, rotate_range=None):
    for x, y, fns in data:
        batch_metadata = []
        for fname in fns:
            fname_splitted = fname.split('_')
            index = fname_splitted[0]
            rest = '_'.join(fname_splitted[1:]).split('.jpeg')[0]
            f_metadata = metadata.iloc[int(index)]
            batch_metadata.append((
                f_metadata.ambient_temp,
                f_metadata.hour))
            # optionally use metadata
        temperatures = np.array(batch_metadata).astype(np.float32)
        x = x[:, 10:-10, 10:-10, :]
        if rotate_range is not None:
            for idx in range(batch_size):
                x[idx] = rotate(x[idx], np.random.random() * rotate_range * 2 - rotate_range, 
                                mode='reflect', reshape=False).astype(np.int64)
        yield [preprocess_input(x), temperatures], y
In [9]:
# Augment train data with horizontal flips, scale to ImageNet input size

generator = ImageDataGenerator(horizontal_flip=True)
val_generator = ImageDataGenerator(horizontal_flip=False)

train_gen = preprocess(generator.flow_from_directory(
    data_path, 
    target_size=(299+20, 299+20),
    classes=izw_classes,
    batch_size=batch_size), rotate_range=10)

val_gen = preprocess(val_generator.flow_from_directory(
    val_data_path, 
    target_size=(299+20, 299+20),
    classes=izw_classes,
    batch_size=batch_size))
Found 17857 images belonging to 3 classes.
Found 1915 images belonging to 3 classes.
In [10]:
# Test data loader

plt.figure(figsize=(7, 7))
plt.imshow(1 - ((next(train_gen)[0][0][0] / 2) + .5) * 255, vmin=0, vmax=255)
Out[10]:
<matplotlib.image.AxesImage at 0x7f497c42bf98>

Test pretrained model

In [11]:
# Load pretrained model
#
# http://arxiv.org/abs/1602.07261
#
# Inception-v4, Inception-ResNet and the Impact of Residual Connections
# on Learning
#
# Christian Szegedy, Sergey Ioffe, Vincent Vanhoucke, Alex Alemi

model = InceptionResNetV2()
In [12]:
# Test pretrained model on IZW data

batch, true_labels = next(val_gen)
fig, axes = plt.subplots(16, 2, figsize=(14, (14 / 2) * 16))
for idx, (image, label, true_label) in enumerate(zip(batch[0], model.predict(batch[0]), true_labels)):
    r, c = divmod(idx, 2)
    axes[r, c].imshow(1 - ((image / 2) + .5) * 255, vmin=0, vmax=255)
    axes[r, c].set_title('P: {} ({:.1f}%) - L: {}'.format(
        classes[label.argmax()],
        label[label.argmax()] * 100,
        izw_classes[true_label.argmax()]))
    axes[r, c].grid('off')
    axes[r, c].set_axis_off()

Model finetuning

In [13]:
# Use pretrained model, but skip last (classifier) layer and replace it with a new layer for the three IZW classes
# Fix pretrained layers, only learn last weights for last layer
# Also use metadata as additional input

metadata_input = Input(shape=(2, ))

base_model = InceptionResNetV2(include_top=False, pooling='avg')
h = concatenate([base_model.output, metadata_input])
h = Dropout(.2, name='Dropout')(h)
outputs = Dense(3, activation='softmax')(h)

for layer in base_model.layers:
    layer.trainable = False
    
model = Model(base_model.inputs + [metadata_input], outputs)
In [14]:
optim = Nadam(0.001)
model.compile(loss='categorical_crossentropy', optimizer=optim, metrics=['accuracy'])
model.fit_generator(train_gen, 17857 / batch_size, validation_data=(val_gen), validation_steps=1915 / batch_size,
                    epochs=3, workers=16, use_multiprocessing=True)
/usr/local/lib/python3.5/dist-packages/keras/engine/training.py:1786: UserWarning: Using a generator with `use_multiprocessing=True` and multiple workers may duplicate your data. Please consider using the`keras.utils.Sequence class.
  UserWarning('Using a generator with `use_multiprocessing=True`'
Epoch 1/3
558/558 [============================>.] - ETA: 0s - loss: 0.3244 - acc: 0.8794
/usr/local/lib/python3.5/dist-packages/keras/engine/training.py:1937: UserWarning: Using a generator with `use_multiprocessing=True` and multiple workers may duplicate your data. Please consider using the`keras.utils.Sequence class.
  UserWarning('Using a generator with `use_multiprocessing=True`'
559/558 [==============================] - 248s - loss: 0.3240 - acc: 0.8795 - val_loss: 0.2662 - val_acc: 0.9104
Epoch 2/3
559/558 [==============================] - 244s - loss: 0.2381 - acc: 0.9153 - val_loss: 0.2492 - val_acc: 0.8938
Epoch 3/3
559/558 [==============================] - 243s - loss: 0.2462 - acc: 0.9052 - val_loss: 0.2229 - val_acc: 0.9125
Out[14]:
<keras.callbacks.History at 0x7f48c45f61d0>
In [15]:
# Now finetune the whole model for five more epochs with reduced learning rate

for layer in base_model.layers:
    layer.trainable = True

optim = Nadam(0.0001)
model.compile(loss='categorical_crossentropy', optimizer=optim, metrics=['accuracy'])
model.fit_generator(train_gen, 17857 / batch_size, validation_data=(val_gen), validation_steps=1915 / batch_size,
                    epochs=5, workers=16, use_multiprocessing=True)
/usr/local/lib/python3.5/dist-packages/keras/engine/training.py:1786: UserWarning: Using a generator with `use_multiprocessing=True` and multiple workers may duplicate your data. Please consider using the`keras.utils.Sequence class.
  UserWarning('Using a generator with `use_multiprocessing=True`'
Epoch 1/5
558/558 [============================>.] - ETA: 0s - loss: 0.0378 - acc: 0.9909
/usr/local/lib/python3.5/dist-packages/keras/engine/training.py:1937: UserWarning: Using a generator with `use_multiprocessing=True` and multiple workers may duplicate your data. Please consider using the`keras.utils.Sequence class.
  UserWarning('Using a generator with `use_multiprocessing=True`'
559/558 [==============================] - 657s - loss: 0.0377 - acc: 0.9909 - val_loss: 0.4755 - val_acc: 0.9208
Epoch 2/5
559/558 [==============================] - 638s - loss: 0.0230 - acc: 0.9940 - val_loss: 0.2539 - val_acc: 0.9437
Epoch 3/5
559/558 [==============================] - 639s - loss: 0.0223 - acc: 0.9944 - val_loss: 0.2214 - val_acc: 0.9437
Epoch 4/5
559/558 [==============================] - 639s - loss: 0.0171 - acc: 0.9959 - val_loss: 0.2924 - val_acc: 0.9313
Epoch 5/5
559/558 [==============================] - 638s - loss: 0.0108 - acc: 0.9974 - val_loss: 0.3764 - val_acc: 0.8953
Out[15]:
<keras.callbacks.History at 0x7f48b1ac9e48>
In [18]:
model.save('/mnt/raid/cheetahs/results/inception-resnet-v2-cheetahs.h5')

Evaluation

In [19]:
# Visualize model predictions

batch, true_labels = next(val_gen)
fig, axes = plt.subplots(16, 2, figsize=(14, (14 / 2) * 16))
for idx, (image, label, true_label) in enumerate(zip(batch[0], model.predict(batch), true_labels)):
    r, c = divmod(idx, 2)
    axes[r, c].imshow(1 - ((image / 2) + .5) * 255, vmin=0, vmax=255)
    axes[r, c].set_title('P: {} ({:.1f}%) - L: {}'.format(
        izw_classes[label.argmax()],
        label[label.argmax()] * 100,
        izw_classes[true_label.argmax()]))
    axes[r, c].grid('off')
    axes[r, c].set_axis_off()
In [20]:
val_labels = []
val_preds = []
val_probs = []

for idx, (batch, labels) in enumerate(val_gen):
    pred = model.predict(batch)
    val_preds.append(pred.argmax(axis=1))
    val_probs.append(pred)
    val_labels.append(labels.argmax(axis=1))
            
    if idx == (2256 // batch_size):
        break
In [21]:
val_preds = list(itertools.chain(*val_preds))
val_labels = list(itertools.chain(*val_labels))
In [22]:
print('Accuracy: {:.3f}%\n'.format(accuracy_score(val_labels, val_preds) * 100))
print(classification_report(val_labels, val_preds, target_names=izw_classes))
Accuracy: 94.795%

             precision    recall  f1-score   support

    unknown       0.88      1.00      0.94       447
    cheetah       0.99      0.95      0.97      1686
    leopard       0.74      0.78      0.76       134

avg / total       0.95      0.95      0.95      2267

In [23]:
def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    thresh = cm.max() / 1.001
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        s = '{:.3f}'.format(cm[i, j]) if normalize else cm[i,j]
        plt.text(j, i, s,
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
In [24]:
cnf_matrix = confusion_matrix(val_labels, val_preds)
plt.figure(figsize=(6, 6))
plot_confusion_matrix(cnf_matrix, classes=izw_classes,
                      normalize=False, title='Confusion matrix')
In [25]:
cheetah_val_probs = list(itertools.chain(*[p[:, izw_classes.index('cheetah')] for p in val_probs]))
cheetah_val_labels = np.array(val_labels) == izw_classes.index('cheetah')
In [26]:
precision, recall, thresholds = precision_recall_curve(cheetah_val_labels, cheetah_val_probs)
In [27]:
plt.style.use('seaborn-notebook')
plt.figure(figsize=(8, 6))
plt.plot(recall, precision)
plt.plot(recall[precision > .999], precision[precision > .999], c='green')
plt.plot(recall[recall > .999], precision[recall > .999], c='green')
plt.ylabel('Precision')
plt.xlabel('Recall')
plt.title('Precision-Recall curve for cheetahs')
plt.ylim([precision.min()-0.01, precision.max()+0.01])
plt.xlim([-0.05, 1.05])
Out[27]:
(-0.05, 1.05)
In [28]:
precision_max_recall = precision[np.where(recall > .999)[0][-1]]
print('Precision for cheetahs for recall > 99.9%: {:.1f}%'.format(precision_max_recall * 100))
Precision for cheetahs for recall > 99.9%: 88.9%
In [29]:
def most_common(lst):
    return max(set(lst), key=lst.count)
In [48]:
from keras.preprocessing.image import load_img, img_to_array
import keras.backend as K
import os
from tqdm import tqdm_notebook

metadata['new_set'] = metadata.new_path.apply(lambda p: 'val' if 'val' in p else 'train')

group_labels = []
group_preds = []
pred_vars = []

t_iterator = tqdm_notebook(metadata[metadata.new_set == 'val'].groupby('event_key'))
for event_key, group in t_iterator:
    batch_x = np.zeros((len(group), 299+20, 299+20, 3), dtype=K.floatx())
    batch_t = np.zeros((len(group), ), dtype=K.floatx())
    batch_h = np.zeros((len(group), ), dtype=K.floatx())
    group_iterator = enumerate(zip(group.index, group.new_path, group.ambient_temp, group.hour))
    for i, (file_idx, path, temp, hour) in group_iterator:
        file_path = os.path.join(*([val_data_path] + path.split('/')[-2:]))
        img = load_img(file_path, grayscale=False, target_size=(299+20, 299+20))
        x = img_to_array(img)
        x = val_generator.random_transform(x)
        x = val_generator.standardize(x)
        batch_x[i] = x
        batch_t[i] = temp
        batch_h[i] = hour
    batch_x = batch_x[:, 10:-10, 10:-10, :]
    batch_x = [preprocess_input(batch_x), np.stack((batch_t, batch_h), axis=1)]
    group_label = group.label.unique()
    assert(len(group_label) == 1)
    group_label = izw_classes.index(group_label[0])
    pred = model.predict_on_batch(batch_x)
    group_pred = np.where(pred == pred.max())[1][0]#most_common(list(pred.argmax(axis=1)))# #pred.mean(axis=0).argmax()
    pred_vars.append(pred.var(axis=0).mean())
    
    for _ in range(len(group)):
        group_labels.append(group_label)
        group_preds.append(group_pred)
    
    if group_pred != group_label:
        print('Prediction => {}: {:.2f}, {}: {:.2f}, {}: {:.2f}, variance: {:.2f}'.format(
            *(list((itertools.chain(*(zip(izw_classes, pred[np.where(pred == pred.max())[0], :][0])))))) + [pred_vars[-1]]))
        print('Label: {}'.format(izw_classes[group_label]))

        for img in batch_x[0]:
            plt.figure(figsize=(4, 4))
            plt.imshow(1 - ((img / 2) + .5) * 255, vmin=0, vmax=255)
            plt.show()
            plt.close()
    
    t_iterator.set_description('Accuracy: {:.3f}, Mean variance: {:.3f}'.format(
        accuracy_score(group_labels, group_preds), np.mean(pred_vars)))
Widget Javascript not detected.  It may not be installed or enabled properly.
Prediction => unknown: 0.00, cheetah: 0.98, leopard: 0.02, variance: 0.00
Label: leopard
Prediction => unknown: 0.00, cheetah: 1.00, leopard: 0.00, variance: 0.00
Label: leopard
Prediction => unknown: 1.00, cheetah: 0.00, leopard: 0.00, variance: 0.00
Label: cheetah
Prediction => unknown: 0.06, cheetah: 0.82, leopard: 0.12, variance: 0.05
Label: leopard
Prediction => unknown: 0.97, cheetah: 0.03, leopard: 0.01, variance: 0.08
Label: cheetah