blogskol

思わず声に出して読みたくなるブログ

PGBATTLE2022 せんべい D

大会サイト : PG BATTLE 2022 - [第5回]企業・学校対抗プログラミングバトル

問題 : https://products.sint.co.jp/hubfs/resource/topsic/pgb2022/2_4.pdf

せんべいでコンテストに参加した人間以外に現在ジャッジが提供されているかは分からず

問題概要

 N 頂点の木が与えられる
各頂点  v には重み  a_v が書かれている
連結頂点集合  S に対して、 F(S)= (\sum_{v\in S} a_v)^K とする
各頂点  v に対して、 \sum_{S\ni v} F(S) を mod  998244353 で出力

※ 問題文から  F(S) の定義を少し変えています

制約

 N\leq 10^3
 K\leq 50
 a_v\leq 10^8

解法

 v=0 についての根付き木に対して木DPを用いた解法を考える
これが解けたなら最後に全方位木DPに直せば良い

 F(S) K 乗と言うのが非常に計算し辛い
一旦  F(S) の代わりに  G(S)=\prod_{v\in S}a_v だった場合について解いてみる
この場合、 dp_v v を根とした部分木についての解とすると、 C_v v の子の集合とすれば、
 dp_v = a_v \prod_{c\in C_v}(dp_c+1)
が成り立つから、解くことが出来る

ここで、 a^K=K![x^K]\exp(ax) であることに注目すると、
 F(S)=K![x^K]\exp((\sum_{v\in S}a_v)x)
と書くことが出来る

更に  \exp((\sum_{v\in S} a_v)x)=\prod_{v\in S}\exp(a_vx) より、 f_v=\exp(a_vx) と置くと、
 F(S)=K![x^K]\prod_{v\in S}f_v
が成り立つ

従って、 G(S) の場合と同様に
 dp_v = f_v \prod_{c\in C_v}(dp_c+1)
とすれば良い(最後に  K! をかけたものを出力する)

コード

FPSライブラリと全方位木ライブラリは記事の最後に付けます

#include <bits/stdc++.h>
using namespace std;
#define REP(i,n) for(int i=0;i<(n);i++)

#include <atcoder/modint>
#include <atcoder/convolution>
using namespace atcoder;
using mint=modint998244353;

#include <FPS>
#include <ReRooting>
using FPS=FormalPowerSeries<mint,55>;

int main(){
  
  int n,k;cin>>n>>k;
  vector<FPS> f(n);
  REP(i,n){
    int a;cin>>a;
    f[i]=FPS::exp(FPS(vector<mint>{0,a}));
  }

  mint fact_k=1;
  REP(i,k)fact_k*=i+1;

  auto merge=[&](const FPS&g,const FPS&h)->FPS{
    // dp_c の merge 
    return (1+g)*(1+h)-1;
  };
  auto score=[&](const FPS&g,const auto&e)->FPS{
    // merge された子を dp_v に合成する
    return f[e.to]*(1+g);
  };

  ReRooting<bool,FPS> tree(n,score,merge,FPS());
  REP(_,n-1){
    int a,b;cin>>a>>b;a--;b--;
    tree.add_edge(a,b,0);
  }
  vector<FPS> ans=tree.build(); 
  // ans[v]:vを根とした時の dp_v
  // ただし全方位木ライブラリの設計上、最後の score を取る操作はまだされてない
  REP(i,n){
    ans[i]=(1+ans[i])*f[i]; // 最後の score を取る操作
    cout<< (ans[i][k]*fact_k).val <<"\n";
  }
}

計算量

FPS は常に  K 次までしか保存しないので、掛け算は  O(K \log K)
掛け算の行われる回数は  O(N) 回なので、全体計算量は  O(NK\log K)
FPS を使わずに  (\sum_{v\in S} a_v)^i を各  i\in{0,1,\dots,K} について持っておく解法もあり、本番はこれで通した
この場合 FPS の掛け算に対応する操作が  O(K^2) かかるため全体計算量は  O(NK^2)

ライブラリのコード

FPS ライブラリ

#define REP_(i,n) for(int i=0;i<(n);i++)
template<typename T,int MX>
struct FormalPowerSeries:vector<T>{
  using FPS=FormalPowerSeries<T,MX>;
  using vector<T>::resize;
  using vector<T>::size;
  using vector<T>::at;
  FormalPowerSeries()=default;
  FormalPowerSeries(int n,T a={}){
    resize(min(MX,n),a);
  }
  FormalPowerSeries(const vector<T>&f){
    int n=min(MX,int(f.size()));
    resize(n);
    REP_(i,n)at(i)=f[i];
  }
  FormalPowerSeries(const vector<pair<T,int>>&sparse){
    int n=0;
    for(const auto&[co,deg]:sparse)n=max(n,deg+1);
    n=min(MX,n);
    assign(n,T(0));
    for(const auto&[co,deg]:sparse)if(deg<n)at(deg)=co;
  }
  FPS operator-()const{
    FPS g=*this;
    for(T&a:g)a=-a;
    return g;
  }
  
