SVM原理与C++的Eigen库实现

具体的SVM详解参考https://blog.csdn.net/c406495762/article/details/78072313, 讲的特别详细, 如下代码也是基于该链接中的讲解而实现的

//model.h
#include <iostream>
#include "Eigen/Eigen"
#include<vector>
#include<string>
#include<fstream>
#include<sstream>
#include<iterator>
#include<algorithm>
#include<regex>
#include<set>
#include<unordered_map>
#include <assert.h>
#include <random>
#include <python2.7/Python.h>
#include <stdlib.h>
using namespace Eigen;
using std::pair;
using std::vector;
using std::cout;
using std::endl;
using std::ios;
using std::ifstream;
using std::string;
using std::regex;
using std::iterator;
using std::stringstream;
using std::sregex_token_iterator;
using std::set;
using std::istringstream;
using std::istream_iterator;
using std::unordered_map;
using std::make_pair;
using std::begin;
using std::end;
using std::min;
using std::max;
using std::abs;

using _MAT_VEC=pair<MatrixXf,VectorXf>;
using _PARAM=pair<VectorXf,float>;

#define filename "data/sample"
#define C 0.6f
#define Threshold 0.001f
#define Max_iter 40
#define Alpha_threshold 0.00001f
#define RBF_var 1.3




#ifndef DATA_HANDLE_
#define DATA_HANDLE_
inline _MAT_VEC load_data();
inline pair<_MAT_VEC,_MAT_VEC> train_test_split(const _MAT_VEC,float);
_PARAM SMO(const MatrixXf&,const VectorXf&);
inline VectorXf cal_weight(const MatrixXf&,const VectorXf&,const VectorXf&);
VectorXf kernel_RBF(MatrixXf,VectorXf,float);
inline void python_plot(const VectorXf&,float b);
#endif
//model.cpp
#pragma once
#include "model.h"

inline _MAT_VEC load_data(){
    ifstream ifile(filename,ios::in);
    if(!ifile.is_open()){
        cout<<"failed to open: "<<filename<<endl;
    }
    string line;
    vector<vector<float> > tempX;
    vector<float> tempY;
    while(getline(ifile,line)){
        istringstream iss( line );
        vector<float> nums{istream_iterator<float>( iss ), std::istream_iterator<float>()};
        tempY.push_back(nums[nums.size()-1]);
        nums.pop_back();
        tempX.push_back(nums);
    }
    assert(tempX.size()==tempY.size());
    int matC=tempX[0].size(),matR=tempX.size();
    MatrixXf X(matR,matC);
    VectorXf Y(matR);
    for(auto row=0;row<matR;++row){
        Y[row]=tempY[row];
        float *arr=tempX[row].data();
        X.row(row)=Map<VectorXf>(arr,matC);
    }
    return make_pair(X,Y);
}

inline pair<_MAT_VEC,_MAT_VEC> train_test_split(const _MAT_VEC &raw_data,float percent=0.8){
    int new_row=raw_data.first.rows()*percent;
    return make_pair(
        make_pair(raw_data.first.topRows(new_row),raw_data.second.topRows(new_row)),
        make_pair(raw_data.first.topRows(raw_data.first.rows()-new_row),raw_data.second.topRows(raw_data.first.rows()-new_row))
        );

}

int random_choice(int min,int max,int current){
    std::random_device seeder;
    std::mt19937 engine(seeder());
    std::uniform_int_distribution<int> dist(min, max);
    int rand=dist(engine);
    while(rand==current){
        rand=dist(engine);
    }
    return rand;
}

