PlusML
Loading...
Searching...
No Matches
linear_regression.h
1#ifndef LINEAR_REGRESSION_H
2#define LINEAR_REGRESSION_H
3
4#include "PlusML/gradient.h"
5#include "PlusML/util.h"
6
7#include <cstdint>
8#include <Eigen/Dense>
9
10namespace plusml {
14class EXPORT LinearRegression {
15public:
16 LinearRegression() = delete;
17
25 explicit LinearRegression(uint64_t features, bool bias_enabled = true);
26
31 uint64_t Features() const;
32
37 Eigen::MatrixXf Parameters() const;
38
43 void SetParameters(const Eigen::MatrixXf& parameters);
44
54 void FitSGD(const Eigen::MatrixXf& samples,
55 const Eigen::MatrixXf& targets,
56 const LossGradient& grad,
57 float learning_rate,
58 uint64_t batch_size,
59 uint64_t epochs);
60
66 Eigen::MatrixXf Predict(Eigen::MatrixXf samples);
67
68private:
69 uint64_t features_;
70 bool bias_enabled_;
71 Eigen::MatrixXf w_;
72};
73} //namespace plusml
74
75#endif //LINEAR_REGRESSION_H
Basic linear regression class.
Definition linear_regression.h:14
Base class for loss gradient implementations.
Definition loss_gradient.h:11