// // Created by xinyang on 19-4-19. // // 为了一时方便,使用循环和Eigen自行编写的CNN前向传播类。 // 没有显著的性能损失。 // 但类定义了网络结构,同时实现的操作较少,可扩展性较差 #ifndef _CLASSIFIER_H_ #define _CLASSIFIER_H_ #include #include #include #include #include using namespace std; using namespace Eigen; class Classifier { private: bool state; // 标志分类器是否正确初始化 // 所有网络参数 vector> conv1_w, conv2_w, conv3_w; vector conv1_b, conv2_b, conv3_b; MatrixXd fc1_w, fc2_w; VectorXd fc1_b, fc2_b; // 读取网络参数的函数 vector> load_conv_w(const string &file); vector load_conv_b(const string &file); MatrixXd load_fc_w(const string &file); VectorXd load_fc_b(const string &file); // 目前支持的所有操作 MatrixXd softmax(const MatrixXd &input); MatrixXd relu(const MatrixXd &input); MatrixXd leaky_relu(const MatrixXd &input, float alpha); vector> apply_bias(const vector> &input, const vector &bias); vector> relu(const vector> &input); vector> leaky_relu(const vector> &input, float alpha); vector> max_pool(const vector> &input, int size); vector> mean_pool(const vector> &input, int size); vector> pand(const vector> &input, int val); MatrixXd conv(const MatrixXd &filter, const MatrixXd &input); vector> conv2(const vector> &filter, const vector> &input); MatrixXd flatten(const vector> &input); public: explicit Classifier(const string &folder); ~Classifier() = default; MatrixXd calculate(const vector> &input); explicit operator bool() const; int operator()(const cv::Mat &image); }; #endif /* _CLASSIFIER_H */