lexxai / goit_python_ds_hw_09

Модуль 9. Підбір гіперпараметрів НМ. Глибоке навчання. Tensorflow. Keras.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Модуль 9. Підбір гіперпараметрів НМ. Глибоке навчання. Tensorflow. Keras.

З циклу домашніх завдань Python Data Science.

Домашнє завдання

В якості домашнього завдання вам пропонується створити нейронну мережу за допомогою механізмів Keras, яка буде класифікувати товари із датасету fasion_mnist.

Вам належить запропонувати свою власну архітектуру мережі. Точність найнаївнішої, але адекватної нейромережі становить приблизно 91%. Точність вашої моделі повинна бути не нижчою за цей показник. Щоб досягти таких значень вам знадобиться поекспериментувати з гіперпараметрами мережі:

  • кількість шарів;
  • кількість нейронів;
  • функції активації;
  • кількість епох;
  • розмір батчу;
  • вибір оптимізатора;
  • різні техніки регуляризації і т.д.

Використайте вивчені техніки виявлення проблем навчання нейронної мережі, і потім поекспериментуйте.

Рішення оформіть у вигляді окремого ноутбука.

Результати

Типова архітектура нейронної мережі моделі 1.

_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 INPUT (Flatten)             (None, 784)               0         
                                                                 
 Hidden_Layer_1 (Dense)      (None, 256)               200960    
                                                                 
 Hidden_Layer_2 (Dense)      (None, 512)               131584    
                                                                 
 Hidden_Layer_3 (Dense)      (None, 256)               131328    
                                                                 
 Hidden_Layer_4 (Dense)      (None, 32)                8224      
                                                                 
 OUTPUT (Dense)              (None, 10)                330       
                                                                 
=================================================================
Total params: 472426 (1.80 MB)
Trainable params: 472426 (1.80 MB)
Non-trainable params: 0 (0.00 Byte)

IL=784, HL1=256 (relu), HL2=512 (relu, L2) , HL3=256 (relu), HL4=32 (relu), OL=10 (softmax)

Графік процесу навчання моделей

Model_1_L2_0.025_loss_acc Model_1_L2_0.025_ca Model_1_no_reg_ca Model_1_L2_0.1_ca Model_2_ca Model_2_L2_0.1_ca Model_2_L2_1_ca Model_3_DO_0.2_ca Model_3_DO_0.5_ca Model_2_lrelu_ca_ca Model_2_lrelu_b32_ca Model_2_lrelu_sgd_ca Model_2_lrelu_bnorm_ca

Порівняння моделей

bar_va

Результати навчання (classification report)

Classification report for: 'Model_2'
                 precision    recall  f1-score   support

0 - T-shirt/top       0.79      0.92      0.85      1000
1 - Trouser           0.99      0.97      0.98      1000
2 - Pullover          0.80      0.81      0.80      1000
3 - Dress             0.88      0.89      0.89      1000
4 - Coat              0.76      0.86      0.81      1000
5 - Sandal            0.97      0.97      0.97      1000
6 - Shirt             0.81      0.59      0.68      1000
7 - Sneaker           0.95      0.95      0.95      1000
8 - Bag               0.99      0.96      0.97      1000
9 - Ankle boot        0.95      0.97      0.96      1000

       accuracy                           0.89     10000
      macro avg       0.89      0.89      0.89     10000
   weighted avg       0.89      0.89      0.89     10000

Accuracy: 88.80%

Результати навчання (Confusion Matrix)

Model_2_cm

['0 - T-shirt/top', '1 - Trouser', '2 - Pullover', '3 - Dress', '4 - Coat', '5 - Sandal', '6 - Shirt', '7 - Sneaker', '8 - Bag', '9 - Ankle boot']

Візуалізація результатів предбачення

predict

Висновок

Використано ряд моделей нейронних мереж для заданого набору даних "Fashion-MNIST".

Для порівняння та вибору найкращої моделі були побудовані декілька моделей з різною кількістю шарів, та різними значеннями гіперпараметрів.

Для забезпечення зупинення перенавчання моделі (overfit) застосовано додавання регуляризатора в одному з шарів типу L2, з різними значеннями. Також протестовано моделі на різних активаційних функціях таких як "relu", та "leaky_relu".

Використано оптимізаційні алгоритми "Adam" та "SGD" для пошуку вагів моделі.

Крок навчання використано динамічний залежно від епохи. Залежність може бути двоступенева - до 7 епохи 0.005, далі крок навчання дорівнює 0.001, та експоненціальна.

Для автоматичного зупинення пошуку нових ваг для моделі було застосовано "Early stop" як по метриці "val_loss" (min) patience=5, так і по метриці "val_accuracy" (max) patience=10.

Побудовані графіки залежностей для кожної з моделей, з відображенням найкращої епохи для вагів, котрі були вибранні. Де можна побачити тенденцію моделі до перенавчання.

Завдяки застосуванню "Early stop" загальні характеристики моделей у на діаграмі "Validated accuracy" певним чином вирівнялися, і перенавчання своєчасно було зупинено.

Вибір найкращої моделі базується на її простоті, швидкості та якості за мінімальними значеннями втрат "val_loss", та максимальним значенням якості "val_accuracy".

За вибраною моделлю "Model_2" побудовані випадкові передбачення для 10 зображень з тестової вибірки набору даних та результати побачення де червоний колір означає передбачення було помилкове.

За значеннями результатів confusion matrix бачимо що найгірша ситуація з прогнозуванням категорії 6 - 'Shirt' [Сорочка] яка має найбільшу кількість помилок з категорією 0 'T-shirt/top' [Футболка/Топ] та '4 - Coat' [Пальто], що і показала тестова вибірка.

Моделі та історії розрахунків зберігаються до архівного файлу, і завантажуються з нього для пришвидшення отримання вже розрахованих результатів.

З метою порівняння моделі де валідаційна частина використовується як 10% від тестової частини, а тестова частина датасету виключно для перевірки моделі то для цього була створена окрема версія файлу проєкту.

About

Модуль 9. Підбір гіперпараметрів НМ. Глибоке навчання. Tensorflow. Keras.

License:MIT License


Languages

Language:Jupyter Notebook 100.0%