PlusML
Loading...
Searching...
No Matches
binary_svm.h
1#ifndef BINARY_SVM_H
2#define BINARY_SVM_H
3
4#include "PlusML/util.h"
5#include <Eigen/Dense>
6
7namespace plusml {
11class EXPORT BinarySVM {
12public:
13 BinarySVM() = delete;
14
22 explicit BinarySVM(uint64_t features, bool bias_enabled = true);
23
28 uint64_t Features() const;
29
34 Eigen::MatrixXf Parameters() const;
35
40 void SetParameters(const Eigen::MatrixXf& parameters);
41
51 void FitSGD(const Eigen::MatrixXf& samples,
52 const Eigen::MatrixXf& targets,
53 float learning_rate,
54 float l2_alpha,
55 uint64_t batch_size,
56 uint64_t epochs);
57
64 Eigen::MatrixXf Predict(Eigen::MatrixXf samples, bool sign = true);
65
66private:
67 uint64_t features_;
68 bool bias_enabled_;
69 Eigen::MatrixXf w_;
70};
71} //namespace plusml
72
73#endif //BINARY_SVM_H
Basic binary classification SVM.
Definition binary_svm.h:11