مقارنة البحث العشوائي والبحث الشبكي لتقدير فرط المعلمات#

قارن بين البحث العشوائي والبحث الشبكي لتحسين فرط معلمات SVM الخطي مع التدريب SGD. يتم البحث عن جميع المعلمات التي تؤثر على التعلم في نفس الوقت (باستثناء عدد المعلمات، والذي يمثل مفاضلة بين الوقت والجودة).

يستكشف البحث العشوائي والبحث الشبكي نفس مساحة المعلمات بالضبط. النتيجة في إعدادات المعلمات متشابهة جدًا، في حين أن وقت التشغيل للبحث العشوائي أقل بشكل كبير.

قد يكون الأداء أسوأ قليلاً بالنسبة للبحث العشوائي، ومن المحتمل أن يكون ذلك بسبب تأثير الضوضاء ولن ينتقل إلى مجموعة اختبار محفوظة.

لاحظ أنه في الممارسة العملية، لن يتم البحث عن هذا العدد الكبير من المعلمات المختلفة في نفس الوقت باستخدام البحث الشبكي، ولكن سيتم اختيار المعلمات التي تعتبر الأكثر أهمية فقط.

استغرق البحث العشوائي RandomizedSearchCV 0.65 ثانية لـ 15 إعدادات المعلمات المرشحة.
نموذج مع الترتيب: 1
متوسط درجة التحقق: 0.989 (الانحراف المعياري: 0.007)
المعلمات: {'alpha': np.float64(0.015412781495004139), 'average': False, 'l1_ratio': np.float64(0.08311249263060239)}

نموذج مع الترتيب: 2
متوسط درجة التحقق: 0.989 (الانحراف المعياري: 0.009)
المعلمات: {'alpha': np.float64(0.018348295456200724), 'average': True, 'l1_ratio': np.float64(0.8413861191973544)}

نموذج مع الترتيب: 3
متوسط درجة التحقق: 0.985 (الانحراف المعياري: 0.016)
المعلمات: {'alpha': np.float64(0.01563876852642779), 'average': True, 'l1_ratio': np.float64(0.6736596308357894)}

استغرق البحث الشبكي GridSearchCV 3.73 ثانية لـ 60 إعدادات المعلمات المرشحة.
نموذج مع الترتيب: 1
متوسط درجة التحقق: 0.993 (الانحراف المعياري: 0.007)
المعلمات: {'alpha': np.float64(0.01), 'average': False, 'l1_ratio': np.float64(0.4444444444444444)}

نموذج مع الترتيب: 2
متوسط درجة التحقق: 0.991 (الانحراف المعياري: 0.014)
المعلمات: {'alpha': np.float64(0.1), 'average': False, 'l1_ratio': np.float64(0.0)}

نموذج مع الترتيب: 3
متوسط درجة التحقق: 0.991 (الانحراف المعياري: 0.008)
المعلمات: {'alpha': np.float64(0.01), 'average': False, 'l1_ratio': np.float64(0.6666666666666666)}

نموذج مع الترتيب: 3
متوسط درجة التحقق: 0.991 (الانحراف المعياري: 0.010)
المعلمات: {'alpha': np.float64(1.0), 'average': False, 'l1_ratio': np.float64(0.0)}

# المؤلفون: مطوري سكايلرن
# معرف الترخيص: BSD-3-Clause

from time import time

import numpy as np
import scipy.stats as stats

from sklearn.datasets import load_digits
from sklearn.linear_model import SGDClassifier
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV

# الحصول على بعض البيانات
X, y = load_digits(return_X_y=True, n_class=3)

# بناء مصنف
clf = SGDClassifier(loss="hinge", penalty="elasticnet", fit_intercept=True)


# دالة مساعدة للإبلاغ عن أفضل النتائج
def report(results, n_top=3):
    for i in range(1, n_top + 1):
        candidates = np.flatnonzero(results["rank_test_score"] == i)
        for candidate in candidates:
            print("نموذج مع الترتيب: {0}".format(i))
            print(
                "متوسط درجة التحقق: {0:.3f} (الانحراف المعياري: {1:.3f})".format(
                    results["mean_test_score"][candidate],
                    results["std_test_score"][candidate],
                )
            )
            print("المعلمات: {0}".format(results["params"][candidate]))
            print("")


# تحديد المعلمات والتوزيعات للعينة منها
param_dist = {
    "average": [True, False],
    "l1_ratio": stats.uniform(0, 1),
    "alpha": stats.loguniform(1e-2, 1e0),
}

# تشغيل البحث العشوائي
n_iter_search = 15
random_search = RandomizedSearchCV(
    clf, param_distributions=param_dist, n_iter=n_iter_search
)

start = time()
random_search.fit(X, y)
print(
    "استغرق البحث العشوائي RandomizedSearchCV %.2f ثانية لـ %d إعدادات المعلمات المرشحة."
    % ((time() - start), n_iter_search)
)
report(random_search.cv_results_)

# استخدام شبكة كاملة عبر جميع المعلمات
param_grid = {
    "average": [True, False],
    "l1_ratio": np.linspace(0, 1, num=10),
    "alpha": np.power(10, np.arange(-2, 1, dtype=float)),
}

# تشغيل البحث الشبكي
grid_search = GridSearchCV(clf, param_grid=param_grid)
start = time()
grid_search.fit(X, y)

print(
    "استغرق البحث الشبكي GridSearchCV %.2f ثانية لـ %d إعدادات المعلمات المرشحة."
    % (time() - start, len(grid_search.cv_results_["params"]))
)
report(grid_search.cv_results_)

Total running time of the script: (0 minutes 4.398 seconds)

Related examples

استراتيجية إعادة الضبط المخصصة للبحث الشبكي مع التحقق المتقاطع

استراتيجية إعادة الضبط المخصصة للبحث الشبكي مع التحقق المتقاطع

عمليات التقسيم المتتالية

عمليات التقسيم المتتالية

مقارنة بين البحث الشبكي وتقليص الخيارات المتتابع

مقارنة بين البحث الشبكي وتقليص الخيارات المتتابع

مثال على خط أنابيب لاستخراج ميزات النص وتقييمها

مثال على خط أنابيب لاستخراج ميزات النص وتقييمها

Gallery generated by Sphinx-Gallery