ميزات شبكة بولتزمان المقيدة لتصنيف الأرقام#

بالنسبة لبيانات الصور الرمادية، حيث يمكن تفسير قيم البكسل على أنها درجات من السواد على خلفية بيضاء، مثل التعرف على الأرقام المكتوبة بخط اليد، يمكن لنموذج شبكة بولتزمان المقيدة ذات التوزيع البرنولي (BernoulliRBM) أن يقوم باستخراج الميزات غير الخطية بشكل فعال.

# المؤلفون: مطوري مكتبة سكايلرن
# معرف الترخيص: BSD-3-Clause

توليد البيانات#

من أجل تعلم تمثيلات كامنة جيدة من مجموعة بيانات صغيرة، نقوم بتوليد المزيد من البيانات المُعَلَّمة بشكل اصطناعي عن طريق إزعاج بيانات التدريب مع تحولات خطية بمقدار 1 بكسل في كل اتجاه.

import numpy as np
from scipy.ndimage import convolve

from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import minmax_scale


def nudge_dataset(X, Y):
    """
    هذا ينتج مجموعة بيانات أكبر بخمس مرات من الأصلية،
    عن طريق تحريك الصور 8x8 في X حولها بمقدار 1px إلى اليسار، اليمين، الأسفل، الأعلى
    """
    direction_vectors = [
        [[0, 1, 0], [0, 0, 0], [0, 0, 0]],
        [[0, 0, 0], [1, 0, 0], [0, 0, 0]],
        [[0, 0, 0], [0, 0, 1], [0, 0, 0]],
        [[0, 0, 0], [0, 0, 0], [0, 1, 0]],
    ]

    def shift(x, w):
        return convolve(x.reshape((8, 8)), mode="constant", weights=w).ravel()

    X = np.concatenate(
        [X] + [np.apply_along_axis(shift, 1, X, vector) for vector in direction_vectors]
    )
    Y = np.concatenate([Y for _ in range(5)], axis=0)
    return X, Y


X, y = datasets.load_digits(return_X_y=True)
X = np.asarray(X, "float32")
X, Y = nudge_dataset(X, y)
X = minmax_scale(X, feature_range=(0, 1))  # 0-1 scaling

X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2, random_state=0)

تعريف النماذج#

نقوم ببناء خط أنابيب التصنيف مع مستخرج ميزات BernoulliRBM و مصنف LogisticRegression

from sklearn import linear_model
from sklearn.neural_network import BernoulliRBM
from sklearn.pipeline import Pipeline

logistic = linear_model.LogisticRegression(solver="newton-cg", tol=1)
rbm = BernoulliRBM(random_state=0, verbose=True)

rbm_features_classifier = Pipeline(steps=[("rbm", rbm), ("logistic", logistic)])

التدريب#

تم تحسين فرط معاملات النموذج بالكامل (معدل التعلم، حجم الطبقة المخفية، التنظيم) عن طريق البحث الشبكي، ولكن البحث غير مستنسخ هنا بسبب قيود وقت التشغيل.

from sklearn.base import clone

# فرط المعاملات. تم ضبط هذه القيم عن طريق التحقق من الصحة المتقاطعة،
# باستخدام GridSearchCV. هنا لا نقوم بالتحقق من الصحة المتقاطعة لتوفير الوقت.
rbm.learning_rate = 0.06
rbm.n_iter = 10

# المزيد من المكونات تميل إلى إعطاء أداء تنبؤ أفضل، ولكن وقت
# ملاءمة أكبر
rbm.n_components = 100
logistic.C = 6000

# تدريب خط أنابيب RBM-Logistic
rbm_features_classifier.fit(X_train, Y_train)

# تدريب مصنف الانحدار اللوجستي مباشرة على البكسل
raw_pixel_classifier = clone(logistic)
raw_pixel_classifier.C = 100.0
raw_pixel_classifier.fit(X_train, Y_train)
[BernoulliRBM] Iteration 1, pseudo-likelihood = -25.57, time = 0.13s
[BernoulliRBM] Iteration 2, pseudo-likelihood = -23.68, time = 0.26s
[BernoulliRBM] Iteration 3, pseudo-likelihood = -22.88, time = 0.18s
[BernoulliRBM] Iteration 4, pseudo-likelihood = -21.91, time = 0.20s
[BernoulliRBM] Iteration 5, pseudo-likelihood = -21.79, time = 0.17s
[BernoulliRBM] Iteration 6, pseudo-likelihood = -20.96, time = 0.16s
[BernoulliRBM] Iteration 7, pseudo-likelihood = -20.80, time = 0.19s
[BernoulliRBM] Iteration 8, pseudo-likelihood = -20.63, time = 0.16s
[BernoulliRBM] Iteration 9, pseudo-likelihood = -20.38, time = 0.20s
[BernoulliRBM] Iteration 10, pseudo-likelihood = -20.19, time = 0.16s
LogisticRegression(C=100.0, solver='newton-cg', tol=1)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


