import sys input = sys.stdin.readline sys.setrecursionlimit(10**8) n = int(input()) C = list(input()) del C[-1] out = [] for i in range(n) : C[i] = int(C[i]) if C[i] == 0 : out.append(i+1) count = 0 V = [[] for i in range(n + 1)] for i in range(n - 1) : v1, v2 = map(int, input().split()) if C[v1 - 1] + C[v2 - 1] == 2 : count = count + 2 else : V[v1].append(v2) V[v2].append(v1) sum = sum(C) if su..