_PARAM SMO(const MatrixXf &features,const VectorXf& labels){
    int rows=features.rows(),cols=features.cols();
    int b=0,iter_count=0;
    VectorXf alphas(rows);
    alphas.setZero();
    while(iter_count<=Max_iter){
        int pair_alpha_changed_count=0;
        for(int i=0;i<rows;++i){

            //计算拉格朗日表示fX_i和损失E_i
            float fX_i=(alphas.cwiseProduct(labels)).transpose()*(features*(features.row(i).transpose()))+b;
            //float fX_i=(alphas.cwiseProduct(labels)).transpose()*(kernel_RBF(features,VectorXf(features.row(i)),1.3f))+b;
            float E_i=fX_i-labels(i);
            if((labels(i)*E_i<-Threshold && alphas(i)<C) || (labels(i)*E_i>Threshold && alphas(i)>0)){
                
                //随机挑选j并计算拉格朗日fX_j和E_j
                int j=random_choice(0,rows-1,i);
                float fX_j=(alphas.cwiseProduct(labels)).transpose()*(features*(features.row(j).transpose()))+b;
                //float fX_j=(alphas.cwiseProduct(labels)).transpose()*(kernel_RBF(features,VectorXf(features.row(j)),1.3f))+b;
                float E_j=fX_j-labels(j);

                //保留i和j所对应的旧的alphas
                float alphas_old_i=alphas(i);
                float alphas_old_j=alphas(j);

                //计算上下界, 如果上下界相同则重新选择
                float zero=0;
                float L=(labels(i)!=labels(j))?max(zero,alphas(j)-alphas(i)):max(zero,alphas(i)+alphas(j)-C);
                float H=(labels(i)!=labels(j))?min(C,C+alphas(j)-alphas(i)):min(C,alphas(i)+alphas(j));
                if(L==H) continue;

                //计算步长eta, 大于等于0说明不是支持向量
                float eta=static_cast<float>(2.0f*features.row(i)*(features.row(j).transpose()))-static_cast<float>(features.row(i)*(features.row(i).transpose()))-static_cast<float>(features.row(j)*(features.row(j).transpose()));
                if(eta>=0) continue;

                //更新alphas_j并对alphas_j加窗
                alphas(j)-=labels(j)*(E_i-E_j)/eta;
                alphas(j)=alphas(j)>H?H:(alphas(j)<L?L:alphas(j));

                //如果alphas_j变化太小则不更新  
                if(abs(alphas(j)-alphas_old_j)<Alpha_threshold) continue;
                
                //更新alphas_i
                alphas(i)+=static_cast<float>(labels.row(j)*labels.row(i))*(alphas_old_j-alphas(j));
                
                //更新b_1和b_2
                float b_1 = b - E_i- labels(i)*(alphas(i)-alphas_old_i)*features.row(i)*features.row(i).transpose() - labels(j)*(alphas(j)-alphas_old_j)*features.row(i)*features.row(j).transpose();
                float b_2 = b - E_j- labels(i)*(alphas(i)-alphas_old_i)*features.row(i)*features.row(j).transpose() - labels(j)*(alphas(j)-alphas_old_j)*features.row(j)*features.row(j).transpose();

                //更新b
                b=(0<alphas(i)&&C>alphas(i))?b_1:
                    ((0<alphas(j)&&C>alphas(j))?b_2:
                        (b_1+b_2)/2);
                
                ++pair_alpha_changed_count;
            }
        }
        cout<<"第 "<<iter_count<<" 次迭代. 在这次迭代中, 共有 "<<pair_alpha_changed_count<<" 个SMO对被改变"<<endl;
        if(!pair_alpha_changed_count) ++iter_count;
        else iter_count=0;
    }
    return make_pair(alphas,b);
}

inline VectorXf cal_weight(const MatrixXf &features,const VectorXf &labels,const VectorXf &alphas){
    int rows=features.rows(),cols=features.cols();
    VectorXf weight=labels.cwiseProduct(alphas).replicate(1,cols).cwiseProduct(features).colwise().sum().transpose();
    return weight;
}

//径向基核函数
VectorXf kernel_RBF(MatrixXf features,VectorXf line,float var=RBF_var){
    features.rowwise()-=line.transpose();
    ArrayXf kV=ArrayXf((features*features.transpose()).diagonal()/(pow(var,2))*(-1));
    return kV.exp();
}


