题目描述

设数论函数$f$满足$n^2−3n+2=\sum_{d \mid n}f(d) $

求$\sum_{i=1}^{n}f(i)$

答案对$10^9+7$取模

题解

回忆杜教筛的一个表达形式

$$
f \times g=h \Rightarrow g(1)F(n)=H(n)-\sum_{d=2}^{n}g(d)F(\lfloor \frac{n}{d} \rfloor)
$$

设$g(n)=1,h(n)=n^2-3n+2$,则有$f \times g = h$,即$f \times 1 = h$

所以

$$
\begin{align}
F(n)
&=(\sum_{i=1}^{n}i^2-3i+2)-\sum_{d=2}^{n}F(\lfloor \frac{n}{d} \rfloor)\\
&=\frac{n(n+1)(2n+1)}{6} - \frac{3n(1+n)}{2} + 2n - \sum_{d=2}^{n}F(\lfloor \frac{n}{d} \rfloor)
\end{align}
$$

因为$f \times 1 = h$,所以$f$是积性函数,且$f=h \times \mu$,即$f(n)=\sum_{d \mid n}h(d)\mu(\frac{n}{d})$

不妨对于每一个$\mu(i)$,枚举$j$,将$f(i \times j)$加上$\mu(i)h(j)$

之后就是递归计算了

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#include "bits/stdc++.h"
using namespace std;
#define DEBUG printf("Passing [%s] in LINE %d\n",__FUNCTION__,__LINE__)
typedef long long ll;
typedef pair<int, int> pii;
const int N = 1e6, mod = 1e9 + 7;

ll ans, inv6, inv2, n;

int mu[N + 10], pri[N + 10], tot, S[N + 10], dn[N + 10], vis[N + 10], f[N + 10];

ll pw(ll a, ll b) {
ll r = 1;
for( ; b ; b >>= 1, a = a * a % mod) if(b & 1) r = r * a % mod;
return r;
}

int h(ll n) {
return (n * n % mod - 3 * n % mod + 2) % mod;
}

int H(ll n) {
return (((
n * (n + 1) % mod * (2 * n + 1) % mod * inv6 % mod
) - (
3 * n % mod * (1 + n) % mod * inv2 % mod
)) % mod + (
2 * n % mod
)) % mod;
}

map<ll, ll> val;
ll F(ll n) {
if(n <= N) return f[n];
else if(val.find(n) != val.end()) return val[n];
else {
ll res = H(n);
for(ll i = 2, j ; i <= n ; i = j + 1) {
j = n / (n / i);
res = (res - F(n / i) * (j - i + 1) % mod) % mod;
}
return val[n] = res;
}
}

void sol() {
scanf("%lld", &n);
printf("%d\n", (F(n) % mod + mod) % mod);
}

int main() {
inv6 = pw(6, mod - 2), inv2 = pw(2, mod - 2);
mu[1] = 1;
for(int i = 2 ; i <= N ; ++ i) {
if(!vis[i]) pri[++ tot] = i, mu[i] = -1;
for(int j = 1 ; j <= tot && i * pri[j] <= N ; ++ j) {
vis[i * pri[j]] = 1;
if(i % pri[j] == 0) break;
mu[i * pri[j]] = -mu[i];
}
}
for(int i = 1 ; i <= N ; ++ i) {
for(int j = 1 ; i * j <= N ; ++ j) {
f[i * j] = ((ll) f[i * j] + mu[j] * h(i) % mod) % mod;
}
f[i] = ((ll) f[i - 1] + f[i]) % mod;
}
int T; scanf("%d", &T); while(T --) sol();
}