PlusML
Loading...
Searching...
No Matches
binary_logistic_regression.h
1#ifndef BINARY_LOGISTIC_REGRESSION_H
2#define BINARY_LOGISTIC_REGRESSION_H
3
4#include "PlusML/util.h"
5#include <cstdint>
6#include <Eigen/Dense>
7
8namespace plusml {
13public:
14 BinaryLogisticRegression() = delete;
15
23 explicit BinaryLogisticRegression(uint64_t features, bool bias_enabled = true);
24
29 uint64_t Features() const;
30
35 Eigen::MatrixXf Parameters() const;
36
41 void SetParameters(const Eigen::MatrixXf& parameters);
42
51 void FitSGD(const Eigen::MatrixXf& samples,
52 const Eigen::MatrixXf& targets,
53 float learning_rate,
54 uint64_t batch_size,
55 uint64_t epochs);
56
62 Eigen::MatrixXf Predict(Eigen::MatrixXf samples);
63
64private:
65 uint64_t features_;
66 bool bias_enabled_;
67 Eigen::MatrixXf w_;
68};
69} //namespace plusml
70
71#endif //BINARY_LOGISTIC_REGRESSION_H
Basic binary logistic regression.
Definition binary_logistic_regression.h:12