diff --git a/include/smath.hpp b/include/smath.hpp index 25a9bac..369c714 100644 --- a/include/smath.hpp +++ b/include/smath.hpp @@ -547,6 +547,159 @@ template struct Quaternion : Vec<4, T> { } }; +template + requires std::is_arithmetic_v +struct Mat : std::array, C> { + using Base = std::array, C>; + using Base::operator[]; + + constexpr Mat() noexcept { + for (auto &col : *this) + col = Vec{}; + } + + constexpr explicit Mat(T const &diag) noexcept + requires(R == C) + { + for (std::size_t c = 0; c < C; ++c) { + (*this)[c] = Vec{}; + (*this)[c][c] = diag; + } + } + + template + requires(sizeof...(Cols) == C && + (std::same_as, Vec> && ...)) + constexpr Mat(Cols const &...cols) noexcept : Base{cols...} {} + + constexpr auto col(std::size_t j) noexcept -> Vec & { + return (*this)[j]; + } + constexpr auto col(std::size_t j) const noexcept -> Vec const & { + return (*this)[j]; + } + + constexpr auto operator()(std::size_t row, std::size_t col) noexcept -> T & { + return (*this)[col][row]; + } + constexpr auto operator()(std::size_t row, std::size_t col) const noexcept + -> T const & { + return (*this)[col][row]; + } + + constexpr auto operator-() const noexcept -> Mat { + Mat r{}; + for (std::size_t c = 0; c < C; ++c) + r[c] = -(*this)[c]; + return r; + } + + constexpr auto operator+=(Mat const &rhs) noexcept -> Mat & { + for (std::size_t c = 0; c < C; ++c) + (*this)[c] += rhs[c]; + return *this; + } + constexpr auto operator-=(Mat const &rhs) noexcept -> Mat & { + for (std::size_t c = 0; c < C; ++c) + (*this)[c] -= rhs[c]; + return *this; + } + friend constexpr auto operator+(Mat lhs, Mat const &rhs) noexcept -> Mat { + lhs += rhs; + return lhs; + } + friend constexpr auto operator-(Mat lhs, Mat const &rhs) noexcept -> Mat { + lhs -= rhs; + return lhs; + } + + constexpr auto operator*=(T const &s) noexcept -> Mat & { + for (std::size_t c = 0; c < C; ++c) + (*this)[c] *= s; + return *this; + } + constexpr auto operator/=(T const &s) noexcept -> Mat & { + for (std::size_t c = 0; c < C; ++c) + (*this)[c] /= s; + return *this; + } + friend constexpr auto operator*(Mat lhs, T const &s) noexcept -> Mat { + lhs *= s; + return lhs; + } + friend constexpr auto operator*(T const &s, Mat rhs) noexcept -> Mat { + rhs *= s; + return rhs; + } + friend constexpr auto operator/(Mat lhs, T const &s) noexcept -> Mat { + lhs /= s; + return lhs; + } + + constexpr auto operator==(Mat const &rhs) const noexcept -> bool { + for (std::size_t c = 0; c < C; ++c) + if (!((*this)[c] == rhs[c])) + return false; + return true; + } + constexpr auto operator!=(Mat const &rhs) const noexcept -> bool { + return !(*this == rhs); + } + + static constexpr T EPS_DEFAULT = T(1e-6); + template + requires std::is_floating_point_v + constexpr auto approx_equal(Mat const &rhs, + U eps = EPS_DEFAULT) const noexcept -> bool { + for (std::size_t c = 0; c < C; ++c) + if (!(*this)[c].approx_equal(rhs[c], eps)) + return false; + return true; + } + + constexpr auto transposed() const noexcept -> Mat { + Mat r{}; + for (std::size_t c = 0; c < C; ++c) + for (std::size_t r_idx = 0; r_idx < R; ++r_idx) + r(r_idx, c) = (*this)(c, r_idx); + return r; + } + + static constexpr auto identity() noexcept -> Mat + requires(R == C) + { + Mat m{}; + for (std::size_t i = 0; i < R; ++i) + m(i, i) = T(1); + return m; + } +}; + +template +constexpr Vec operator*(Mat const &m, + Vec const &v) noexcept { + Vec out{}; + for (std::size_t c = 0; c < C; ++c) + out += m.col(c) * v[c]; + return out; +} + +// Matrix * Matrix +template +constexpr Mat operator*(Mat const &a, + Mat const &b) noexcept { + Mat out{}; + for (std::size_t k = 0; k < K; ++k) { + for (std::size_t r = 0; r < R; ++r) { + T sum = T(0); + for (std::size_t c = 0; c < C; ++c) + sum += a(r, c) * b(r, k); + out(r, k) = sum; + } + } + return out; +} + } // namespace smath template