PlusML
Loading...
Searching...
No Matches
include
PlusML
multiclass_logistic_regression.h
1
#ifndef MULTICLASS_LOGISTIC_REGRESSION_H
2
#define MULTICLASS_LOGISTIC_REGRESSION_H
3
4
#include "PlusML/util.h"
5
6
#include <cstdint>
7
#include <forward_list>
8
#include <Eigen/Dense>
9
10
#include "binary_logistic_regression.h"
11
12
namespace
plusml {
16
class
EXPORT
MulticlassLogisticRegression
{
17
public
:
18
MulticlassLogisticRegression
() =
delete
;
19
26
MulticlassLogisticRegression
(uint64_t features, uint64_t classes,
bool
bias_enabled =
true
);
27
32
uint64_t Features()
const
;
33
42
void
FitSGD(
const
Eigen::MatrixXf& samples,
43
const
Eigen::MatrixXi& targets,
44
float
learning_rate,
45
uint64_t batch_size,
46
uint64_t epochs);
47
53
Eigen::MatrixXf Predict(Eigen::MatrixXf samples);
54
55
private
:
56
uint64_t features_;
57
uint64_t classes_;
58
59
bool
bias_enabled_;
60
61
std::forward_list<std::tuple<uint64_t, BinaryLogisticRegression>> classifiers_;
62
};
63
64
}
//namespace plusml
65
66
#endif
//MULTICLASS_LOGISTIC_REGRESSION_H
plusml::MulticlassLogisticRegression
Class implementing multiclassification logistic regression.
Definition
multiclass_logistic_regression.h:16
Generated by
1.10.0