首页 > 代码库 > 261. Discrete Roots
261. Discrete Roots
给定\(p, k, A\),满足\(k, p\)是质数,求
\[x^k \equiv A \mod p\]
不会。。。
1 #include <iostream> 2 #include <map> 3 #include <cstdio> 4 #include <algorithm> 5 #include <cstring> 6 #include <vector> 7 #include <cmath> 8 using namespace std; 9 typedef long long LL;10 11 vector<LL> f, as;12 LL fast_pow(LL base, LL index, LL mod) {13 LL ret = 1;14 for(; index; index >>= 1, base = base * base % mod)15 if(index & 1) ret = ret * base % mod;16 return ret;17 }18 bool test_Primitive_Root(LL g, LL p) {19 for(LL i = 0; i < f.size(); ++i)20 if(fast_pow(g, (p - 1) / f[i], p) == 1)21 return 0;22 return 1;23 }24 LL get_Primitive_Root(LL p) {25 f.clear();26 LL tmp = p - 1;27 for(LL i = 2; i <= tmp / i; ++i) 28 if(tmp % i == 0)29 for(f.push_back(i); tmp % i == 0; tmp /= i);30 if(tmp != 1) f.push_back(tmp);31 for(LL g = 1; ; ++g) {32 if(test_Primitive_Root(g, p))33 return g;34 }35 }36 LL get_Discrete_Logarithm(LL x, LL n, LL m) {37 map<LL, int> rec;38 LL s = (LL)(sqrt((double)m) + 0.5), cur = 1;39 for(LL i = 0; i < s; rec[cur] = i, cur = cur * x % m, ++i);40 LL mul = cur;41 cur = 1;42 for(LL i = 0; i < s; ++i) {43 LL more = n * fast_pow(cur, m - 2, m) % m;44 if(rec.count(more))45 return i * s + rec[more];46 cur = cur * mul % m;47 }48 return -1;49 }50 LL ext_Euclid(LL a, LL b, LL &x, LL &y) {51 if(b == 0) {52 x = 1, y = 0;53 return a;54 } else {55 LL ret = ext_Euclid(b, a % b, y, x);56 y -= x * (a / b);57 return ret;58 }59 }60 void solve_Linear_Mod_Equation(LL a, LL b, LL n) {61 LL x, y, d;62 as.clear();63 d = ext_Euclid(a, n, x, y);64 if(b % d == 0) {65 x %= n, x += n, x %= n;66 as.push_back(x * (b / d) % (n / d));67 for(LL i = 1; i < d; ++i)68 as.push_back((as[0] + i * n / d) % n);69 }70 }71 72 int main() {73 #ifndef ONLINE_JUDGE74 freopen("data.in", "r", stdin); freopen("data.out", "w", stdout);75 #endif76 77 LL p, k, a;78 cin >> p >> k >> a;79 if(a == 0) {80 puts("1\n0");81 return 0;82 }83 LL g = get_Primitive_Root(p);84 LL q = get_Discrete_Logarithm(g, a, p);85 solve_Linear_Mod_Equation(k, q, p - 1);86 for(int i = 0; i < as.size(); ++i)87 as[i] = fast_pow(g, as[i], p);88 sort(as.begin(), as.end());89 printf("%d\n", as.size());90 for(int i = 0; i < as.size(); ++i) {91 printf("%lld%c", as[i], i == as.size() - 1 ? ‘\n‘ : ‘ ‘);92 }93 return 0;94 }
声明:以上内容来自用户投稿及互联网公开渠道收集整理发布,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任,若内容有误或涉及侵权可进行投诉: 投诉/举报 工作人员会在5个工作日内联系你,一经查实,本站将立刻删除涉嫌侵权内容。