背景

给定n, k (n\le 10^9, k \le 10^3)以及\lbrace a_1, a_2,\dots a_k\rbrace, \lbrace f_0, f_1,\dots f_{k-1} \rbrace

且数列f的递推关系式满足:

f_n = \sum_{i=1}^k f_{n-i}a_i

f_n.

一点声明

以上就是所谓的常系数齐次线性递推.

它的递推式必须有以下性质:

  1. 各项系数均为常数, 即与下标无关等.
  2. 各项是齐次的, 即不会有f_n=f_{n-1}+1​或者f_n=f_{n-1}^2​的形式

矩阵快速幂优化

该标题下讨论的是一般性矩阵,

也就是所有矩阵均适用.

给出做法前, 先撇开一切, 我们先来考虑这样的一个问题:

  • 给定n×n的矩阵M, 求M^k

  • n\leq 50, k\leq 10 ^ {50000}

普通的做法复杂度是n^3\log k​的.

但是注意到n十分小, \log k却很大, 利用\text{ Cayley-Hamilton}定理优化可以优化到O(n^4 + n^2\log k), 下面给出做法:

  • 定义矩阵M的特征多项式为:

f(x) = \det(xE-M) = x ^ n + c_1x^{n-1} + c_2 x ^{n-2} + \cdots + c_{n-1}x + c_n

x可以带入不局限于复数域内的东西.

再设g(x)=x^{k}, 要求的答案就是g(M)=M^k​.

注意到f(M) = \det(M-M) = 0, 也就是说g(M)-\lambda \cdot f(M)还是答案.

由此联想到多项式取模, 我们只要求出g\bmod f即可.

M^k = (g\bmod f) (M)

那么和多项式取模有点不同的是g(x)=x^k, \deg(g)比较大

故考虑多项式快速幂–只是每次多项式乘法都要在模f(x)的意义下进行!

暴力多项式乘法总复杂度是n^2\log k的, FFT优化是n\log n \cdot \log k.

但….f(x)咋求, 难道要2的指数级枚举求出行列式吗?

可以带入n个值, 并求\det, 最后插值出多项式即可!

复杂度是n^4的(注意带入的值要存在乘法逆元).

线性递推

终于回归正题

考虑我们的特征多项式:

M 为左乘转移矩阵.

\begin{aligned} f(x)&=\det(x E-M)\\ &=\begin{vmatrix}{x} & {-1} & {0} & {0} & {\cdots} & {0} \\ {0} & {x} & {-1} & {0} & {\cdots} & {0} \\ {0} & {0} & {x} & {-1} & {\cdots} & {0} \\ {\vdots} & {\vdots} & {\vdots} & {\vdots} & {\ddots} & {\vdots} \\ {0} & {0} & {0} & {0} & {\cdots} & {-1} \\ {-a_{k}} & {-a_{k-1}} & {-a_{k-2}} & {-a_{k-3}} & {\cdots} & {x-a_{1}}\end{vmatrix} \end{aligned}

考虑对它进行高斯消元求出特征多项式:

\begin{aligned} &=x\cdot\left|\begin{array}{cccccc}{x} & {-1} & {0} & {0} & {\cdots} & {0} \\ {0} & {x} & {-1} & {0} & {\cdots} & {0} \\ {0} & {0} & {x} & {-1} & {\cdots} & {0} \\ {\vdots} & {\vdots} & {\vdots} & {\vdots} & {\ddots} & {\vdots} \\ {0} & {0} & {0} & {0} & {\cdots} & {-1} \\ {-\frac{a_{k}}{x}-a_{k-1}} & {-a_{k-2}} & {-a_{k-3}} & {-a_{k-4}} & {\cdots} & {x-a_{1}}\end{array}\right| \\ &=x^{2} \left|\begin{array}{cccccc}{x} & {-1} & {0} & {0} & {\cdots} & {0} \\ {0} & {x} & {-1} & {0} & {\cdots} & {0} \\ {0} & {0} & {x} & {-1} & {\cdots} & {0} \\ {\vdots} & {\vdots} & {\vdots} & {\vdots} & {\ddots} & {\vdots} \\ {0} & {0} & {0} & {0} & {\cdots} & {-1} \\ {-\frac{a_{k}}{x^{2}}-\frac{a_{k-1}}{x}-a_{k-2}} & {-a_{k-3}} & {-a_{k-4}} & {-a_{k-5}} & {\cdots} & {x-a_{1}}\end{array}\right| \\ &=…\\ &=x^{k-1} \cdot\left(x-\sum_{i=1}^{k} \frac{a_{i}}{x^{i-1}}\right)\\ &=x^{k}-\sum_{i=1}^{k} a_{i} x^{k-i}\\ \end{aligned}

