cp-documentation

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub zawa-tin/cp-documentation

:heavy_check_mark: Mod of Linear をsqrt(M)個の等差数列に分割する
(Src/Number/SqrtHeuristicArithmetic.hpp)

概要

長さ $N$ の数列 $f_i = (Ai + B) \mod M$ があったとき、これを $O(\sqrt{M})$ 個の連続するとは限らない $f$ の部分列であって、それぞれが等差数列であるものに分ける。

任意の部分列の添え字の列はある整数 $l, g, n’$ に対して $(l, l + g, l + 2g, \dots, l + (n’-1)g)$ となる。

このライブラリが動作するために結構厳しい条件が必要で、コード内のコメントを参照すること。

Depends on

Verified with

Code

#pragma once

#include "../Template/TypeAlias.hpp"

#include <algorithm>
#include <cassert>
#include <cmath>
#include <concepts>
#include <ranges>
#include <tuple>
#include <vector>
#include <numeric>

namespace zawa {

// ax + b
template <std::integral T>
struct Arithmetic {
    T a, b;
    usize n;
    // original sequence f(l), f(l + gap), f(l + 2gap), ..., f(l + (n-1)gap)
    usize l, gap;
    T operator()(usize i) const {
        assert(i < n);
        return a * i + b;
    }
    usize originIndex(usize i) const {
        assert(i < n);
        return gap * i + l;
    }
};

// { B, A + B, 2A + B, 3A + B, ..., (N-1)A + B } (each values are mod M)
// require: 0 <= N < M, gcd(A, M) = 1
// reference: https://codeforces.com/blog/entry/141889
// reference: https://atcoder.jp/contests/abc429/submissions/70468550
// response: get<0> * get<1>[i](j) + get<2> = B + A * get<1>[i].origin_index(j)
// memo: if gcd(A, M) = g > 1, transform to g((A/g)x + (B/g) (mod M)) + B%g
// memo: if N >= M, reduce to gcd(A, M) = 1 and partition to {0, 1, 2, ..., M - 1} (N/M times), {B, A+B, ..., A*(N%M)+B} mod M
template <std::integral T>
std::vector<Arithmetic<T>> SqrtHeuristicArithmetic(T A, T B, usize M, usize N) {
    assert(M > usize{0});
    assert(N < M);
    A %= M;
    B %= M;
    if (N == 0)
        return {};
    if (N == 1)
        return {Arithmetic{A, B, N, 0, 0}};
    assert(std::gcd(A, M) == 1);
    // floor(sqrt(M)) + 1
    const usize n = [&]() {
        usize res = static_cast<usize>(sqrtl(M));
        while (res * res > M)
            res--;
        while (res * res <= M)
            res++;
        return res;
    }();
    assert(std::min(n, M) >= 2);
    std::vector<std::pair<T, usize>> app;
    app.reserve(std::min(n, N));
    for (usize i = 0 ; i < std::min(n, N) ; i++)
        app.emplace_back((A * i + B) % M, i);
    std::ranges::sort(app);
    T minDiff = M;
    usize idx = 0;
    for (usize i = 1 ; i < app.size() ; i++) {
        if (minDiff > std::abs(app[i].first - app[i - 1].first)) {
            minDiff = std::abs(app[i].first - app[i - 1].first);
            idx = i;
        }
    }
    const usize gap = std::max(app[idx].second, app[idx - 1].second) - std::min(app[idx].second, app[idx - 1].second);
    assert(minDiff > 0); 
    assert(gap > 0);
    std::vector<Arithmetic<T>> res;
    const T a = 
        ((A * std::max(app[idx].second, app[idx - 1].second) + B) % M)
        -
        ((A * std::min(app[idx].second, app[idx - 1].second) + B) % M);
    assert(a != 0);
    for (usize i = 0 ; i < gap ; i++)
        for (usize j = i ; j < N ; ) {
            const T stVal = (A * j + B) % M;
            const T n = std::min<usize>(
                    (N - j + gap - 1) / gap - 1, 
                    a > 0 ? ((M - stVal + a - 1) / a - 1) : stVal / (-a));
            assert(stVal + n * a < static_cast<T>(M) and stVal + n * a >= 0);
            assert(j + n * gap < N);
            res.emplace_back(a, stVal, n + 1, j, gap);
            j += gap * (n + 1);
        }
    return res;
}

} // namespace zawa
#line 2 "Src/Number/SqrtHeuristicArithmetic.hpp"

