bzoj2286:[Sdoi2011]消耗战(树形DP+虚树)

Author Avatar
Sakits 3月 31, 2018

题解

  虚树的模板题辣…
  对每个询问建出虚树,显然这题只需要留下祖先没有询问点的点,这样树形DP会简单很多。
  设f(i)f(i)为切断ii的子树的最小代价,costicost_iii到根节点路径上边的最小边权,则有:

f(i)=min(f(son),costi)f(i)=min(\sum f(son),cost_i)

  我也不知道为什么要写这破水题的题解…不然我的虚树学习笔记好像没有例题放啊(捂脸
  居然在第二页,应该是树剖比较快吧嘿嘿嘿

代码

#include<iostream>
#include<cstring>
#include<cstdlib>
#include<cstdio>
#include<algorithm>
#define ll long long
using namespace std;
const int maxn=500010, inf=1e9;
struct poi{int too, dis, pre;}e[maxn], e2[maxn];
int n, m, Q, x, y, z, tot, tot2, tott;
int size[maxn], dep[maxn], son[maxn], dfn[maxn], las[maxn], last[maxn], last2[maxn], fa[maxn], top[maxn], cost[maxn], a[maxn], st[maxn];
ll f[maxn];
inline void read(int &k)
{
    int f=1; k=0; char c=getchar();
    while(c<'0' || c>'9') c=='-' && (f=-1), c=getchar();
    while(c<='9' && c>='0') k=k*10+c-'0', c=getchar();
    k*=f;
}
inline void add1(int x, int y, int z){e[++tot]=(poi){y, z, last[x]}; last[x]=tot;}
inline void add2(int x, int y, int z){if(x==y) return; e2[++tot2]=(poi){y, z, last2[x]}; last2[x]=tot2;}
inline bool cmp(int a, int b){return dfn[a]<dfn[b];}
void dfs1(int x)
{
    size[x]=1; dep[x]=dep[fa[x]]+1;
    for(int i=last[x], too;i;i=e[i].pre)
    if((too=e[i].too)!=fa[x]) 
    {
        cost[too]=min(cost[x], e[i].dis);
        fa[too]=x; dfs1(too);
        size[x]+=size[too];
        if(size[too]>size[son[x]]) son[x]=too;
    }
}
void dfs2(int x, int tp)
{
    top[x]=tp; dfn[x]=++tott;
    if(son[x]) dfs2(son[x], tp);
    for(int i=last[x], too;i;i=e[i].pre)
    if((too=e[i].too)!=fa[x] && too!=son[x]) dfs2(too, too);
    las[x]=tott;
}
inline int lca(int x, int y)
{
    int f1=top[x], f2=top[y];
    while(f1!=f2)
    {
        if(dep[f1]<dep[f2]) swap(x, y), swap(f1, f2);
        x=fa[f1]; f1=top[x];
    }
    return dep[x]<dep[y]?x:y;
}
void dp(int x)
{
    ll sum=0;
    for(int i=last2[x], too;i;i=e2[i].pre) dp(too=e2[i].too), sum+=f[too];
    f[x]=last2[x]?min(sum, (ll)cost[x]):cost[x]; last2[x]=0;
    if(x==1) f[x]=sum;
}
int main()
{
    read(n);
    for(int i=1;i<n;i++) 
    read(x), read(y), read(z), add1(x, y, z), add1(y, x, z);
    cost[1]=inf; dfs1(1); dfs2(1, 1); read(Q);
    for(int i=1;i<=Q;i++)
    {
        read(m); for(int j=1;j<=m;j++) read(a[j]);
        sort(a+1, a+1+m, cmp);
        int cnt=1, top=1; st[1]=1; tot2=0;
        for(int j=2;j<=m;j++) 
        if(dfn[a[cnt]]>dfn[a[j]] || las[a[cnt]]<dfn[a[j]]) a[++cnt]=a[j];
        for(int j=1;j<=cnt;j++)
        {
            int fa=lca(a[j], st[top]);
            while(1)
            {
                if(dfn[st[top-1]]<=dfn[fa])
                {
                    add2(fa, st[top--], 0);
                    if(dfn[st[top-1]]<dfn[fa]) st[++top]=fa;
                    break;
                }
                add2(st[top-1], st[top], 0); top--;
            }
            st[++top]=a[j];
        }
        while(top>1) add2(st[top-1], st[top], 0), top--;
        dp(1); printf("%lld\n", f[1]);
    }
}