PlusML
Loading...
Searching...
No Matches
multiclass_logistic_regression.h
1#ifndef MULTICLASS_LOGISTIC_REGRESSION_H
2#define MULTICLASS_LOGISTIC_REGRESSION_H
3
4#include "PlusML/util.h"
5
6#include <cstdint>
7#include <forward_list>
8#include <Eigen/Dense>
9
10#include "binary_logistic_regression.h"
11
12namespace plusml {
17public:
19
26 MulticlassLogisticRegression(uint64_t features, uint64_t classes, bool bias_enabled = true);
27
32 uint64_t Features() const;
33
42 void FitSGD(const Eigen::MatrixXf& samples,
43 const Eigen::MatrixXi& targets,
44 float learning_rate,
45 uint64_t batch_size,
46 uint64_t epochs);
47
53 Eigen::MatrixXf Predict(Eigen::MatrixXf samples);
54
55private:
56 uint64_t features_;
57 uint64_t classes_;
58
59 bool bias_enabled_;
60
61 std::forward_list<std::tuple<uint64_t, BinaryLogisticRegression>> classifiers_;
62};
63
64} //namespace plusml
65
66#endif //MULTICLASS_LOGISTIC_REGRESSION_H
Class implementing multiclassification logistic regression.
Definition multiclass_logistic_regression.h:16