原题传送门
首先得推一推式子
令
s
u
m
i
=
∑
j
=
1
i
a
j
sum_i=\sum_{j=1}^{i}a_j
sumi=∑j=1iaj
a
n
s
=
(
v
1
−
s
u
m
n
m
)
2
+
(
v
2
−
s
u
m
n
m
)
2
+
.
.
.
+
(
v
m
−
s
u
m
n
m
)
2
m
ans=\frac{(v_1-\frac{sum_n}{m})^2+(v_2-\frac{sum_n}{m})^2+...+(v_m-\frac{sum_n}{m})^2}{m}
ans=m(v1−msumn)2+(v2−msumn)2+...+(vm−msumn)2
=
v
1
2
+
v
2
2
+
.
.
.
+
v
m
2
m
+
s
u
m
n
2
m
2
−
2
s
u
m
n
2
m
2
=\frac{v_1^2+v_2^2+...+v_m^2}{m}+\frac{sum_n^2}{m^2}-2\frac{sum_n^2}{m^2}
=mv12+v22+...+vm2+m2sumn2−2m2sumn2
=
v
1
2
+
v
2
2
+
.
.
.
+
v
m
2
m
−
s
u
m
n
2
m
2
=\frac{v_1^2+v_2^2+...+v_m^2}{m}-\frac{sum_n^2}{m^2}
=mv12+v22+...+vm2−m2sumn2
m
2
a
n
s
=
m
(
v
1
2
+
v
2
2
+
.
.
.
+
v
m
2
)
−
s
u
m
n
2
m^2ans=m(v_1^2+v_2^2+...+v_m^2)-sum_n^2
m2ans=m(v12+v22+...+vm2)−sumn2
现在任务是求
m
i
n
(
v
1
2
+
v
2
2
+
.
.
.
+
v
m
2
)
min(v_1^2+v_2^2+...+v_m^2)
min(v12+v22+...+vm2)
上dp
d
p
i
,
j
dp_{i,j}
dpi,j表示前
i
i
i个分成
j
j
j段的最小平方和
d
p
i
,
j
=
m
i
n
(
d
p
k
,
j
−
1
+
(
s
u
m
i
−
s
u
m
k
)
2
)
dp_{i,j}=min(dp_{k,j-1}+(sum_i-sum_k)^2)
dpi,j=min(dpk,j−1+(sumi−sumk)2)
时间复杂度较大,用斜率优化
假设两个决策
x
,
y
(
x
<
y
)
x,y(x<y)
x,y(x<y),如果
y
y
y更优必须满足
d
p
x
,
j
−
1
+
s
u
m
x
2
−
2
s
u
m
i
s
u
m
x
>
d
p
y
,
j
−
1
+
s
u
m
y
2
−
2
s
u
m
i
s
u
m
y
dp_{x,j-1}+sum_x^2-2sum_isum_x>dp_{y,j-1}+sum_y^2-2sum_isum_y
dpx,j−1+sumx2−2sumisumx>dpy,j−1+sumy2−2sumisumy
化简得
d
p
x
,
j
−
1
−
d
p
y
,
j
−
1
+
s
u
m
x
2
−
s
u
m
y
2
2
(
s
u
m
x
−
s
u
m
y
)
<
s
u
m
i
\frac{dp_{x,j-1}-dp_{y,j-1}+sum_x^2-sum_y^2}{2(sum_x-sum_y)}<sum_i
2(sumx−sumy)dpx,j−1−dpy,j−1+sumx2−sumy2<sumi
维护一个斜率的下凸包就好了
其实还可以滚动掉一维,但是反正能过,就懒得滚动了
Code:
#include <bits/stdc++.h>
#define maxn 3010
#define LL long long
using namespace std;
LL sum[maxn], dp[maxn][maxn], n, m, q[maxn];
inline int read(){
int s = 0, w = 1;
char c = getchar();
for (; !isdigit(c); c = getchar()) if (c == '-') w = -1;
for (; isdigit(c); c = getchar()) s = (s << 1) + (s << 3) + (c ^ 48);
return s * w;
}
double slope(int opt, int x, int y){
return 1.0 * (dp[x][opt] - dp[y][opt] + sum[x] * sum[x] - sum[y] * sum[y]) / 2 / (sum[x] - sum[y]);
}
int main(){
n = read(), m = read();
for (int i = 1; i <= n; ++i) sum[i] = sum[i - 1] + read(), dp[i][1] = sum[i] * sum[i];
for (int j = 2; j <= m; ++j){
int h = 0, t = 0; q[0] = j - 1;
for (int i = j; i <= n; ++i){
while (h < t && slope(j - 1, q[h], q[h + 1]) < sum[i]) ++h;
dp[i][j] = dp[q[h]][j - 1] + (sum[i] - sum[q[h]]) * (sum[i] - sum[q[h]]);
while (h < t && slope(j - 1, q[t], i) < slope(j - 1, q[t - 1], q[t])) --t;
q[++t] = i;
}
}
printf("%lld\n", m * dp[n][m] - sum[n] * sum[n]);
return 0;
}