poj1822-食物链-(带权并查集)

发布于 2019-08-06  935 次阅读


问题传送门:

题目:

有三类动物A,B,C,三类动物的食物链构成了环形。A吃B, B吃C,C吃A。 
现有N个动物,以1-N编号。每个动物都是A,B,C中的一种,但并不知道底是哪一种。
有两种说法对 X 和 Y 两个动物的关系进行描述: 
"1 X Y",表示 X 和 Y 是同类。 
"2 X Y",表示 X 吃 Y。 
 现在有 K 句描述,这 K 句话有的是真的,有的是假的。 
  1. 当前的话与前面的某些真的话冲突,就是假话; 
  2. 当前的话中 X 或 Y 比 N 大,就是假话; 
  3. 当前的话表示 X 吃 X,就是假话。 
 N(1 <= N <= 50,000)(0 <= K <= 100,000),输出假话的总数。  

分析:

首先不能用cin 输入,因为K太大了。
这是一道很经典的带权并查集,刚开始不知道这样做,想着用并查集
把每个物种放在一起。但是写不出来。~~~
带权并查集:不单单是储存了集合,而且还记录了集合中节点之间的关系。
更多的是记录每个节点与根节点之间的关系。
因为有路经压缩这个过程,把每个节点与根节点直接连接。
利用这个过程,计算出每个节点与根节点枝间的关系。
进而通过根节点,计算出集合中任意两个节点的关系。
对于这个题来说,任意两个节点之间只有三种关系:
同类,X吃Y,X被Y吃。可以用数字rel[i]
0     1      2     分别代表这三种关系。
然后用数组 rel[i] 来表示 i 到父节点的关系。
初始化时rel[i]=0; 表示自己和自己同类。
然后发现,当:
1 X Y   d = 1, rel[x] = 0;
2 X Y   d = 2,  rel[x] = 1;
(Y看作父节点)
~~rel[x] + 1 = d~~
判断真假话:
对于题目中的 2 与 3 条件直接判断,是假话就 res++;
重要的是对于条件 1:
当我们输入X, Y时,先分别求根节点
  1. 若两者根节点相同,说明已经赋予X,Y关系了,只需要判断题中给的关系 是否符合已有的关系即可:(根据向量关系)

$$则x到y的关系是:(3 + rel[x] -rel[y] ) % 3$$

    2.若两者根节点不相同,说明我们还没有赋予x,y关系。把X和Y所在的子树合并。
      这里有个细节问题,这里如果 r[ra] = rb,需要更新 ra 到 rb 的关系。
      即更新 rel[ra].
      根据上面的公式,可以得出:

$$rel[a] + rel[ra] - rel[b] = d-1$$

$$rel[ra] = (d - 1 + rel[b] - rel[a] + 3) % 3$$

在路经压缩中更新关系:
要把每一个正确的值 与 根结点连接,用rel保存关系。
路经压缩是个递归的过程,在这个过程中,可以:
int t = r[x];
r[x] = Find(r[x]);
rel[x] = (rel[x] + rel[t]) % 3;
递归更新关系。

代码:

#include <iostream>
#include <cstdio>

using namespace std;

const int MAXN = 5e4+10;
int r[MAXN];
int rel[MAXN];

void init(int n, int &res)
{
    res = 0;
    for (int i = 1; i <= n; i++)
        r[i] = i, rel[i] = 0;
}
int Find(int x)
{
    if (x == r[x]) return r[x];    
    else {
        int t = r[x];
        r[x] = Find(r[x]);
        rel[x] = (rel[x] + rel[t]) % 3;
        return r[x];
    }
}
bool Join(int d, int a, int b)
{
    int ra = Find(a);
    int rb = Find(b);
    if (ra == rb) {
        if ((rel[a]-rel[b]+3)%3+1 != d) 
            return false;
        else                            
            return true;
    } else {
        r[ra] = rb;
        rel[ra]=(d-rel[a]+rel[b]+3-1)%3;
        return true;
    }
}
int main()
{
    int n, k, d, x, y, res;
    cin >> n >> k;
    init(n, res);
    for (int i = 0; i < k; i++) {
        scanf("%d %d %d",&d, &x, &y);
        if (d==2 && x==y)     res++; 
        else if (x>n || y>n)  res++;
        else 
            if(!Join(d, x, y))res++;      
    }
    cout << res << endl;
    return 0;
}

Simple And Clear