使用C++实现一个处理复数的矩阵库

暮雨٩(๑˃̵ᴗ˂̵๑)۶终将落下 发布于 2024-08-02 886 次阅读


背景和目标

最近在一些C++工程中需要使用到线性代数相关的库,一般来讲可以使用Eigen等库,不过我想从头开始写一个工程,只依赖std标准库,自己写的线性代数库之后进行补充、升级和移植也比较方便。

依赖

本项目依赖以下std标准库,__cplusplus为201703,在此版本下经过测试

  • complex
  • vector
  • map
  • set
  • stdexcept
  • initializer_list

食用

建议搭配Cmake食用,在CMakeLists.txt中添加下面内容,包含头文件和cpp文件即可使用

include_directories(/dir/path/to/Matrix.h)
add_executable(your_program Matrix.cpp)

内置函数

Matrix对象可以使用下面两种方法进行初始化

Matrix a = {{{1,0.5}, 2, std::complex(3,0)}, {4, 8, 6}, {7, 8, 9}};
Matrix b(3, 3);

其中{1,0.5}和std::complex<double>(3,0)是一个复数,所有的元素都会被初始化为std::complex<double>类型。

可以使用+、-、*、^等运算符,^运算符支持负数,Matrix^(-2) = Matrix.inverse()^(2)。

矩阵中的元素通过(row,col)进行索引,均从0开始,当然也可以使用[row*Matrix.col()+col]进行索引,但是不推荐(Tips:当矩阵是一个行向量或者列向量,及row或者cows为1时,使用[]可以直接索引某个元素)。

特殊用法a(std::vector<int> rows, std::vector<int> cols):可提取矩阵中的特定行和列,比如