//调用python代码作图
void python_plot(VectorXf &weight,float b){
    setenv("PYTHONPATH",".",1); //将python路径设为当前工作路径
    Py_Initialize();

    PyObject* myModuleString = PyString_FromString((char*)"svm");
    PyObject* myModule = PyImport_Import(myModuleString);

    PyObject* myFunction = PyObject_GetAttrString(myModule,(char*)"plot_points");

    //通过元组传入参数
    PyObject *pArgs = PyTuple_New(3);
    PyTuple_SetItem(pArgs,0, PyFloat_FromDouble(static_cast<double>(weight(0))));
    PyTuple_SetItem(pArgs,1, PyFloat_FromDouble(static_cast<double>(weight(1))));
    PyTuple_SetItem(pArgs,2, PyFloat_FromDouble(static_cast<double>(b)));

    //调用函数
    PyObject_CallObject(myFunction, pArgs);
    Py_Finalize();
}
//main.cpp
#include "Eigen/Dense"
#include "model.h"
#include "model.cpp"
int main(){
    _MAT_VEC Train;
    _MAT_VEC Test;
    _MAT_VEC total_data=load_data();
    {
        auto whole_data=train_test_split(total_data);
        Train=move(whole_data.first);
        Test=move(whole_data.second);
    }
    assert(Train.first.rows()==Train.second.rows());
    _PARAM alpha_b=SMO(Train.first,Train.second);
    
    VectorXf weight=cal_weight(Train.first,Train.second,alpha_b.first);

    ArrayXf arr((Test.first*weight+alpha_b.second*VectorXf::Ones(Test.first.rows())).cwiseProduct(Test.second));

    cout<<"权重大小w为: "<<weight.transpose()<<"偏置项b为: "<<alpha_b.second<<endl;

    cout<<"训练样本大小: "<<Train.first.rows()<<endl<<"测试样本大小: "<<Test.first.rows()<<endl;
    cout<<"测试样本正确的数量: "<<(arr>=0).count()<<endl;

    python_plot(weight,alpha_b.second);
}
#svm.py
# -*- coding:UTF-8 -*-


def plot_points(a1,a2,b):
    import matplotlib.pyplot as plt
    import numpy as np
    import types

    fileName=''
    features = []; labels = []
    fr = open(fileName)
    for line in fr.readlines():
        lineArr = line.strip().split('\t')
        features.append([float(lineArr[0]), float(lineArr[1])])
        labels.append(float(lineArr[2]))

    data_plus = []
    data_minus = []
    for i in range(len(features)):
        if labels[i] > 0:
            data_plus.append(features[i])
        else:
            data_minus.append(features[i])
    data_plus_np = np.array(data_plus)
    data_minus_np = np.array(data_minus)
    plt.scatter(np.transpose(data_plus_np)[0], np.transpose(data_plus_np)[1], s=30, alpha=0.7,color='red')
    plt.scatter(np.transpose(data_minus_np)[0], np.transpose(data_minus_np)[1], s=30, alpha=0.7,color='blue')
    x1 = max(features)[0]
    x2 = min(features)[0]
    y1, y2 = (-b- a1*x1)/a2, (-b - a1*x2)/a2
    plt.plot([x1, x2], [y1, y2])
    plt.title("Sample data and the svm linear discriminant")
    plt.show()

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 157,298评论 4 360
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 66,701评论 1 290
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 107,078评论 0 237
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 43,687评论 0 202
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 52,018评论 3 286
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 40,410评论 1 211
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 31,729评论 2 310
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 30,412评论 0 194
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 34,124评论 1 239
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 30,379评论 2 242
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 31,903评论 1 257
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 28,268评论 2 251
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 32,894评论 3 233
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 26,014评论 0 8
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 26,770评论 0 192
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 35,435评论 2 269
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 35,312评论 2 260