zawatins-library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub zawa-tin/zawatins-library

:heavy_check_mark: test/matrix.test.cpp

Depends on

Code

#define PROBLEM "https://judge.yosupo.jp/problem/matrix_product"

#include "../src/math/modint.hpp"
#include "../src/math/matrix.hpp"

#include <iostream>
#include <vector>

int main() {
	int N, M, K; std::cin >> N >> M >> K;
	zawa::matrix<zawa::modint<998244353>> A(N, M), B(M, K);
	for (int i = 0 ; i < N ; i++) {
		for (int j = 0 ; j < M ; j++) {
			int a; std::cin >> a;
			A[i][j] = a;
		}
	}
	for (int i = 0 ; i < M ; i++) {
		for (int j = 0 ; j < K ; j++) {
			int a; std::cin >> a;
			B[i][j] = a;
		}
	}
	auto C = A * B;
	for (int i = 0 ; i < N ; i++) {
		for (int j = 0 ; j < K ; j++) {
			std::cout << C[i][j].val() << (j + 1 == K ? '\n' : ' ');
		}
	}
}
#line 1 "test/matrix.test.cpp"
#define PROBLEM "https://judge.yosupo.jp/problem/matrix_product"

#line 1 "src/math/modint.hpp"
namespace zawa {

    template<long long mod>
    class modint {
    private:
        long long x;

    public:
        modint() : x(0) {}
        
        modint(long long x) : x((x % mod + mod) % mod) {}

        modint& operator+=(modint p) {
            x += p.x;
            if (x >= mod) x -= mod;
            return *this;
        }

        modint& operator-=(modint p) {
            x += mod - p.x;
            if (x >= mod) x -= mod;
            return *this;
        }

        modint& operator*=(modint p) {
            x = (1LL * x * p.x % mod);
            return *this;
        }

        modint& operator/=(modint p) {
            *this *= p.inv();
            return *this;
        }

        modint operator-() const {
            return modint(-x);
        }

        modint operator+(const modint& p) const {
            return modint(*this) += p;
        }

        modint operator-(const modint& p) const {
            return modint(*this) -= p;
        }

        modint operator*(const modint& p) const {
            return modint(*this) *= p;
        }

        modint operator/(const modint& p) const {
            return modint(*this) /= p;
        }

        long long val() {
            return x;
        }

        modint pow(long long p) {
            modint res(1), val(x);
            while(p) {
                if (p & 1) res *= val;
                val *= val;
                p >>= 1;
            }
            return res;
        }

        modint inv() {
            return pow(mod - 2);
        }
    };

}// namespace zawa
#line 2 "src/math/matrix.hpp"

#include <vector>

namespace zawa {

template <class T = long long>
class matrix {
private:
	std::vector<std::vector<T>> dat;

public:
	std::size_t r, c;

	matrix(const std::vector<T>& dat) : dat(dat), r(dat.size()), c(dat[0].size())  {}
	matrix(std::size_t r, std::size_t c) : dat(r, std::vector<T>(c)), r(r), c(c) {}
	matrix(const matrix<T>& mat) : dat(mat.r, std::vector<T>(mat.c)), r(mat.r), c(mat.c) {
		for (std::size_t i = 0 ; i < r ; i++) {
			for (std::size_t j = 0 ; j < c ; j++) {
				dat[i][j] = mat[i][j];
			}
		}	
	}

	std::vector<T>& operator[](std::size_t i) {
		return dat[i];
	}
	const std::vector<T>& operator[](std::size_t i) const {
		return dat[i];
	}

	matrix& operator+=(const matrix<T>& mat) {
		for (std::size_t i = 0 ; i < r ; i++) {
			for (std::size_t j = 0 ; j < c ; j++) {
				dat[i][j] += mat[i][j];
			}
		}
		return *this;
	}
	matrix operator+(const matrix<T>& mat) {
		return matrix(*this) += mat;
	}

	matrix& operator-=(const matrix<T>& mat) {
		for (std::size_t i = 0 ; i < r ; i++) {
			for (std::size_t j = 0 ; j < c ; j++) {
				dat[i][j] -= mat[i][j];
			}
		}
		return *this;
	}
	matrix& operator-(const matrix<T>& mat) {
		return matrix(*this) -= mat;	
	}

	matrix operator*(const matrix<T>& mat) {
		matrix res(r, mat.c);
		for (std::size_t i = 0 ; i < r ; i++) {
			for (std::size_t j = 0 ; j < mat.c ; j++) {
				for (std::size_t k = 0 ; k < c ; k++) {
					res[i][j] += dat[i][k] * mat[k][j];
				}
			}
		}
		return res;
	}
	matrix operator*=(const matrix<T>& mat) {
		return (*this) = (*this) * mat;
	}

	matrix pow(long long p);
};

template <class T>
matrix<T> id_mul(std::size_t n) {
	matrix<T> res(n, n);
	for (std::size_t i = 0 ; i < n ; i++) {
		res[i][i] = 1;
	}
	return res;
}

template <class T>
matrix<T> matrix<T>::pow(long long p) {
	matrix<T> res = id_mul<T>(this->r);
	matrix<T> base(*this);
	while (p > 0) {
		if (p & 1) {
			res *= base;
		}
		base *= base;
		p >>= 1;
	}
	return res;
}

} // namespace zawa
#line 5 "test/matrix.test.cpp"

#include <iostream>
#line 8 "test/matrix.test.cpp"

int main() {
	int N, M, K; std::cin >> N >> M >> K;
	zawa::matrix<zawa::modint<998244353>> A(N, M), B(M, K);
	for (int i = 0 ; i < N ; i++) {
		for (int j = 0 ; j < M ; j++) {
			int a; std::cin >> a;
			A[i][j] = a;
		}
	}
	for (int i = 0 ; i < M ; i++) {
		for (int j = 0 ; j < K ; j++) {
			int a; std::cin >> a;
			B[i][j] = a;
		}
	}
	auto C = A * B;
	for (int i = 0 ; i < N ; i++) {
		for (int j = 0 ; j < K ; j++) {
			std::cout << C[i][j].val() << (j + 1 == K ? '\n' : ' ');
		}
	}
}
Back to top page