\begin{equation} a= \begin{pmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \\ 7 & 8 & 9 \end{pmatrix} \end{equation}

当输入为a({1,2},{1,2})时,返回值为

\begin{equation} result = \begin{pmatrix} 5&6\\ 8&9 \end{pmatrix} \end{equation}

T()函数等价于Transpose(转置),这里为返回一个转置后的矩阵,原矩阵并不会发生变化

inverse()函数返回矩阵的逆,当矩阵不可逆时会直接抛出一个错误std::runtime_error,如果你觉得可以返回一个空矩阵NULL不影响结果,你可以去除到这个error.

getMinor(int row, int col)函数返回矩阵去除row和col之后的结果,可以配合det()函数求代数余子式

det()函数返回矩阵的行列式。

头文件


#ifndef MATRIX_H
#define MATRIX_H
 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
 
class Matrix
{
public:
    Matrix(int rows, int cols);
    Matrix(std::initializer_list>> init);
    std::complex &operator[](int index);
    const std::complex &operator[](int index) const;
    std::complex &operator()(int i, int j);
    const std::complex &operator()(int i, int j) const;
    Matrix operator()(std::vector rows, std::vector cols);
    Matrix operator+(const Matrix &other) const;
    Matrix operator-(const Matrix &other) const;
    Matrix operator*(const Matrix &other) const;
    Matrix operator*(const std::complex &scalar) const;
    Matrix operator^(int power) const;
    Matrix all() const;
    Matrix T();
    static Matrix Identity(int rows, int cols);
    static std::vector> meshgrid(std::vector x, std::vector y);
    Matrix inverse() const;
    std::complex det() const;
    Matrix getMinor(int row, int col) const;
    void print() const;
    int row() const { return rows; }
    int col() const { return cols; }
 
private:
    int rows;
    int cols;
    std::vector> data;
};
 
#endif // !MATRIX_H

CPP

#include "Matrix.h"

Matrix::Matrix(int rows, int cols) : rows(rows), cols(cols), data(rows * cols) {}

Matrix::Matrix(std::initializer_list<std::initializer_list<std::complex<double>>> init)
{
    rows = init.size();
    if (rows == 0)
    {
        throw std::runtime_error("Matrix cannot have zero rows.");
    }
    cols = init.begin()->size();
    if (cols == 0)
    {
        throw std::runtime_error("Matrix cannot have zero columns.");
    }
    data.resize(rows * cols);
    int i = 0;
    for (const auto &row : init)
    {
        if (row.size() != cols)
        {
            throw std::runtime_error("All rows must have the same number of columns.");
        }
        int j = 0;
        for (const auto &elem : row)
        {
            data[i * cols + j] = elem;
            ++j;
        }
        ++i;
    }
}

std::complex<double> &Matrix::operator[](int index)
{
    return data[index];
}

const std::complex<double> &Matrix::operator[](int index) const
{
    return data[index];
}

std::complex<double> &Matrix::operator()(int i, int j)
{
    return data[i * cols + j];
}

const std::complex<double> &Matrix::operator()(int i, int j) const
{
    return data[i * cols + j];
}

// static std::vector<Matrix> meshgrid(std::vector<std::complex<double>> x, std::vector<std::complex<double>> y);
std::vector<std::vector<int>> Matrix::meshgrid(std::vector<int> x, std::vector<int> y)
{
    // Matrix X(x.size(), y.size());
    // Matrix Y(x.size(), y.size());
    // for (int i = 0; i < x.size(); ++i)
    // {
    //     for (int j = 0; j < y.size(); ++j)
    //     {
    //         X(i, j) = x[i];
    //         Y(i, j) = y[j];
    //     }
    // }
    // return {X.T(), Y.T()};
    std::vector<std::vector<int>> result;
    std::vector<int> X;
    std::vector<int> Y;
    for (auto i : x)
    {
        for (auto j : y)
        {
            X.push_back(i);
            Y.push_back(j);
        }
    }
    result.push_back(X);
    result.push_back(Y);
    return result;
}

Matrix Matrix::operator+(const Matrix &other) const
{
    if (rows != other.rows || cols != other.cols)
    {
        throw std::runtime_error("Matrix dimensions must match for addition.");
    }
    Matrix result(rows, cols);
    for (int i = 0; i < rows * cols; ++i)
    {
        result[i] = data[i] + other[i];
    }
    return result;
}

Matrix Matrix::T()
{
    Matrix result(cols, rows);
    for (int i = 0; i < rows; ++i)
    {
        for (int j = 0; j < cols; ++j)
        {
            result(j, i) = this->operator()(i, j);
        }
    }
    return result;
}

Matrix Matrix::Identity(int row, int col)
{
    Matrix result(row, col);
    for (int i = 0; i < row; ++i)
    {
        result(i, i) = 1;
    }
    return result;
}

Matrix Matrix::operator-(const Matrix &other) const
{
    if (rows != other.rows || cols != other.cols)
    {
        throw std::runtime_error("Matrix dimensions must match for subtraction.");
    }
    Matrix result(rows, cols);
    for (int i = 0; i < rows * cols; ++i)
    {
        result[i] = data[i] - other[i];
    }
    return result;
}

Matrix Matrix::operator*(const Matrix &other) const
{
    if (cols != other.rows)
    {
        throw std::runtime_error("Matrix dimensions are not compatible for multiplication.");
    }
    Matrix result(rows, other.cols);
    for (int i = 0; i < rows; ++i)
    {
        for (int j = 0; j < other.cols; ++j)
        {
            result(i, j) = 0;
            for (int k = 0; k < cols; ++k)
            {
                result(i, j) += (*this)(i, k) * other(k, j);
            }
        }
    }
    return result;
}

Matrix Matrix::operator*(const std::complex<double> &scalar) const
{
    Matrix result(rows, cols);
    for (int i = 0; i < rows * cols; ++i)
    {
        result[i] = data[i] * scalar;
    }
    return result;
}

Matrix Matrix::getMinor(int row, int col) const
{
    Matrix minor(rows - 1, cols - 1);
    for (int i = 0, mi = 0; i < rows; ++i)
    {
        if (i == row)
            continue;
        for (int j = 0, mj = 0; j < cols; ++j)
        {
            if (j == col)
                continue;
            minor(mi, mj) = (*this)(i, j);
            ++mj;
        }
        ++mi;
    }
    return minor;
}

std::complex<double> Matrix::det() const
{
    if (rows != cols)
    {
        throw std::runtime_error("Matrix must be square to compute determinant.");
    }

    if (rows == 1)
    {
        return (*this)(0, 0);
    }

    if (rows == 2)
    {
        return (*this)(0, 0) * (*this)(1, 1) - (*this)(0, 1) * (*this)(1, 0);
    }

    std::complex<double> det = 0;
    for (int j = 0; j < cols; ++j)
    {
        det += ((*this)(0, j) * getMinor(0, j).det() * ((j % 2 == 0) ? 1.0 : -1.0));
    }
    return det;
}

Matrix Matrix::inverse() const
{
    // if (rows != 2 || cols != 2)
    // {
    //     throw std::runtime_error("Inverse is only implemented for 2x2 matrices.");
    // }
    // std::complex<double> a = (*this)(0, 0);
    // std::complex<double> b = (*this)(0, 1);
    // std::complex<double> c = (*this)(1, 0);
    // std::complex<double> d = (*this)(1, 1);

    // std::complex<double> det = a * d - b * c;
    // if (det == std::complex<double>(0.0))
    // {
    //     // 这里本来应该求不出值,但是为了方便,我们返回一个数值很小的det
    //     throw std::runtime_error("Matrix is not invertible.");
    //     std::cout << "Matrix is not invertible, but we give a very small value to instead ret" << std::endl;
    //     det = std::complex<double>(1e-200, 0);
    // }

    // Matrix result(2, 2);
    // result(0, 0) = d / det;
    // result(0, 1) = -b / det;
    // result(1, 0) = -c / det;
    // result(1, 1) = a / det;

    // return result;

    // 下面是使用高斯约旦消元法的更通用方法
    if (rows != cols)
    {
        throw std::runtime_error("Matrix must be square to find its inverse.");
    }

    int n = rows;
    Matrix augmented(n, 2 * n);

    // Initialize the augmented matrix [A | I]
    for (int i = 0; i < n; ++i)
    {
        for (int j = 0; j < n; ++j)
        {
            augmented(i, j) = (*this)(i, j);
            augmented(i, j + n) = (i == j) ? 1.0 : 0.0;
        }
    }

    // Perform Gauss-Jordan elimination
    for (int i = 0; i < n; ++i)
    {
        // Find the pivot row
        int pivot = i;
        for (int j = i + 1; j < n; ++j)
        {
            if (std::abs(augmented(j, i)) > std::abs(augmented(pivot, i)))
            {
                pivot = j;
            }
        }

        if (std::abs(augmented(pivot, i)) < 1e-10)
        {
            throw std::runtime_error("Matrix is singular and cannot be inverted.");
        }

        // Swap rows
        if (i != pivot)
        {
            for (int j = 0; j < 2 * n; ++j)
            {
                std::swap(augmented(i, j), augmented(pivot, j));
            }
        }

        // Normalize the pivot row
        std::complex<double> pivotValue = augmented(i, i);
        for (int j = 0; j < 2 * n; ++j)
        {
            augmented(i, j) /= pivotValue;
        }

        // Eliminate the other rows
        for (int j = 0; j < n; ++j)
        {
            if (j != i)
            {
                std::complex<double> factor = augmented(j, i);
                for (int k = 0; k < 2 * n; ++k)
                {
                    augmented(j, k) -= factor * augmented(i, k);
                }
            }
        }
    }

    // Extract the inverse matrix
    Matrix inv(n, n);
    for (int i = 0; i < n; ++i)
    {
        for (int j = 0; j < n; ++j)
        {
            inv(i, j) = augmented(i, j + n);
        }
    }

    return inv;
}

Matrix Matrix::operator^(int power) const
{
    if (rows != cols)
    {
        throw std::runtime_error("Matrix must be square for exponentiation.");
    }
    if (power < 0)
    {
        return this->inverse() ^ (-power);
    }
    if (power == 0)
    {
        Matrix result(rows, cols);
        for (int i = 0; i < rows; ++i)
        {
            result(i, i) = 1;
        }
        return result;
    }
    Matrix result = *this;
    for (int i = 1; i < power; ++i)
    {
        result = result * (*this);
    }
    return result;
}

void Matrix::print() const
{
    // if(this==nullptr){
    //     std::cout<<"Matrix is nullptr"<<std::endl;
    //     return;
    // }
    for (int i = 0; i < rows; ++i)
    {
        std::cout << "| ";
        for (int j = 0; j < cols; ++j)
        {
            std::cout << (*this)(i, j) << " ";
        }
        std::cout << "|" << std::endl;
    }
}

int countUniqueElements(const std::vector<int> &rows)
{
    std::set<int> uniqueRows(rows.begin(), rows.end());
    return uniqueRows.size();
}

Matrix Matrix::operator()(std::vector<int> rows, std::vector<int> cols)
{
    if (rows.size() == 0 || cols.size() == 0)
    {
        throw std::runtime_error("Rows and columns must be non-empty.");
    }
    // row和cal里的最大值不能超过矩阵的行和列,最小值不能小于0
    for (auto i : rows)
    {
        if (i < 0 || i >= this->rows)
        {
            throw std::runtime_error("Row index out of bounds.");
        }
    }
    for (auto j : cols)
    {
        if (j < 0 || j >= this->cols)
        {
            throw std::runtime_error("Column index out of bounds.");
        }
    }
    std::map<int, int> rowMap;
    std::map<int, int> colMap;
    // 如果不存在,就插入
    int count_r = 0;
    int count_c = 0;
    for (auto i : rows)
    {
        if (rowMap.find(i) == rowMap.end())
        {
            rowMap[i] = count_r;
            count_r++;
        }
    }
    for (auto j : cols)
    {
        if (colMap.find(j) == colMap.end())
        {
            colMap[j] = count_c;
            count_c++;
        }
    }
    // 创建一个新的矩阵
    Matrix result(count_r, count_c);
    for (auto i : rows)
    {
        for (auto j : cols)
        {
            result(rowMap[i], colMap[j]) = this->operator()(i, j);
        }
    }
    return result;
}

Matrix Matrix::all() const
{
    // 记录不需要删除的行和列
    std::vector<int> rows;
    std::vector<int> cols;
    for (int i = 0; i < this->rows; ++i)
    {
        if (this->rows == 1)
        {
            rows.push_back(i);
            break;
        }
        bool all_true = true;
        for (int j = 0; j < this->cols; ++j)
        {
            if (this->operator()(i, j) != std::complex<double>(0.0))
            {
                break;
            }
            if (j == this->cols - 1)
            {
                all_true = false;
            }
        }
        if (all_true)
        {
            rows.push_back(i);
        }
    }
    for (int j = 0; j < this->cols; ++j)
    {
        if (this->cols == 1)
        {
            cols.push_back(j);
            break;
        }
        bool all_true = true;
        for (int i = 0; i < this->rows; ++i)
        {
            if (this->operator()(i, j) != std::complex<double>(0.0))
            {
                break;
            }
            if (i == this->rows - 1)
            {
                all_true = false;
            }
        }
        if (all_true)
        {
            cols.push_back(j);
        }
    }
    // 创建一个新的矩阵
    Matrix result(rows.size(), cols.size());
    for (int i = 0; i < rows.size(); ++i)
    {
        for (int j = 0; j < cols.size(); ++j)
        {
            result(i, j) = this->operator()(rows[i], cols[j]);
        }
    }
    return result;
}

Hi~ o(* ̄▽ ̄*)ブ这里是feng-arch,会不定时在网站上发布技术贴~ ~~ 当然,要是很久都没有更新,那就是摆烂去了o(*////▽////*)q 有问题欢迎发送邮件至 feng-arch@outlook.com
最后更新于 2024-08-13