#line 2 "Src/Template/TypeAlias.hpp"

#include <cstdint>
#include <cstddef>

namespace zawa {

using i16 = std::int16_t;
using i32 = std::int32_t;
using i64 = std::int64_t;
using i128 = __int128_t;

using u8 = std::uint8_t;
using u16 = std::uint16_t;
using u32 = std::uint32_t;
using u64 = std::uint64_t;

using usize = std::size_t;

} // namespace zawa
#line 4 "Src/Number/SqrtHeuristicArithmetic.hpp"

#include <algorithm>
#include <cassert>
#include <cmath>
#include <concepts>
#include <ranges>
#include <tuple>
#include <vector>
#include <numeric>

namespace zawa {

// ax + b
template <std::integral T>
struct Arithmetic {
    T a, b;
    usize n;
    // original sequence f(l), f(l + gap), f(l + 2gap), ..., f(l + (n-1)gap)
    usize l, gap;
    T operator()(usize i) const {
        assert(i < n);
        return a * i + b;
    }
    usize originIndex(usize i) const {
        assert(i < n);
        return gap * i + l;
    }
};

// { B, A + B, 2A + B, 3A + B, ..., (N-1)A + B } (each values are mod M)
// require: 0 <= N < M, gcd(A, M) = 1
// reference: https://codeforces.com/blog/entry/141889
// reference: https://atcoder.jp/contests/abc429/submissions/70468550
// response: get<0> * get<1>[i](j) + get<2> = B + A * get<1>[i].origin_index(j)
// memo: if gcd(A, M) = g > 1, transform to g((A/g)x + (B/g) (mod M)) + B%g
// memo: if N >= M, reduce to gcd(A, M) = 1 and partition to {0, 1, 2, ..., M - 1} (N/M times), {B, A+B, ..., A*(N%M)+B} mod M
template <std::integral T>
std::vector<Arithmetic<T>> SqrtHeuristicArithmetic(T A, T B, usize M, usize N) {
    assert(M > usize{0});
    assert(N < M);
    A %= M;
    B %= M;
    if (N == 0)
        return {};
    if (N == 1)
        return {Arithmetic{A, B, N, 0, 0}};
    assert(std::gcd(A, M) == 1);
    // floor(sqrt(M)) + 1
    const usize n = [&]() {
        usize res = static_cast<usize>(sqrtl(M));
        while (res * res > M)
            res--;
        while (res * res <= M)
            res++;
        return res;
    }();
    assert(std::min(n, M) >= 2);
    std::vector<std::pair<T, usize>> app;
    app.reserve(std::min(n, N));
    for (usize i = 0 ; i < std::min(n, N) ; i++)
        app.emplace_back((A * i + B) % M, i);
    std::ranges::sort(app);
    T minDiff = M;
    usize idx = 0;
    for (usize i = 1 ; i < app.size() ; i++) {
        if (minDiff > std::abs(app[i].first - app[i - 1].first)) {
            minDiff = std::abs(app[i].first - app[i - 1].first);
            idx = i;
        }
    }
    const usize gap = std::max(app[idx].second, app[idx - 1].second) - std::min(app[idx].second, app[idx - 1].second);
    assert(minDiff > 0); 
    assert(gap > 0);
    std::vector<Arithmetic<T>> res;
    const T a = 
        ((A * std::max(app[idx].second, app[idx - 1].second) + B) % M)
        -
        ((A * std::min(app[idx].second, app[idx - 1].second) + B) % M);
    assert(a != 0);
    for (usize i = 0 ; i < gap ; i++)
        for (usize j = i ; j < N ; ) {
            const T stVal = (A * j + B) % M;
            const T n = std::min<usize>(
                    (N - j + gap - 1) / gap - 1, 
                    a > 0 ? ((M - stVal + a - 1) / a - 1) : stVal / (-a));
            assert(stVal + n * a < static_cast<T>(M) and stVal + n * a >= 0);
            assert(j + n * gap < N);
            res.emplace_back(a, stVal, n + 1, j, gap);
            j += gap * (n + 1);
        }
    return res;
}

} // namespace zawa
Back to top page