DL之RBM:基于RBM实现手写数字图片识别提高准确率

DL之RBM:基于RBM实现手写数字图片识别提高准确率


输出结果

设计代码

import numpy as np
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn import metrics,linear_model
from sklearn.neural_network import BernoulliRBM
from sklearn.datasets import load_digits
from sklearn.pipeline import Pipeline    

digits = load_digits()
X = digits.data
y = digits.target     

X -= X.min()
X /= X.max()
X_train, X_test, y_train, y_test = train_test_split(X, y)  

logistic = linear_model.LogisticRegression()
rbm = BernoulliRBM(random_state=0, verbose=True)
classifier = Pipeline(steps=[('rbm', rbm), ('logistic',logistic)]) 

rbm.learning_rate = 0.06
rbm.n_iter = 20
rbm.n_components = 200
logistic.C = 6000.0
classifier.fit (X_train,y_train)  

print()
print("Logistic regression using RBM features:\n%s\n"%(
    metrics.classification_report(y_test,classifier.predict(X_test)) 
(0)

相关推荐