التقييم#

from sklearn import metrics

Y_pred = rbm_features_classifier.predict(X_test)
print(
    "انحدار لوجستي باستخدام ميزات RBM:\n%s\n"
    % (metrics.classification_report(Y_test, Y_pred))
)
/project/workspace/sklearn/metrics/_classification.py:1613: UndefinedMetricWarning:

Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.

/project/workspace/sklearn/metrics/_classification.py:1613: UndefinedMetricWarning:

Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.

/project/workspace/sklearn/metrics/_classification.py:1613: UndefinedMetricWarning:

Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.

انحدار لوجستي باستخدام ميزات RBM:
              precision    recall  f1-score   support

           0       0.10      1.00      0.18       174
           1       0.00      0.00      0.00       184
           2       0.00      0.00      0.00       166
           3       0.00      0.00      0.00       194
           4       0.00      0.00      0.00       186
           5       0.00      0.00      0.00       181
           6       0.00      0.00      0.00       207
           7       0.00      0.00      0.00       154
           8       0.00      0.00      0.00       182
           9       0.00      0.00      0.00       169

    accuracy                           0.10      1797
   macro avg       0.01      0.10      0.02      1797
weighted avg       0.01      0.10      0.02      1797
Y_pred = raw_pixel_classifier.predict(X_test)
print(
    "انحدار لوجستي باستخدام ميزات البكسل الخام:\n%s\n"
    % (metrics.classification_report(Y_test, Y_pred))
)
/project/workspace/sklearn/metrics/_classification.py:1613: UndefinedMetricWarning:

Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.

/project/workspace/sklearn/metrics/_classification.py:1613: UndefinedMetricWarning:

Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.

/project/workspace/sklearn/metrics/_classification.py:1613: UndefinedMetricWarning:

Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.

انحدار لوجستي باستخدام ميزات البكسل الخام:
              precision    recall  f1-score   support

           0       0.10      1.00      0.18       174
           1       0.00      0.00      0.00       184
           2       0.00      0.00      0.00       166
           3       0.00      0.00      0.00       194
           4       0.00      0.00      0.00       186
           5       0.00      0.00      0.00       181
           6       0.00      0.00      0.00       207
           7       0.00      0.00      0.00       154
           8       0.00      0.00      0.00       182
           9       0.00      0.00      0.00       169

    accuracy                           0.10      1797
   macro avg       0.01      0.10      0.02      1797
weighted avg       0.01      0.10      0.02      1797

تساعد الميزات المستخرجة بواسطة BernoulliRBM في تحسين دقة التصنيف فيما يتعلق بالانحدار اللوجستي على البكسل الخام.

الرسم#

import matplotlib.pyplot as plt

plt.figure(figsize=(4.2, 4))
for i, comp in enumerate(rbm.components_):
    plt.subplot(10, 10, i + 1)
    plt.imshow(comp.reshape((8, 8)), cmap=plt.cm.gray_r, interpolation="nearest")
    plt.xticks(())
    plt.yticks(())
plt.suptitle("100 مكون مستخرج بواسطة RBM", fontsize=16)
plt.subplots_adjust(0.08, 0.02, 0.92, 0.85, 0.08, 0.23)

plt.show()
100 مكون مستخرج بواسطة RBM

Total running time of the script: (0 minutes 4.100 seconds)

Related examples

إزالة الميزة المتكررة

إزالة الميزة المتكررة

رسم احتمالية التصنيف

رسم احتمالية التصنيف

التعرف على الأرقام المكتوبة بخط اليد

التعرف على الأرقام المكتوبة بخط اليد

ربط الأنابيب: ربط PCA والانحدار اللوجستي

ربط الأنابيب: ربط PCA والانحدار اللوجستي

Gallery generated by Sphinx-Gallery