zoukankan      html  css  js  c++  java
  • 采用libsvm进行mnist训练

     1 #coding:utf8
     2 import cPickle
     3 import gzip
     4 import numpy as np
     5 from sklearn.svm import libsvm
     6 
     7 
     8 class SVM(object):
     9     def __init__(self, kernel='rbf', degree=3, gamma='auto',
    10                  coef0=0.0, tol=1e-3, C=1.0,nu=0., epsilon=0.,shrinking=True, probability=False,
    11                   cache_size=200, class_weight=None, max_iter=-1):
    12         self.kernel = kernel
    13         self.degree = degree
    14         self.gamma = gamma
    15         self.coef0 = coef0
    16         self.tol = tol
    17         self.C = C
    18         self.nu = nu
    19         self.epsilon = epsilon
    20         self.shrinking = shrinking
    21         self.probability = probability
    22         self.cache_size = cache_size
    23         self.class_weight = class_weight
    24         self.max_iter = max_iter
    25 
    26     def fit(self, X, y):
    27         X= np.array(X, dtype=np.float64, order='C')
    28         cls, y = np.unique(y, return_inverse=True)
    29         weight = np.ones(cls.shape[0], dtype=np.float64, order='C')
    30         self.class_weight_=weight
    31         self.classes_ = cls
    32         y= np.asarray(y, dtype=np.float64, order='C')
    33         sample_weight = np.asarray([])
    34         solver_type =0
    35         self._gamma = 1.0 / X.shape[1]
    36         kernel = self.kernel
    37         seed = np.random.randint(np.iinfo('i').max)
    38         self.support_, self.support_vectors_, self.n_support_, 
    39             self.dual_coef_, self.intercept_, self.probA_, 
    40             self.probB_, self.fit_status_ = libsvm.fit(
    41                 X, y,
    42                 svm_type=solver_type, sample_weight=sample_weight,
    43                 class_weight=self.class_weight_, kernel=kernel, C=self.C,
    44                 nu=self.nu, probability=self.probability, degree=self.degree,
    45                 shrinking=self.shrinking, tol=self.tol,
    46                 cache_size=self.cache_size, coef0=self.coef0,
    47                 gamma=self._gamma, epsilon=self.epsilon,
    48                 max_iter=self.max_iter, random_seed=seed)
    49         self.shape_fit_ = X.shape
    50         self._intercept_ = self.intercept_.copy()
    51         self._dual_coef_ = self.dual_coef_
    52         self.intercept_ *= -1
    53         self.dual_coef_ = -self.dual_coef_
    54         return self
    55 
    56     def predict(self, X):
    57         X= np.array(X,dtype=np.float64, order='C')
    58         svm_type = 0
    59         return libsvm.predict(
    60             X, self.support_, self.support_vectors_, self.n_support_,
    61             self._dual_coef_, self._intercept_,
    62             self.probA_, self.probB_, svm_type=svm_type, kernel=self.kernel,
    63             degree=self.degree, coef0=self.coef0, gamma=self._gamma,
    64             cache_size=self.cache_size)
    65 
    66 def load_data():
    67     f = gzip.open('../data/mnist.pkl.gz', 'rb')
    68     training_data, validation_data, test_data = cPickle.load(f)
    69     f.close()
    70     return (training_data, validation_data, test_data)
    71 
    72 def svm_test():
    73     training_data, validation_data, test_data = load_data()
    74     clf = SVM(kernel='linear')   # 'linear', 'poly', 'rbf', 'sigmoid', 'precomputed'
    75     clf.fit(training_data[0][:10000], training_data[1][:10000])
    76     predictions = [int(a) for a in clf.predict(test_data[0][:10000])]
    77     num_correct = sum(int(a == y) for a, y in zip(predictions, test_data[1][:10000]))
    78     print "Baseline classifier using an SVM."
    79     print "%s of %s values correct." % (num_correct, len(test_data[1][:10000]))   # 0.9172  'rbf'=0.9214
    80 
    81 if __name__ == "__main__":
    82     svm_test()
  • 相关阅读:
    使用sql语句查询表结构
    plsql出现录相机后卡屏解决方法
    oracle的“ORA-01480:STR绑定值的结尾Null字符缺失”错误
    oracle创建表空间并对用户赋权
    Scrapy安装错误(error: Microsoft Visual C++ 14.0 is required. Get it with "Microsoft Visual C++ Build Tools": http://landinghub.visualstudio.com/visual-cpp-build-tools)
    震惊你不知道的python
    django.core.exceptions.ImproperlyConfigured: Error loading MySQLdb module: No module named 'MySQLdb'
    python3 ImportError: No module named 'ConfigParser'
    python import报错
    No migrations to apply(django不能创建数据库中的表的问题)
  • 原文地址:https://www.cnblogs.com/qw12/p/5743865.html
Copyright © 2011-2022 走看看