发现由于它的特殊性质, 并不需要O(n^4)的复杂度就可以知道它的特征多项式了.

进一步的, 设一开始的初始k维行向量为A

那么我们要求: M ^n\cdot A, 可能有人说可以利用上面的矩阵乘法

但, 思考一下发现我们是逃不开n^4的复杂度的…

为什么会这样, 由于这样操作得到的是整个M^n, 简单来说, 就是我们知道的太多了.

我们不需要知道整个k维向量, 其实只是要第一个数字, 即:\Big(M^n \cdot A\Big)[0]

那么开始转换:

\begin{aligned} &\left( M^{n} \cdot A \right) [0] \\ =&\Big(\left(x^{n} \bmod f(x)\right)(M) \cdot A\Big)[0] \end{aligned}

假设结果为:

\begin{aligned}&{ \sum_{i=0}^{k-1} c_{i} M^{i} \cdot A } \\ &=\sum_{i=0}^{k-1} c_{i} (M^{i}A)\end{aligned}

而发现由于(i, 那么(M_iA)[\lambda] = A[\lambda +i]

类似地, 只要我们预处理A数列的前 2k 项, 就可以得到所有M^i A (i\le k).

那这样就做完了, 只要求出c[\ ]就好了, 而这个刚刚已经复述过.

复杂度为O(n^2 \log k) \text{ or } O(n\log n \cdot \log k)

Code

tips: 做线性递推的时候如果是n^2暴力卷积, 那么取模有简单的方法(详见代码)

#pragma GCC optimize(2)
#include <bits/stdc++.h>
using namespace std;

const int N = 2050;
const int mod = 1e9 + 7;
# define pb push_back

int n, k, x;
vector<int> a, f;

inline void upd(int &a, int b) {
    a += b; if(a >= mod) a -= mod;
}
inline int mul(int a, int b) {
    return (long long) a * b % mod;
}

namespace sol {
# define V vector<int>
    int ret, k; V mo;
    inline V Mul(V &a, V &b) {
        static V c;
        int n = a.size(), m = b.size();
        c.resize(n + m - 1);
        for(int i = 0; i < n + m - 1; ++i) c[i] = 0;
        for(int i = 0; i < n; ++i)
            for(int j = 0; j < m; ++j)
                upd(c[i + j], mul(a[i], b[j]));
        for(int i = (int) c.size() - 1; i >= k; --i) { // 多项式取模的n ^ 2版本
            if(!c[i]) continue;
            for(int j = 1; j <= k; ++j)
                upd(c[i - j], mul(mo[j - 1], c[i]));
        }
        if(c.size() > k)
            c.erase(c.begin() + k, c.end());
        return c;
    }
    V poly_pow(int n) {
        V res, a;
        res.resize(2);
        a.resize(2);
        res[0] = a[1] = 1;
        for(; n; n >>= 1, a = Mul(a, a)) 
            if(n & 1) res = Mul(res, a);
        return res;
    }
    int calc(int n, int _, V g, V a) {
        // calc g[n], g[0--(k - 1)], a[0--(k - 1)]
        ret = 0;
        k = _, mo = a;
        V s = poly_pow(n);
        int I = min(k, (int) s.size());
        for(int i = 0; i < I; ++i)
            upd(ret, mul(g[i], s[i]));
        return ret;
    }
}

int main() {
    scanf("%d%d", &n, &k);
    for(int i = 0; i < k; ++i) {
        scanf("%d", &x);
        if(x < 0) x += mod;
        a.pb(x);
    }
    for(int i = 0; i < k; ++i) {
        scanf("%d", &x);
        if(x < 0) x += mod;
        f.pb(x);
    }
    printf("%d\n", sol :: calc(n, k, f, a));
    return 0;
}

fft优化版本

#pragma GCC optimize(2)
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
typedef vector<int> poly;

const int N = 1 << 17 | 10; // remember to calc N ! 
const int mod = 998244353;

namespace {
  inline int mul(int a, int b) { return (LL) a * b % mod; }
  inline int add(int a, int b) { return (a += b) < mod ? a : a - mod; }
  inline int sub(int a, int b) { return (a -= b) < 0 ? a + mod : a; }
  inline int Pow(int a, int b) {
    int r = 1;
    for(; b; b >>= 1, a = mul(a, a))
      if(b & 1) r = mul(r, a);
    return r;
  }
  inline void read(int &x) {
    x = 0; int f = 1; char c = getchar();
    for(;!isdigit(c); c = getchar())
      if(c == '-') f = -1; 
    for(; isdigit(c); c = getchar())
      x = x * 10 + c - '0';
    x *= f;
  }
  inline void Mread(int &x) {
    x = 0; int f = 0; char c = getchar();
    for(; !isdigit(c); c = getchar())
      if(c == '-') f = 1;
    for(;  isdigit(c); c = getchar())
      x = add(mul(x, 10), c - '0');
    if(f) x = sub(0, x);
  }
}

inline void print(poly &a) {
  for(int i = (int) a.size() - 1; ~i; --i)
    printf("%d ", a[i]);
  puts("");
}

int n, k, f[N], a[N];
namespace sol {
# define pi pair<poly, poly>
# define Rev(x) (reverse((x).begin(), (x).end()))

  int lim, rev[N], o[N >> 1];
  int fac[N], ifac[N], inv[N];

  inline void init(int n) {
    for(lim = 1; lim < n; lim <<= 1);
    for(int i = 0, s = lim >> 1; i < lim; ++i)
      rev[i] = (rev[i >> 1] >> 1) | ((i & 1) ? s : 0);
    o[0] = 1;
    int c = Pow(3, (mod - 1) / lim);
    for(int i = 1; i < lim >> 1; ++i)
      o[i] = mul(o[i - 1], c);
  }
  void Dft(poly &a, int n = lim) {
    static int t;
    for(int i = 0; i < n; ++i)
      if(i < rev[i]) swap(a[i], a[rev[i]]);
    for(int m = 1, tmp = n >> 1; m < n; m <<= 1, tmp >>= 1)
      for(int i = 0; i < n; i += m << 1) {
        int *l = &a[0] + i, *r = l + m;
        for(int k = 0; k < m; ++k, ++l, ++r) {
          t = mul(*r, o[tmp * k]);
          *r = sub(*l, t);
          *l = add(*l, t);
        }
      }
  }
  void Idft(poly &a, int n = lim) {
    reverse(a.begin() + 1, a.end());
    Dft(a);
    int Iv = Pow(n, mod - 2);
    for(int i = 0; i < (int) a.size(); ++i)
      a[i] = mul(a[i], Iv);
  }
  poly Mul(poly a, poly b) {
    int n = a.size(), m = b.size();
    init(n + m - 1);
    a.resize(lim), Dft(a);
    b.resize(lim), Dft(b);
    for(int i = 0; i < lim; ++i)
      a[i] = mul(a[i], b[i]);
    Idft(a);
    a.resize(n + m - 1);
    return a;
  }
  poly Inv(poly a, int O) {
    if(!a[0]) { cerr << "no Inv!\n"; }
    poly b(1, Pow(a[0], mod - 2)), c;
    for(int i = 2; (i >> 1) < O; i <<= 1) {
      init(i << 1);
      c = a, c.resize(i);
      b.resize(i << 1), Dft(b);
      c.resize(i << 1), Dft(c);
      for(int j = 0; j < (i << 1); ++j)
        b[j] = mul(b[j], sub(2, mul(b[j], c[j])));
      Idft(b);
      b.resize(i);
    }
    b.resize(O);
    return b;
  }
  pi Div(poly a, poly b) {
    pi r;
    int n = a.size(), m = b.size();
    if(n < m) return pi(poly(1, 0), a);
    Rev(a);
    // -------- be careful ! -------
    static poly c;
    {
      Rev(b);
      c = Inv(b, n - m + 1);
      Rev(b);
    }
    r.first = Mul(a, c);
    r.first.resize(n - m + 1);
    Rev(a), Rev(r.first);
    r.second = Mul(b, r.first);
    r.second.resize(m - 1);
    for(int i = 0; i < m - 1; ++i)
      r.second[i] = sub(a[i], r.second[i]);
    return r;
  }

  poly PO(int n, poly mo) {
    // print(mo);
    poly a(2), r(1, 1);
    a[0] = 0;
    a[1] = 1;
    while(n) {
      if(n & 1) {
        r = Mul(r, a);
        r = Div(r, mo).second;
      }
      a = Mul(a, a);
      a = Div(a, mo).second;
      n >>= 1;
    }
    return r;
  }
  void main() {
    read(n), read(k);
    for(int i = 1; i <= k; ++i) Mread(a[i]);
    for(int i = 0; i < k; ++i) Mread(f[i]);
    poly g(k + 1);  
    g[k] = 1;
    for(int i = 0; i < k; ++i)
      g[i] = sub(0, a[k - i]);
    g = PO(n, g);
    int ans = 0;
    for(int i = 0; i < (int) g.size(); ++i) 
      ans = add(ans, mul(f[i], g[i]));
    printf("%d\n", ans);
  }
}

int main() {
  sol :: main();
  return 0;
}