1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
| #include<bits/stdc++.h> using namespace std; const int N = 100010; int n, m; struct Node{ int s[2], p, v; int sum, rev; }tr[N]; int stk[N]; void pushrev(int x) { swap(tr[x].s[0], tr[x].s[1]); tr[x].rev ^= 1; } void pushup(int x) { tr[x].sum = tr[tr[x].s[0]].sum ^ tr[x].v ^ tr[tr[x].s[1]].sum; } void pushdown(int x) { if(tr[x].rev) { pushrev(tr[x].s[0]), pushrev(tr[x].s[1]); tr[x].rev = 0; } } bool isroot(int x) { return tr[tr[x].p].s[0] != x && tr[tr[x].p].s[1] != x; } void rotate(int x) { int y = tr[x].p, z = tr[y].p; int k = tr[y].s[1] == x; if(!isroot(y)) tr[z].s[tr[z].s[1] == y] = x; tr[x].p = z; tr[y].s[k] = tr[x].s[k ^ 1], tr[tr[x].s[k ^ 1]].p = y; tr[x].s[k ^ 1] = y, tr[y].p = x; pushup(y), pushup(x); } void splay(int x) { int top = 0, r = x; stk[++ top] = r; while(!isroot(r)) stk[++ top] = r = tr[r].p; while(top) pushdown(stk[top --]); while(!isroot(x)) { int y = tr[x].p, z = tr[y].p; if(!isroot(y)) { if((tr[y].s[1] == x) ^ (tr[z].s[1] == y)) rotate(x); else rotate(y); } rotate(x); } } void access(int x) { int z = x; for(int y = 0; x; y = x, x = tr[x].p) { splay(x); tr[x].s[1] = y, pushup(x); } splay(z); } void make_root(int x) { access(x); pushrev(x); } int find_root(int x) { access(x); while(tr[x].s[0]) pushdown(x), x = tr[x].s[0]; splay(x); return x; } void split(int x,int y) { make_root(x); access(y); } void link(int x, int y) { make_root(x); if(find_root(y) != x) tr[x].p = y; } void cut(int x, int y) { make_root(x); if(find_root(y) == x && tr[y].p == x && !tr[y].s[0]) { tr[x].s[1] = tr[y].p = 0; pushup(x); } }
int main() { scanf("%d%d", &n, &m); for(int i = 1;i <= n;i ++) scanf("%d", &tr[i].v); while(m --) { int t, x, y; scanf("%d%d%d", &t, &x, &y); if(t == 0) { split(x, y); printf("%d\n", tr[y].sum); } else if(t == 1) link(x, y); else if(t == 2) cut(x, y); else { splay(x); tr[x].v = y; pushup(x); } } return 0; }
|