Definite Random Walks – HackerRank Solution Java , Python 3, Python 2 , C , C++, Best and Optimal Solutions , All you need.
Solutions of Algorithms Data Structures Hard HackerRank:
Here are all the Solutions of Hard , Advanced , Expert Algorithms of Data Structure of Hacker Rank , Leave a comment for similar posts
C++ Definite Random Walks HackerRank Solution
#pragma GCC diagnostic ignored "-Wunused-result"
#include <cstdio>
#include <vector>
#include <cassert>
#include <cstdlib>
#include <cmath>
#include <algorithm>
#include <ctime>
#include <random>
const int MOD = 998244353;
int modMul(int a, int b) {
assert(-MOD < a && a < MOD);
assert(-MOD < b && b < MOD);
return (int)((long long)a * b % MOD);
}
int modAdd(int a, int b) {
assert(-MOD < a && a < MOD);
assert(-MOD < b && b < MOD);
return (a + b) % MOD;
}
int modPow(int a, int n) {
assert(-MOD < a && a < MOD);
assert(n >= 0);
int r = 1;
while (n > 0) {
if ((n & 1) == 1) {
r = modMul(r, a);
}
a = modMul(a, a);
n >>= 1;
}
return r;
}
int modInv(int a) {
assert(-MOD < a && a < MOD && a != 0);
return modPow(a, MOD - 2);
}
std::vector<int> solveSlow(int nTurns, const std::vector<int> &next, const std::vector<int> &prob) {
int nPos = (int)next.size();
int nFaces = (int)prob.size();
std::vector<int> cur(nPos, modInv(nPos));
std::vector<int> prev(nPos, 0);
for (int t = 0; t < nTurns; t++) {
cur.swap(prev);
for (int i = 0; i < nPos; i++) {
cur[i] = 0;
}
for (int start = 0; start < nPos; start++) {
int pos = start;
for (int i = 0; i < nFaces; i++) {
cur[pos] = modAdd(cur[pos], modMul(prob[i], prev[start]));
pos = next[pos];
}
}
}
return cur;
}
void getPeriodAndPrePeriod(const std::vector<int> &next, std::vector<int> &period, std::vector<int> &prePeriod) {
int nPos = (int)next.size();
period.assign(nPos, -1);
prePeriod.assign(nPos, -1);
const int NOT_VISITED = -1;
const int VISITED = -2;
std::vector<int> status(nPos, NOT_VISITED);
for (int i = 0; i < nPos; i++) {
if (status[i] == NOT_VISITED) {
int cur = i;
int step = 0;
std::vector<int> path;
status[cur] = step;
step++;
path.push_back(cur);
cur = next[cur];
while (status[cur] == NOT_VISITED) {
status[cur] = step;
step++;
path.push_back(cur);
cur = next[cur];
}
if (status[cur] == VISITED) {
int pp = 1;
while (!path.empty()) {
int t = path.back();
path.pop_back();
status[t] = VISITED;
period[t] = period[cur];
prePeriod[t] = prePeriod[cur] + pp;
pp++;
}
} else {
assert(status[cur] >= 0);
period[cur] = step - status[cur];
prePeriod[cur] = 0;
while (true) {
int t = path.back();
path.pop_back();
status[t] = VISITED;
if (t == cur) {
break;
}
period[t] = period[cur];
prePeriod[t] = 0;
}
int pp = 1;
while (!path.empty()) {
int t = path.back();
path.pop_back();
status[t] = VISITED;
period[t] = period[cur];
prePeriod[t] = pp;
pp++;
}
}
}
}
}
inline bool isPow2(int n) {
return n > 0 && (n & (n - 1)) == 0;
}
const int MAX_LEVEL = 19;
int sinCos[1 + MAX_LEVEL][1 << (MAX_LEVEL - 1)];
// https://en.wikipedia.org/wiki/Cooley%E2%80%93Tukey_FFT_algorithm#Pseudocode
void fft(const std::vector<int> &src, const int srcPos, const int n, const int logN, const int srcStride, std::vector<int> &dst, int &dstPos) {
if (n == 1) {
dst[dstPos] = src[srcPos];
dstPos++;
} else {
int dstStart = dstPos;
fft(src, srcPos, n >> 1, logN - 1, srcStride << 1, dst, dstPos);
fft(src, srcPos + srcStride, n >> 1, logN - 1, srcStride << 1, dst, dstPos);
assert(dstPos == dstStart + n);
for (int i = 0; i < (n >> 1); i++) {
assert(1 <= logN && logN <= MAX_LEVEL);
assert(0 <= i && i < (1 << (MAX_LEVEL - 1)));
int m = sinCos[logN][i];
int index1 = dstStart + i;
int index2 = dstStart + i + (n >> 1);
int t = dst[index1];
dst[index1] = int((t + 1LL * m * dst[index2]) % MOD);
dst[index2] = int((t - 1LL * m * dst[index2]) % MOD);
}
}
}
void fft(const std::vector<int> &src, std::vector<int> &dst) {
if (sinCos[MAX_LEVEL][0] == 0) {
sinCos[MAX_LEVEL][0] = 1;
for (int i = 1; i < (1 << (MAX_LEVEL - 1)); i++) {
sinCos[MAX_LEVEL][i] = modMul(sinCos[MAX_LEVEL][i - 1], MOD - 121229115);
}
for (int level = MAX_LEVEL - 1; level >= 1; level--) {
for (int i = 0; i < (1 << (level - 1)); i++) {
sinCos[level][i] = sinCos[level + 1][i << 1];
}
}
}
assert(isPow2((int)src.size()));
assert(&src != &dst);
dst.resize(src.size(), 0);
int dstPos = 0;
fft(src, 0, (int)src.size(), (int)std::round(std::log2((double)src.size())), 1, dst, dstPos);
}
void inverseFft(const std::vector<int> &src, std::vector<int> &dst) {
fft(src, dst);
int inv = modInv((int)dst.size());
for (int i = 0; i < (int)dst.size(); i++) {
dst[i] = modMul(dst[i], inv);
}
std::reverse(dst.begin() + 1, dst.end());
}
void mulPoly(const std::vector<int> &a, const std::vector<int> &b, std::vector<int> &res) {
assert(a.size() == b.size());
std::vector<int> fa;
fft(a, fa);
std::vector<int> fb;
fft(b, fb);
std::vector<int> fr(a.size());
for (int i = 0; i < (int)a.size(); i++) {
fr[i] = modMul(fa[i], fb[i]);
}
inverseFft(fr, res);
}
std::vector<int> mul(const std::vector<int> &left, const std::vector<int> &right) {
// naive nPos=5k -> 2.42s
// fft nPos=5k -> 0.21s
// naive nPos=1k -> 0.1s
// fft nPos=1k -> 0.11s
if (left.size() * right.size() > 1000 * 1000) {
int size = int(left.size() + right.size() - 1);
while (!isPow2(size)) {
size++;
}
std::vector<int> newLeft = left;
std::vector<int> newRight = right;
newLeft.resize(size);
newRight.resize(size);
std::vector<int> res(size);
mulPoly(newLeft, newRight, res);
return res;
}
std::vector<int> res(left.size() + right.size() - 1, 0);
for (int iL = 0; iL < (int)left.size(); iL++) {
for (int iR = 0; iR < (int)right.size(); iR++) {
res[iL + iR] = modAdd(res[iL + iR], modMul(left[iL], right[iR]));
}
}
return res;
}
std::vector<int> wrappedMul(const std::vector<int> &left, const std::vector<int> &right, int prePeriod, int period) {
assert((int)left.size() == prePeriod + period);
assert((int)right.size() == prePeriod + period);
std::vector<int> res = mul(left, right);
for (int i = prePeriod + period; i < (int)res.size(); i++) {
int dst = (i - prePeriod) % period + prePeriod;
res[dst] = modAdd(res[dst], res[i]);
}
assert((int)res.size() >= prePeriod + period);
res.resize(prePeriod + period);
return res;
}
std::vector<int> wrappedPow(const std::vector<int> &prob, int prePeriod, int period, int nTurns) {
assert(nTurns >= 1);
if (nTurns == 1) {
return prob;
} else if (nTurns % 2 == 1) {
return wrappedMul(prob, wrappedPow(prob, prePeriod, period, nTurns - 1), prePeriod, period);
} else {
return wrappedPow(wrappedMul(prob, prob, prePeriod, period), prePeriod, period, nTurns / 2);
}
}
void topologicalSort(std::vector<int> &next, std::vector<int> &perm) {
int nPos = (int)next.size();
assert(nPos == (int)perm.size());
int cur = nPos - 1;
perm.assign(nPos, -1);
std::vector<int> stack;
for (int i = 0; i < nPos; i++) {
if (perm[i] == -1) {
int t = i;
while (perm[t] == -1) {
perm[t] = -2;
stack.push_back(t);
t = next[t];
}
while (!stack.empty()) {
t = stack.back();
stack.pop_back();
assert(perm[t] == -2);
assert(cur >= 0);
perm[t] = cur;
cur--;
}
}
}
assert(cur == -1);
std::vector<int> newNext(nPos);
for (int i = 0; i < nPos; i++) {
newNext[perm[i]] = perm[next[i]];
}
next = newNext;
}
void topologicalUnsort(std::vector<int> &ans, std::vector<int> &perm) {
int nPos = (int)ans.size();
assert(nPos == (int)perm.size());
std::vector<int> newAns(nPos);
for (int i = 0; i < nPos; i++) {
newAns[i] = ans[perm[i]];
}
ans = newAns;
}
std::vector<int> solveFastSmall(int nTurns, std::vector<int> next, const std::vector<int> &prob) {
int nPos = (int)next.size();
std::vector<int> perm(nPos);
topologicalSort(next, perm); // to make it processor cache friendly
int nFaces = (int)prob.size();
std::vector<int> period;
std::vector<int> prePeriod;
getPeriodAndPrePeriod(next, period, prePeriod);
std::vector<int> periodToMaxPrePeriod(1 + nPos, -1);
for (int i = 0; i < nPos; i++) {
int p = period[i];
assert(p > 0);
int pp = prePeriod[i];
assert(pp >= 0);
if (pp > periodToMaxPrePeriod[p]) {
periodToMaxPrePeriod[p] = pp;
}
}
assert(periodToMaxPrePeriod[0] == -1);
std::vector<long long> ansL(nPos, 0);
for (int p = 1; p <= nPos; p++) {
int pp = periodToMaxPrePeriod[p];
if (pp == -1) {
continue;
}
// 350 different periods -> nPos = (1 + 350) * 350 / 2 = 61 425 -> in this problem less than 350 different periods
std::vector<int> wrappedProb(pp + p, 0);
for (int i = 0; i < nFaces; i++) {
if (i < pp) {
wrappedProb[i] = prob[i];
} else {
int dst = (i - pp) % p + pp;
wrappedProb[dst] = modAdd(wrappedProb[dst], prob[i]);
}
}
wrappedProb = wrappedPow(wrappedProb, pp, p, nTurns);
// great loop or great chain - let's fill in one swoop.
std::vector<int> ones(pp + p, 1);
std::vector<int> wrappedProbCum = wrappedMul(wrappedProb, ones, pp, p);
std::vector<bool> wasStart(nPos, false);
int start = -1;
for (int i = 0; i < nPos; i++) {
if (period[i] == p && prePeriod[i] == pp) {
start = i;
break;
}
}
assert(start >= 0);
for (int i = 0; i < (int)wrappedProbCum.size(); i++) {
assert(!wasStart[start]);
wasStart[start] = true;
ansL[start] += wrappedProbCum[i];
start = next[start];
}
for (int i = 0; i < nPos; i++) {
if (period[i] == p && !wasStart[i]) {
int pos = i;
for (int j = 0; j < (int)wrappedProb.size(); j++) {
ansL[pos] += wrappedProb[j];
pos = next[pos];
}
}
}
}
std::vector<int> ansI(nPos);
for (int i = 0; i < nPos; i++) {
ansI[i] = int(ansL[i] % MOD);
}
int startValue = modInv(nPos);
for (int i = 0; i < nPos; i++) {
ansI[i] = modMul(ansI[i], startValue);
}
topologicalUnsort(ansI, perm);
return ansI;
}
void maxTest() {
int nPos = 60000;
int nFaces = 100000;
int nTurns = 1000;
std::vector<int> next(nPos);
for (int i = 0; i < nPos; i++) {
next[i] = (i + 1) % nPos;
}
std::vector<int> prob(nFaces, modInv(nFaces));
std::vector<int> res = solveFastSmall(nTurns, next, prob);
assert(res[0] != -1);
printf("maxTest() done\n");
// Naive poly multiplication: nPos=10k -> 10s. So 60k ~ 360s.
// FFT nPos=60k -> 23s
// multiplication in the end -> 14.67s
// addition without % -> 6.7s (FFT just 0.32s, everything else in just adding result)
// with max path elimination -> 0.32s (another worst case);
std::exit(0);
}
void maxTest2() {
clock_t start = clock();
int nPos = 60000;
int nFaces = 100000;
int nTurns = 1000;
std::vector<int> next(nPos);
for (int i = 0; i < nPos; i++) {
next[i] = (i + 1) % (nPos / 2) + i / (nPos / 2) * (nPos / 2);
// printf("next[%d] = %d\n", i, next[i]);
}
{
std::vector<int> perm(nPos);
for (int i = 0; i < nPos; i++) {
perm[i] = i;
}
std::mt19937 gen;
std::shuffle(perm.begin(), perm.end(), gen);
std::vector<int> next2(nPos);
for (int i = 0; i < nPos; i++) {
next2[perm[i]] = perm[next[i]];
}
next = next2;
}
std::vector<int> prob(nFaces, modInv(nFaces));
std::vector<int> res = solveFastSmall(nTurns, next, prob);
assert(res[0] != -1);
printf("maxTest2() done %.2f\n", 1.0 * (clock() - start) / CLOCKS_PER_SEC);
// i7: 1.66s
// HR: 1.93s Time Limit is 4s. This test is not the worst because some get tle.
// shuffled i7: 8.97s
// shuffled + topSorted i7: 1.67
std::exit(0);
}
int main() {
// maxTest();
// maxTest2();
int nPos, nFaces, nTurns;
scanf("%d %d %d", &nPos, &nFaces, &nTurns);
std::vector<int> next(nPos);
for (int i = 0; i < nPos; i++) {
scanf("%d", &next[i]);
next[i]--;
}
std::vector<int> prob(nFaces);
for (int i = 0; i < nFaces; i++) {
scanf("%d", &prob[i]);
}
// std::vector<int> res = solveSlow(nTurns, next, prob);
std::vector<int> res = solveFastSmall(nTurns, next, prob);
for (int x : res) {
printf("%d\n", (x + MOD) % MOD);
}
return 0;
}
Java Definite Random Walks HackerRank Solution
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.HashMap;
import java.util.InputMismatchException;
import java.util.Map;
public class G2 {
InputStream is;
PrintWriter out;
String INPUT = "";
// String INPUT = "7 3 1\r\n" +
// "4 2 7 2 4 3 5 \r\n" +
// "0 1 0 ";
// String INPUT = "4 2 1\r\n" +
// "1 4 2 3 \r\n" +
// "748683265 249561089";
// 1 2->4->3->2
// String INPUT = "4 2 2\r\n" +
// "2 3 1 3 \r\n" +
// "0 1 ";
// 1/4 1/4 1/2
// 1/2 1/4 1/4
// String INPUT = "3 1 2\r\n" +
// "1 3 1 \r\n" +
// "1 ";
// String INPUT = "4 5 1\r\n" +
// "2 3 2 4\r\n" +
// "332748118 332748118 332748118 0 0";
void solve()
{
int n = ni(), m = ni(), K = ni();
int[] f = na(n);
for(int i = 0;i < n;i++)f[i]--;
long[] ps = new long[m];
for(int i = 0;i < m;i++)ps[i] = nl();
int mod = 998244353;
// tr("phase 0");
long[] made = make(ps, K, n+1, 1);
int H = (int)Math.sqrt(n)*8; // naive height limit
int B = (int)Math.sqrt(n)*8; // cycle split period
SplitResult sres = split(f);
int[] tclus = new int[n];
Arrays.fill(tclus, -1);
for(int i = n-1;i >= 0;i--){
int cur = sres.ord[i];
if(sres.incycle[cur]){
tclus[cur] = cur;
}else{
tclus[cur] = tclus[f[cur]];
}
}
// tr("phase 1");
long[] rets = new long[n];
int[][] maps = makeBuckets(tclus, n);
for(int i = 0;i < n;i++){
if(maps[i].length > 0){
int[] map = maps[i];
int[] lpar = new int[map.length];
int p = 0;
for(int x : maps[i]){
if(sres.incycle[x]){
lpar[p++] = -1;
}else{
lpar[p++] = Arrays.binarySearch(map, f[x]);
}
}
long[] res = solve(parentToG(lpar), lpar, made, H, Arrays.binarySearch(map, i));
for(int j = 0;j < res.length;j++){
if(!sres.incycle[map[j]]){
rets[map[j]] += res[j];
}
}
}
}
// tr("phase 2");
int[] maxdep = new int[n];
for(int i = 0;i < n;i++){
int cur = sres.ord[i];
if(!sres.incycle[cur]){
maxdep[f[cur]] = Math.max(maxdep[f[cur]], maxdep[cur]+1);
}
}
int[] tdep = new int[n];
for(int i = n-1;i >= 0;i--){
int cur = sres.ord[i];
if(!sres.incycle[cur]){
tdep[cur] = tdep[f[cur]]+1;
}
}
// tr("phase 3");
// for(int j = made.length-1-1;j >= 0;j--){
// made[j] += made[j+1];
// if(made[j] >= mod)made[j] -= mod;
// }
boolean[] ved = new boolean[n];
int[] cycle = new int[n];
Map<Long, long[]> cache = new HashMap<>();
for(int i = 0;i < n;i++){
if(sres.incycle[i] && !ved[i]){
int p = 0;
ved[i] = true;
cycle[p++] = i;
int lmaxdep = maxdep[i];
for(int j = f[i];!ved[j];j = f[j]){
ved[j] = true;
cycle[p++] = j;
lmaxdep = Math.max(lmaxdep, maxdep[j]);
}
int tail = lmaxdep+p+1;
int fp = p;
long[] di = cache.computeIfAbsent((long)tail<<32|fp, (z) -> {
long[] res = make(ps, K, tail, fp);
for(int j = res.length-fp-1;j >= 0;j--){
res[j] += res[j+fp];
if(res[j] >= mod)res[j] -= mod;
}
return res;
});
// 0.1 0.2 0.3 0.4
// 0.4 0.6
// 0.1 0.6 0.3
// 0.1 0.2 0.3 0.4
// 0.4 0.6
// 0.6 0.3
// 0.3 0.4
if(p <= B){
for(int j = 0;j < p;j++){
for(int v : maps[cycle[j]]){
for(int k = tdep[v], l = j;k < tdep[v]+p;k++,l++){
if(l == p)l = 0;
rets[cycle[l]] += di[k];
}
}
}
}else{
inputed = null;
for(int b = 0;b < p;b+=B){
long[] ents = new long[tail+1];
for(int j = 0;j < b;j++){
for(int v : maps[cycle[j]]){
ents[tail-(b-j)-tdep[v]]++;
}
}
for(int j = b+B;j < p;j++){
for(int v : maps[cycle[j]]){
ents[tail-b-(p-j)-tdep[v]]++;
}
}
long[] ced = convoluteSimply(ents, di, mod, 3);
inputed = saved;
for(int k = b;k < p && k < b+B;k++){
rets[cycle[k]] += ced[tail+k-b];
}
}
inputed = null;
// remainder
for(int j = 0;j < p;j++){
for(int v : maps[cycle[j]]){
for(int k = tdep[v], l = j;k < tdep[v]+p && l < p && l < j/B*B+B;k++,l++){
rets[cycle[l]] += di[k];
}
for(int k = tdep[v]+p-j+j/B*B, l = j/B*B;k < tdep[v]+p && l < j;k++,l++){
rets[cycle[l]] += di[k];
}
}
}
}
}
}
// tr("phase 4");
long RN = invl(n, mod);
for(long ret : rets){
out.println(ret%mod*RN%mod);
}
}
int mod = 998244353;
long[] make(long[] ps, int K, int tail, int period)
{
long[] ms = ps;
if(ps.length > tail+period){
ms = Arrays.copyOf(ps, tail+period);
for(int j = tail+period, k = tail;j < ps.length;j++,k++){
if(k == tail+period)k -= period;
ms[k] += ps[j];
if(ms[k] >= mod)ms[k] -= mod;
}
}
long[] pps = new long[1];
pps[0] = 1;
for(int i = 0;1<<i <= K;i++){
if(K<<~i<0){
long[] res = convoluteSimply(pps, ms, mod, 3);
for(int j = res.length-1-period;j >= tail;j--){
res[j] += res[j+period];
res[j+period] = 0;
if(res[j] >= mod)res[j] -= mod;
}
pps = Arrays.copyOf(res, Math.min(tail+period, pps.length+ms.length));
}
if(1<<i+1 <= K){
long[] res = convoluteSimply(ms, ms, mod, 3);
for(int j = res.length-1-period;j >= tail;j--){
res[j] += res[j+period];
res[j+period] = 0;
if(res[j] >= mod)res[j] -= mod;
}
ms = Arrays.copyOf(res, Math.min(tail+period, ms.length+ms.length));
}
}
if(pps.length < tail+period)pps = Arrays.copyOf(pps, tail+period);
return pps;
}
long[] solve(int[][] g, int[] par, long[] di, int H, int root)
{
int n = g.length;
long[] ret = new long[n];
int[][] pars = parents3(g, root);
int[] des = new int[n];
long[] ws = new long[n];
int[] ord = pars[1], dep = pars[2];
int[] marked = new int[n];
for(int i = n-1;i >= 0;i--){
int cur = ord[i];
des[cur]++;
ws[cur] += des[cur];
if(marked[cur] == 0 && ws[cur] > (long)H*des[cur]){
marked[cur] = 1;
}
if(i > 0){
des[par[cur]] += des[cur];
ws[par[cur]] += ws[cur];
if(marked[cur] >= 1)marked[par[cur]] = 2;
}
}
// tr(g, root);
// tr(marked);
// large
// marked node
for(int i = 0;i < n;i++){
if(marked[i] == 1){
int[] fdep = new int[n];
collect(i, par[i], g, dep, fdep);
for(int j = par[i];j != -1 && marked[j] == 2;j = par[j]){
fdep[dep[j]]++;
marked[j] = 3;
}
int lmaxdep = n;
for(int j = n-1;j >= 0;j--){
if(fdep[j] > 0){
lmaxdep = j;
break;
}
}
long[] rfdep = new long[lmaxdep+1];
for(int j = 0;j <= lmaxdep;j++){
rfdep[lmaxdep-j] = fdep[j];
}
// tr("fdep", fdep, marked);
long[] ced = convoluteSimply(rfdep, Arrays.copyOf(di, lmaxdep+1), mod, 3);
for(int j = i;j != -1;j = par[j]){
ret[j] += ced[lmaxdep-dep[j]];
}
}
}
// small
for(int i = 0;i < n;i++){
if(marked[i] == 0){
for(int j = i;j != -1 && marked[j] != 1;j = par[j]){
ret[j] += di[dep[i]-dep[j]];
}
}
}
for(int i = 0;i < n;i++){
ret[i] %= mod;
}
return ret;
}
void collect(int cur, int par, int[][] g, int[] dep, int[] fdep)
{
fdep[dep[cur]]++;
for(int e : g[cur]){
if(e != par)collect(e, cur, g, dep, fdep);
}
}
// library
public static int[][] parents3(int[][] g, int root) {
int n = g.length;
int[] par = new int[n];
Arrays.fill(par, -1);
int[] depth = new int[n];
depth[0] = 0;
int[] q = new int[n];
q[0] = root;
for (int p = 0, r = 1; p < r; p++) {
int cur = q[p];
for (int nex : g[cur]) {
if (par[cur] != nex) {
q[r++] = nex;
par[nex] = cur;
depth[nex] = depth[cur] + 1;
}
}
}
return new int[][] { par, q, depth };
}
public static final int[] NTTPrimes = {1053818881, 1051721729, 1045430273, 1012924417, 1007681537, 1004535809, 998244353, 985661441, 976224257, 975175681};
public static final int[] NTTPrimitiveRoots = {7, 6, 3, 5, 3, 3, 3, 3, 3, 17};
// public static final int[] NTTPrimes = {1012924417, 1004535809, 998244353, 985661441, 975175681, 962592769, 950009857, 943718401, 935329793, 924844033};
// public static final int[] NTTPrimitiveRoots = {5, 3, 3, 3, 17, 7, 7, 7, 3, 5};
static long[] inputed;
static long[] saved;
public static long[] convoluteSimply(long[] a, long[] b, int P, int g)
{
int m = Math.max(2, Integer.highestOneBit(Math.max(a.length, b.length)-1)<<2);
long[] fa = nttmb(a, m, false, P, g);
long[] fb = a == b ? fa : inputed != null ? inputed : nttmb(b, m, false, P, g);
saved = fb;
for(int i = 0;i < m;i++){
fa[i] = fa[i]*fb[i]%P;
}
return nttmb(fa, m, true, P, g);
}
public static long[] convolute(long[] a, long[] b)
{
int USE = 2;
int m = Math.max(2, Integer.highestOneBit(Math.max(a.length, b.length)-1)<<2);
long[][] fs = new long[USE][];
for(int k = 0;k < USE;k++){
int P = NTTPrimes[k], g = NTTPrimitiveRoots[k];
long[] fa = nttmb(a, m, false, P, g);
long[] fb = a == b ? fa : nttmb(b, m, false, P, g);
for(int i = 0;i < m;i++){
fa[i] = fa[i]*fb[i]%P;
}
fs[k] = nttmb(fa, m, true, P, g);
}
int[] mods = Arrays.copyOf(NTTPrimes, USE);
long[] gammas = garnerPrepare(mods);
int[] buf = new int[USE];
for(int i = 0;i < fs[0].length;i++){
for(int j = 0;j < USE;j++)buf[j] = (int)fs[j][i];
long[] res = garnerBatch(buf, mods, gammas);
long ret = 0;
for(int j = res.length-1;j >= 0;j--)ret = ret * mods[j] + res[j];
fs[0][i] = ret;
}
return fs[0];
}
public static long[] convolute(long[] a, long[] b, int USE, int mod)
{
int m = Math.max(2, Integer.highestOneBit(Math.max(a.length, b.length)-1)<<2);
long[][] fs = new long[USE][];
for(int k = 0;k < USE;k++){
int P = NTTPrimes[k], g = NTTPrimitiveRoots[k];
long[] fa = nttmb(a, m, false, P, g);
long[] fb = a == b ? fa : nttmb(b, m, false, P, g);
for(int i = 0;i < m;i++){
fa[i] = fa[i]*fb[i]%P;
}
fs[k] = nttmb(fa, m, true, P, g);
}
int[] mods = Arrays.copyOf(NTTPrimes, USE);
long[] gammas = garnerPrepare(mods);
int[] buf = new int[USE];
for(int i = 0;i < fs[0].length;i++){
for(int j = 0;j < USE;j++)buf[j] = (int)fs[j][i];
long[] res = garnerBatch(buf, mods, gammas);
long ret = 0;
for(int j = res.length-1;j >= 0;j--)ret = (ret * mods[j] + res[j]) % mod;
fs[0][i] = ret;
}
return fs[0];
}
// static int[] wws = new int[270000]; // outer faster
// Modifed Montgomery + Barrett
private static long[] nttmb(long[] src, int n, boolean inverse, int P, int g)
{
long[] dst = Arrays.copyOf(src, n);
int h = Integer.numberOfTrailingZeros(n);
long K = Integer.highestOneBit(P)<<1;
int H = Long.numberOfTrailingZeros(K)*2;
long M = K*K/P;
int[] wws = new int[1<<h-1];
long dw = inverse ? pow(g, P-1-(P-1)/n, P) : pow(g, (P-1)/n, P);
long w = (1L<<32)%P;
for(int k = 0;k < 1<<h-1;k++){
wws[k] = (int)w;
w = modh(w*dw, M, H, P);
}
long J = invl(P, 1L<<32);
for(int i = 0;i < h;i++){
for(int j = 0;j < 1<<i;j++){
for(int k = 0, s = j<<h-i, t = s|1<<h-i-1;k < 1<<h-i-1;k++,s++,t++){
long u = (dst[s] - dst[t] + 2*P)*wws[k];
dst[s] += dst[t];
if(dst[s] >= 2*P)dst[s] -= 2*P;
// long Q = (u&(1L<<32)-1)*J&(1L<<32)-1;
long Q = (u<<32)*J>>>32;
dst[t] = (u>>>32)-(Q*P>>>32)+P;
}
}
if(i < h-1){
for(int k = 0;k < 1<<h-i-2;k++)wws[k] = wws[k*2];
}
}
for(int i = 0;i < n;i++){
if(dst[i] >= P)dst[i] -= P;
}
for(int i = 0;i < n;i++){
int rev = Integer.reverse(i)>>>-h;
if(i < rev){
long d = dst[i]; dst[i] = dst[rev]; dst[rev] = d;
}
}
if(inverse){
long in = invl(n, P);
for(int i = 0;i < n;i++)dst[i] = modh(dst[i]*in, M, H, P);
}
return dst;
}
// Modified Shoup + Barrett
private static long[] nttsb(long[] src, int n, boolean inverse, int P, int g)
{
long[] dst = Arrays.copyOf(src, n);
int h = Integer.numberOfTrailingZeros(n);
long K = Integer.highestOneBit(P)<<1;
int H = Long.numberOfTrailingZeros(K)*2;
long M = K*K/P;
long dw = inverse ? pow(g, P-1-(P-1)/n, P) : pow(g, (P-1)/n, P);
long[] wws = new long[1<<h-1];
long[] ws = new long[1<<h-1];
long w = 1;
for(int k = 0;k < 1<<h-1;k++){
wws[k] = (w<<32)/P;
ws[k] = w;
w = modh(w*dw, M, H, P);
}
for(int i = 0;i < h;i++){
for(int j = 0;j < 1<<i;j++){
for(int k = 0, s = j<<h-i, t = s|1<<h-i-1;k < 1<<h-i-1;k++,s++,t++){
long ndsts = dst[s] + dst[t];
if(ndsts >= 2*P)ndsts -= 2*P;
long T = dst[s] - dst[t] + 2*P;
long Q = wws[k]*T>>>32;
dst[s] = ndsts;
dst[t] = ws[k]*T-Q*P&(1L<<32)-1;
}
}
// dw = dw * dw % P;
if(i < h-1){
for(int k = 0;k < 1<<h-i-2;k++){
wws[k] = wws[k*2];
ws[k] = ws[k*2];
}
}
}
for(int i = 0;i < n;i++){
if(dst[i] >= P)dst[i] -= P;
}
for(int i = 0;i < n;i++){
int rev = Integer.reverse(i)>>>-h;
if(i < rev){
long d = dst[i]; dst[i] = dst[rev]; dst[rev] = d;
}
}
if(inverse){
long in = invl(n, P);
for(int i = 0;i < n;i++){
dst[i] = modh(dst[i] * in, M, H, P);
}
}
return dst;
}
static final long mask = (1L<<31)-1;
public static long modh(long a, long M, int h, int mod)
{
long r = a-((M*(a&mask)>>>31)+M*(a>>>31)>>>h-31)*mod;
return r < mod ? r : r-mod;
}
private static long[] garnerPrepare(int[] m)
{
int n = m.length;
assert n == m.length;
if(n == 0)return new long[0];
long[] gamma = new long[n];
for(int k = 1;k < n;k++){
long prod = 1;
for(int i = 0;i < k;i++){
prod = prod * m[i] % m[k];
}
gamma[k] = invl(prod, m[k]);
}
return gamma;
}
private static long[] garnerBatch(int[] u, int[] m, long[] gamma)
{
int n = u.length;
assert n == m.length;
long[] v = new long[n];
v[0] = u[0];
for(int k = 1;k < n;k++){
long temp = v[k-1];
for(int j = k-2;j >= 0;j--){
temp = (temp * m[j] + v[j]) % m[k];
}
v[k] = (u[k] - temp) * gamma[k] % m[k];
if(v[k] < 0)v[k] += m[k];
}
return v;
}
private static long pow(long a, long n, long mod) {
// a %= mod;
long ret = 1;
int x = 63 - Long.numberOfLeadingZeros(n);
for (; x >= 0; x--) {
ret = ret * ret % mod;
if (n << 63 - x < 0)
ret = ret * a % mod;
}
return ret;
}
private static long invl(long a, long mod) {
long b = mod;
long p = 1, q = 0;
while (b > 0) {
long c = a / b;
long d;
d = a;
a = b;
b = d % b;
d = p;
p = q;
q = d - c * q;
}
return p < 0 ? p + mod : p;
}
public static int[][] parentToG(int[] par)
{
int n = par.length;
int[] ct = new int[n];
for(int i = 0;i < n;i++){
if(par[i] >= 0){
ct[i]++;
ct[par[i]]++;
}
}
int[][] g = new int[n][];
for(int i = 0;i < n;i++){
g[i] = new int[ct[i]];
}
for(int i = 0;i < n;i++){
if(par[i] >= 0){
g[par[i]][--ct[par[i]]] = i;
g[i][--ct[i]] = par[i];
}
}
return g;
}
public static int[][] makeBuckets(int[] a, int sup)
{
int n = a.length;
int[][] bucket = new int[sup+1][];
int[] bp = new int[sup+1];
for(int i = 0;i < n;i++)bp[a[i]]++;
for(int i = 0;i <= sup;i++)bucket[i] = new int[bp[i]];
for(int i = n-1;i >= 0;i--)bucket[a[i]][--bp[a[i]]] = i;
return bucket;
}
public static class SplitResult
{
public boolean[] incycle;
public int[] ord;
}
public static SplitResult split(int[] f)
{
int n = f.length;
boolean[] incycle = new boolean[n];
Arrays.fill(incycle, true);
int[] indeg = new int[n];
for(int i = 0;i < n;i++)indeg[f[i]]++;
int[] q = new int[n];
int qp = 0;
for(int i = 0;i < n;i++){
if(indeg[i] == 0)q[qp++] = i;
}
for(int r = 0;r < qp;r++){
int cur = q[r];
indeg[cur] = -9999999;
incycle[cur] = false;
int e = f[cur];
indeg[e]--;
if(indeg[e] == 0)q[qp++] = e;
}
for(int i = 0;i < n;i++){
if(indeg[i] == 1){
q[qp++] = i;
}
}
assert qp == n;
SplitResult ret = new SplitResult();
ret.incycle = incycle;
ret.ord = q;
return ret;
}
void run() throws Exception
{
is = INPUT.isEmpty() ? System.in : new ByteArrayInputStream(INPUT.getBytes());
out = new PrintWriter(System.out);
long s = System.currentTimeMillis();
solve();
out.flush();
if(!INPUT.isEmpty())tr(System.currentTimeMillis()-s+"ms");
}
public static void main(String[] args) throws Exception { new G2().run(); }
private byte[] inbuf = new byte[1024];
public int lenbuf = 0, ptrbuf = 0;
private int readByte()
{
if(lenbuf == -1)throw new InputMismatchException();
if(ptrbuf >= lenbuf){
ptrbuf = 0;
try { lenbuf = is.read(inbuf); } catch (IOException e) { throw new InputMismatchException(); }
if(lenbuf <= 0)return -1;
}
return inbuf[ptrbuf++];
}
private boolean isSpaceChar(int c) { return !(c >= 33 && c <= 126); }
private int skip() { int b; while((b = readByte()) != -1 && isSpaceChar(b)); return b; }
private double nd() { return Double.parseDouble(ns()); }
private char nc() { return (char)skip(); }
private String ns()
{
int b = skip();
StringBuilder sb = new StringBuilder();
while(!(isSpaceChar(b))){ // when nextLine, (isSpaceChar(b) && b != ' ')
sb.appendCodePoint(b);
b = readByte();
}
return sb.toString();
}
private char[] ns(int n)
{
char[] buf = new char[n];
int b = skip(), p = 0;
while(p < n && !(isSpaceChar(b))){
buf[p++] = (char)b;
b = readByte();
}
return n == p ? buf : Arrays.copyOf(buf, p);
}
private char[][] nm(int n, int m)
{
char[][] map = new char[n][];
for(int i = 0;i < n;i++)map[i] = ns(m);
return map;
}
private int[] na(int n)
{
int[] a = new int[n];
for(int i = 0;i < n;i++)a[i] = ni();
return a;
}
private int ni()
{
int num = 0, b;
boolean minus = false;
while((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'));
if(b == '-'){
minus = true;
b = readByte();
}
while(true){
if(b >= '0' && b <= '9'){
num = num * 10 + (b - '0');
}else{
return minus ? -num : num;
}
b = readByte();
}
}
private long nl()
{
long num = 0;
int b;
boolean minus = false;
while((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'));
if(b == '-'){
minus = true;
b = readByte();
}
while(true){
if(b >= '0' && b <= '9'){
num = num * 10 + (b - '0');
}else{
return minus ? -num : num;
}
b = readByte();
}
}
private static void tr(Object... o) { System.out.println(Arrays.deepToString(o)); }
}
Warmup
Implementation
Strings
Sorting
Search
Graph Theory
Greedy
Dynamic Programming
Constructive Algorithms
Bit Manipulation
Recursion
Game Theory
NP Complete
Debugging
Leave a comment below