PlusML
Loading...
Searching...
No Matches
multiclass_svm.h
1#ifndef MULTICLASS_SVM_H
2#define MULTICLASS_SVM_H
3
4#include <forward_list>
5
6#include "PlusML/util.h"
7#include <Eigen/Dense>
8
9#include "PlusML/binary_svm.h"
10
11namespace plusml {
12enum ClassificationMode {
13 kOneVsAll,
14 kOneVsOne
15};
16
22class EXPORT MulticlassSVM {
23public:
24 MulticlassSVM() = delete;
25
33 MulticlassSVM(uint64_t features, uint64_t classes, ClassificationMode mode, bool bias_enabled = true);
34
39 uint64_t Features() const;
40
50 void FitSGD(const Eigen::MatrixXf& samples,
51 const Eigen::MatrixXi& targets,
52 float learning_rate,
53 float l2_alpha,
54 uint64_t batch_size,
55 uint64_t epochs);
56
62 Eigen::MatrixXi Predict(Eigen::MatrixXf samples);
63
64private:
65 ClassificationMode mode_;
66
67 uint64_t features_;
68 uint64_t classes_;
69
70 bool bias_enabled_;
71
72 std::forward_list<std::tuple<uint64_t, uint64_t, BinarySVM>> classifiers_;
73};
74} // plusml
75
76#endif //MULTICLASS_SVM_H
Class implementing multiclassification SVM.
Definition multiclass_svm.h:22