背景和目标
最近在一些C++工程中需要使用到线性代数相关的库,一般来讲可以使用Eigen等库,不过我想从头开始写一个工程,只依赖std标准库,自己写的线性代数库之后进行补充、升级和移植也比较方便。
依赖
本项目依赖以下std标准库,__cplusplus为201703,在此版本下经过测试
- complex
- vector
- map
- set
- stdexcept
- initializer_list
食用
建议搭配Cmake食用,在CMakeLists.txt中添加下面内容,包含头文件和cpp文件即可使用
include_directories(/dir/path/to/Matrix.h)
add_executable(your_program Matrix.cpp)
内置函数
Matrix对象可以使用下面两种方法进行初始化
Matrix a = {{{1,0.5}, 2, std::complex(3,0)}, {4, 8, 6}, {7, 8, 9}};
Matrix b(3, 3);
其中{1,0.5}和std::complex<double>(3,0)是一个复数,所有的元素都会被初始化为std::complex<double>类型。
可以使用+、-、*、^等运算符,^运算符支持负数,Matrix^(-2) = Matrix.inverse()^(2)。
矩阵中的元素通过(row,col)进行索引,均从0开始,当然也可以使用[row*Matrix.col()+col]进行索引,但是不推荐(Tips:当矩阵是一个行向量或者列向量,及row或者cows为1时,使用[]可以直接索引某个元素)。
特殊用法a(std::vector<int> rows, std::vector<int> cols):可提取矩阵中的特定行和列,比如
当输入为a({1,2},{1,2})时,返回值为
T()函数等价于Transpose(转置),这里为返回一个转置后的矩阵,原矩阵并不会发生变化
inverse()函数返回矩阵的逆,当矩阵不可逆时会直接抛出一个错误std::runtime_error,如果你觉得可以返回一个空矩阵NULL不影响结果,你可以去除到这个error.
getMinor(int row, int col)函数返回矩阵去除row和col之后的结果,可以配合det()函数求代数余子式
det()函数返回矩阵的行列式。
头文件
#ifndef MATRIX_H
#define MATRIX_H
#include
#include
#include
#include
#include
#include
#include
class Matrix
{
public:
Matrix(int rows, int cols);
Matrix(std::initializer_list>> init);
std::complex &operator[](int index);
const std::complex &operator[](int index) const;
std::complex &operator()(int i, int j);
const std::complex &operator()(int i, int j) const;
Matrix operator()(std::vector rows, std::vector cols);
Matrix operator+(const Matrix &other) const;
Matrix operator-(const Matrix &other) const;
Matrix operator*(const Matrix &other) const;
Matrix operator*(const std::complex &scalar) const;
Matrix operator^(int power) const;
Matrix all() const;
Matrix T();
static Matrix Identity(int rows, int cols);
static std::vector> meshgrid(std::vector x, std::vector y);
Matrix inverse() const;
std::complex det() const;
Matrix getMinor(int row, int col) const;
void print() const;
int row() const { return rows; }
int col() const { return cols; }
private:
int rows;
int cols;
std::vector> data;
};
#endif // !MATRIX_H
CPP
#include "Matrix.h"
Matrix::Matrix(int rows, int cols) : rows(rows), cols(cols), data(rows * cols) {}
Matrix::Matrix(std::initializer_list<std::initializer_list<std::complex<double>>> init)
{
rows = init.size();
if (rows == 0)
{
throw std::runtime_error("Matrix cannot have zero rows.");
}
cols = init.begin()->size();
if (cols == 0)
{
throw std::runtime_error("Matrix cannot have zero columns.");
}
data.resize(rows * cols);
int i = 0;
for (const auto &row : init)
{
if (row.size() != cols)
{
throw std::runtime_error("All rows must have the same number of columns.");
}
int j = 0;
for (const auto &elem : row)
{
data[i * cols + j] = elem;
++j;
}
++i;
}
}
std::complex<double> &Matrix::operator[](int index)
{
return data[index];
}
const std::complex<double> &Matrix::operator[](int index) const
{
return data[index];
}
std::complex<double> &Matrix::operator()(int i, int j)
{
return data[i * cols + j];
}
const std::complex<double> &Matrix::operator()(int i, int j) const
{
return data[i * cols + j];
}
// static std::vector<Matrix> meshgrid(std::vector<std::complex<double>> x, std::vector<std::complex<double>> y);
std::vector<std::vector<int>> Matrix::meshgrid(std::vector<int> x, std::vector<int> y)
{
// Matrix X(x.size(), y.size());
// Matrix Y(x.size(), y.size());
// for (int i = 0; i < x.size(); ++i)
// {
// for (int j = 0; j < y.size(); ++j)
// {
// X(i, j) = x[i];
// Y(i, j) = y[j];
// }
// }
// return {X.T(), Y.T()};
std::vector<std::vector<int>> result;
std::vector<int> X;
std::vector<int> Y;
for (auto i : x)
{
for (auto j : y)
{
X.push_back(i);
Y.push_back(j);
}
}
result.push_back(X);
result.push_back(Y);
return result;
}
Matrix Matrix::operator+(const Matrix &other) const
{
if (rows != other.rows || cols != other.cols)
{
throw std::runtime_error("Matrix dimensions must match for addition.");
}
Matrix result(rows, cols);
for (int i = 0; i < rows * cols; ++i)
{
result[i] = data[i] + other[i];
}
return result;
}
Matrix Matrix::T()
{
Matrix result(cols, rows);
for (int i = 0; i < rows; ++i)
{
for (int j = 0; j < cols; ++j)
{
result(j, i) = this->operator()(i, j);
}
}
return result;
}
Matrix Matrix::Identity(int row, int col)
{
Matrix result(row, col);
for (int i = 0; i < row; ++i)
{
result(i, i) = 1;
}
return result;
}
Matrix Matrix::operator-(const Matrix &other) const
{
if (rows != other.rows || cols != other.cols)
{
throw std::runtime_error("Matrix dimensions must match for subtraction.");
}
Matrix result(rows, cols);
for (int i = 0; i < rows * cols; ++i)
{
result[i] = data[i] - other[i];
}
return result;
}
Matrix Matrix::operator*(const Matrix &other) const
{
if (cols != other.rows)
{
throw std::runtime_error("Matrix dimensions are not compatible for multiplication.");
}
Matrix result(rows, other.cols);
for (int i = 0; i < rows; ++i)
{
for (int j = 0; j < other.cols; ++j)
{
result(i, j) = 0;
for (int k = 0; k < cols; ++k)
{
result(i, j) += (*this)(i, k) * other(k, j);
}
}
}
return result;
}
Matrix Matrix::operator*(const std::complex<double> &scalar) const
{
Matrix result(rows, cols);
for (int i = 0; i < rows * cols; ++i)
{
result[i] = data[i] * scalar;
}
return result;
}
Matrix Matrix::getMinor(int row, int col) const
{
Matrix minor(rows - 1, cols - 1);
for (int i = 0, mi = 0; i < rows; ++i)
{
if (i == row)
continue;
for (int j = 0, mj = 0; j < cols; ++j)
{
if (j == col)
continue;
minor(mi, mj) = (*this)(i, j);
++mj;
}
++mi;
}
return minor;
}
std::complex<double> Matrix::det() const
{
if (rows != cols)
{
throw std::runtime_error("Matrix must be square to compute determinant.");
}
if (rows == 1)
{
return (*this)(0, 0);
}
if (rows == 2)
{
return (*this)(0, 0) * (*this)(1, 1) - (*this)(0, 1) * (*this)(1, 0);
}
std::complex<double> det = 0;
for (int j = 0; j < cols; ++j)
{
det += ((*this)(0, j) * getMinor(0, j).det() * ((j % 2 == 0) ? 1.0 : -1.0));
}
return det;
}
Matrix Matrix::inverse() const
{
// if (rows != 2 || cols != 2)
// {
// throw std::runtime_error("Inverse is only implemented for 2x2 matrices.");
// }
// std::complex<double> a = (*this)(0, 0);
// std::complex<double> b = (*this)(0, 1);
// std::complex<double> c = (*this)(1, 0);
// std::complex<double> d = (*this)(1, 1);
// std::complex<double> det = a * d - b * c;
// if (det == std::complex<double>(0.0))
// {
// // 这里本来应该求不出值,但是为了方便,我们返回一个数值很小的det
// throw std::runtime_error("Matrix is not invertible.");
// std::cout << "Matrix is not invertible, but we give a very small value to instead ret" << std::endl;
// det = std::complex<double>(1e-200, 0);
// }
// Matrix result(2, 2);
// result(0, 0) = d / det;
// result(0, 1) = -b / det;
// result(1, 0) = -c / det;
// result(1, 1) = a / det;
// return result;
// 下面是使用高斯约旦消元法的更通用方法
if (rows != cols)
{
throw std::runtime_error("Matrix must be square to find its inverse.");
}
int n = rows;
Matrix augmented(n, 2 * n);
// Initialize the augmented matrix [A | I]
for (int i = 0; i < n; ++i)
{
for (int j = 0; j < n; ++j)
{
augmented(i, j) = (*this)(i, j);
augmented(i, j + n) = (i == j) ? 1.0 : 0.0;
}
}
// Perform Gauss-Jordan elimination
for (int i = 0; i < n; ++i)
{
// Find the pivot row
int pivot = i;
for (int j = i + 1; j < n; ++j)
{
if (std::abs(augmented(j, i)) > std::abs(augmented(pivot, i)))
{
pivot = j;
}
}
if (std::abs(augmented(pivot, i)) < 1e-10)
{
throw std::runtime_error("Matrix is singular and cannot be inverted.");
}
// Swap rows
if (i != pivot)
{
for (int j = 0; j < 2 * n; ++j)
{
std::swap(augmented(i, j), augmented(pivot, j));
}
}
// Normalize the pivot row
std::complex<double> pivotValue = augmented(i, i);
for (int j = 0; j < 2 * n; ++j)
{
augmented(i, j) /= pivotValue;
}
// Eliminate the other rows
for (int j = 0; j < n; ++j)
{
if (j != i)
{
std::complex<double> factor = augmented(j, i);
for (int k = 0; k < 2 * n; ++k)
{
augmented(j, k) -= factor * augmented(i, k);
}
}
}
}
// Extract the inverse matrix
Matrix inv(n, n);
for (int i = 0; i < n; ++i)
{
for (int j = 0; j < n; ++j)
{
inv(i, j) = augmented(i, j + n);
}
}
return inv;
}
Matrix Matrix::operator^(int power) const
{
if (rows != cols)
{
throw std::runtime_error("Matrix must be square for exponentiation.");
}
if (power < 0)
{
return this->inverse() ^ (-power);
}
if (power == 0)
{
Matrix result(rows, cols);
for (int i = 0; i < rows; ++i)
{
result(i, i) = 1;
}
return result;
}
Matrix result = *this;
for (int i = 1; i < power; ++i)
{
result = result * (*this);
}
return result;
}
void Matrix::print() const
{
// if(this==nullptr){
// std::cout<<"Matrix is nullptr"<<std::endl;
// return;
// }
for (int i = 0; i < rows; ++i)
{
std::cout << "| ";
for (int j = 0; j < cols; ++j)
{
std::cout << (*this)(i, j) << " ";
}
std::cout << "|" << std::endl;
}
}
int countUniqueElements(const std::vector<int> &rows)
{
std::set<int> uniqueRows(rows.begin(), rows.end());
return uniqueRows.size();
}
Matrix Matrix::operator()(std::vector<int> rows, std::vector<int> cols)
{
if (rows.size() == 0 || cols.size() == 0)
{
throw std::runtime_error("Rows and columns must be non-empty.");
}
// row和cal里的最大值不能超过矩阵的行和列,最小值不能小于0
for (auto i : rows)
{
if (i < 0 || i >= this->rows)
{
throw std::runtime_error("Row index out of bounds.");
}
}
for (auto j : cols)
{
if (j < 0 || j >= this->cols)
{
throw std::runtime_error("Column index out of bounds.");
}
}
std::map<int, int> rowMap;
std::map<int, int> colMap;
// 如果不存在,就插入
int count_r = 0;
int count_c = 0;
for (auto i : rows)
{
if (rowMap.find(i) == rowMap.end())
{
rowMap[i] = count_r;
count_r++;
}
}
for (auto j : cols)
{
if (colMap.find(j) == colMap.end())
{
colMap[j] = count_c;
count_c++;
}
}
// 创建一个新的矩阵
Matrix result(count_r, count_c);
for (auto i : rows)
{
for (auto j : cols)
{
result(rowMap[i], colMap[j]) = this->operator()(i, j);
}
}
return result;
}
Matrix Matrix::all() const
{
// 记录不需要删除的行和列
std::vector<int> rows;
std::vector<int> cols;
for (int i = 0; i < this->rows; ++i)
{
if (this->rows == 1)
{
rows.push_back(i);
break;
}
bool all_true = true;
for (int j = 0; j < this->cols; ++j)
{
if (this->operator()(i, j) != std::complex<double>(0.0))
{
break;
}
if (j == this->cols - 1)
{
all_true = false;
}
}
if (all_true)
{
rows.push_back(i);
}
}
for (int j = 0; j < this->cols; ++j)
{
if (this->cols == 1)
{
cols.push_back(j);
break;
}
bool all_true = true;
for (int i = 0; i < this->rows; ++i)
{
if (this->operator()(i, j) != std::complex<double>(0.0))
{
break;
}
if (i == this->rows - 1)
{
all_true = false;
}
}
if (all_true)
{
cols.push_back(j);
}
}
// 创建一个新的矩阵
Matrix result(rows.size(), cols.size());
for (int i = 0; i < rows.size(); ++i)
{
for (int j = 0; j < cols.size(); ++j)
{
result(i, j) = this->operator()(rows[i], cols[j]);
}
}
return result;
}