  FPS &operator+=(const FPS &g){
    if(size()<g.size())resize(g.size());
    REP_(i,g.size())at(i)+=g[i];
    return *this;
  }
  FPS operator+(const FPS &g)const{return FPS(*this)+=g;}
  FPS &operator+=(const T &a){
    if(!size())resize(1);
    at(0)+=a;
    return *this;
  }
  FPS operator+(const T& a)const{return FPS(*this)+=a;}
  friend FPS operator+(const T&a,const FPS&f){return f+a;}
  FPS &operator-=(const FPS &g){
    if(size()<g.size())resize(g.size());
    REP_(i,g.size())at(i)-=g[i];
    return *this;
  }
  FPS operator-(const FPS &g)const{return FPS(*this)-=g;}
  FPS &operator-=(const T &a){
    if(!size())resize(1);
    at(0)-=a;
    return *this;
  }
  FPS operator-(const T& a){return FPS(*this)-=a;}
  friend FPS operator-(const T&a,const FPS&f){return a+(-f);}
  
  FPS operator*(const FPS&g)const{
    return FPS(convolution(*this,g));
  }
  FPS &operator*=(const FPS&g){
    return (*this)=(*this)*g;
  }
  FPS &operator*=(const T &a){
    REP_(i,size())at(i)*=a;
    return *this;
  }
  FPS operator*(const T &a)const{
    FPS res(*this);
    for(T&p:res)p*=a;
    return res;
  }
  friend FPS operator*(const T&a,const FPS&f){return f*a;}
  FPS operator/(const FPS g)const{
    return *this*g.inv();
  }
  FPS &operator/=(const FPS&g){
    return (*this)=(*this)/g;
  }
  FPS &operator/=(const T &a){
    assert(a!=0);
    REP_(i,size())at(i)/=a;
    return *this;
  }
  FPS inv()const{
    assert(size() and at(0)!=0);
    FPS res(1,at(0).inv());
    for(int i=0;(1<<i)<MX;i++)res*=(2-res*(*this));
    return res;
  }
  void strict(int n){
    if(size()>n)resize(n);
  }
  FPS pow(int n)const{
    assert(n>=0);
    if(n==0)return FPS(1,1);
    if(n==1)return *this;
    if(at(0)==1)return exp(n*log(*this));
    FPS res(vector<T>{1}),now=*this;
    while(n){
      if(n&1)res*=now;
      now*=now;
      res>>=1;
    }
    return res;
  }
  FPS operator()(FPS f)const{
    if(size()==1)return FPS(1,at(0));
    if(size()==2)return FPS(at(0)+at(1)*f);
    int n=size()/2;
    FPS s=*this;
    s.strict(n);
    FPS t(size()-n);
    for(int i=n;i<size();i++)t[i-n]=at(i);
    return s(f)+f.pow(n)*t(f);
  }
  static FPS differential(const FPS f){
    if(f.size()<=1)return FPS(0);
    FPS res(f.size()-1);
    REP_(i,f.size()-1)res[i]=(i+1)*f[i+1];
    return res;
  }
  static FPS integral(const FPS f){
    FPS res(f.size()+1,0);
    REP_(i,f.size())res[i+1]=f[i]/(i+1);
    res.strict(MX);
    return res;
  }
  static FPS log(const FPS f){
    assert(f[0]==1);
    return integral(differential(f)/f);
  }
  static FPS exp(const FPS f){
    assert(f[0]==0);
    FPS res(1,1);
    for(int i=0;(1<<i)<MX;i++)res*=f+T(1)-log(res);
    return res;
  }
};
#undef REP_

全方位木ライブラリ

//key_t:辺固有の情報 頂点固有の情報はラムダの方でキャプチャさせる
//sum_t:計算結果
template<typename key_t,typename sum_t=key_t>
struct ReRooting{
  struct Edge {
    int from,to;
    key_t cost;
    sum_t dp1,dp2;
    //dp1:部分木の内容
    //dp2:e=g[idx][i]の時、g[idx][0,i)のdp1の累積が入る
  };
  using F=function<sum_t(sum_t,sum_t)>;
  using G=function<sum_t(sum_t,Edge)>;
 
  vector<vector<Edge>> g;
  const F merge;
  const G score;
  const sum_t id;
  vector<sum_t> dp1,dp2;
  ReRooting(int n,const G &score,const F &merge,const sum_t &id=sum_t{})
    :g(n),merge(merge),score(score),id(id),dp1(n,id),dp2(n,id){}
 
  void add_edge(int u,int v,const key_t &c){
    add_arc(u,v,c);
    add_arc(v,u,c);
  }
  void add_arc(int u,int v,const key_t &c){
    g[u].emplace_back(Edge{u,v,c,id,id});
  }
private:
  void dfs1(int idx,int pre){
    for(auto &e:g[idx])if(e.to!=pre){
      dfs1(e.to,idx);
      e.dp1=score(dp1[e.to],e);
      dp1[idx]=merge(dp1[idx],e.dp1);
    }
  }
 
  void dfs2(int idx,int pre,sum_t top){
    sum_t now=id;
    for(auto &e:g[idx]){
      e.dp2=now;
      if(e.to==pre)e.dp1=score(top,e);
      now=merge(now,e.dp1);
    }
    dp2[idx]=now;
    now=id;
    reverse(g[idx].begin(),g[idx].end());
    for(auto &e:g[idx]){
      if(e.to!=pre)
        dfs2(e.to,idx,merge(e.dp2,now));
      now=merge(now,e.dp1);
    }
  }
public:
  vector<sum_t> build(){
    dfs1(0,-1);
    dfs2(0,-1,id);
    return dp2;
  }
};
//使用例:https://atcoder.jp/contests/abc222/submissions/26517686