首页 > 代码库 > 一元三次方程求解c++实现

一元三次方程求解c++实现

typedef  double Number;
class CubicRealPolynomial
{
public:
    static Number computeDiscriminant(Number a, Number b, Number c, Number d);
    static std::vector<Number> computeRealRoots(Number a, Number b, Number c, Number d);
private:
    static std::vector<Number> computeRealRootsInner(Number a, Number b, Number c, Number d);

};

--------------

    Number CubicRealPolynomial::computeDiscriminant(Number a, Number b, Number c, Number d)
    {
        Number a2 = a * a;
        Number b2 = b * b;
        Number c2 = c * c;
        Number d2 = d * d;

        Number discriminant = 18.0 * a * b * c * d + b2 * c2 - 27.0 * a2 * d2 - 4.0 * (a * c2 * c + b2 * b * d);
        return discriminant;
    }

    std::vector<Number> CubicRealPolynomial::computeRealRoots(Number a, Number b, Number c, Number d)
    {
        std::vector<Number> roots;
        Number ratio, root1, root2, root3;
        if (a == 0.0)
        {
            // Quadratic function: b * x^2 + c * x + d = 0.
            QuadraticRealPolynomial::computeRealRoots(b, c, d, roots);
            return roots;
        } else if (b == 0.0)
        {
            if (c == 0.0)
            {
                if (d == 0.0)
                {
                    // 3rd order monomial: a * x^3 = 0.
                    roots.push_back(0.0);
                    roots.push_back(0.0);
                    roots.push_back(0.0);
                    return roots;
                }

                // a * x^3 + d = 0
                ratio = -d / a;
                Number root = (ratio < 0.0) ? -std::pow(-ratio, 1.0 / 3.0) : std::pow(ratio, 1.0 / 3.0);
                roots.push_back(root);
                roots.push_back(root);
                roots.push_back(root);
                return roots;
            } else if (d == 0.0)
            {
                // x * (a * x^2 + c) = 0.
                QuadraticRealPolynomial::computeRealRoots(a, 0, c, roots);

                // Return the roots in ascending order.
                if (roots.size() == 0)
                {
                    roots.push_back(0.0);
                    return roots;
                }

                root1 = roots[0];
                root2 = 0.0;
                root3 = roots[1];

                roots.clear();
                roots.push_back(root1);
                roots.push_back(root2);
                roots.push_back(root3);
                return roots;
            }

            // Deflated cubic polynomial: a * x^3 + c * x + d= 0.
            return computeRealRootsInner(a, 0, c, d);
        } else if (c == 0.0)
        {
            if (d == 0.0)
            {
                // x^2 * (a * x + b) = 0.
                ratio = -b / a;
                if (ratio < 0.0)
                {
                    root1 = ratio;
                    root2 = 0.0;
                    root3 = 0.0;

                    roots.clear();
                    roots.push_back(root1);
                    roots.push_back(root2);
                    roots.push_back(root3);
                    return roots;
                }
                root1 = 0.0;
                root2 = 0.0;
                root3 = ratio;

                roots.clear();
                roots.push_back(root1);
                roots.push_back(root2);
                roots.push_back(root3);
                return roots;
            }
            // a * x^3 + b * x^2 + d = 0.
            return computeRealRootsInner(a, b, 0, d);
        }
        else if (d == 0.0)
        {
            // x * (a * x^2 + b * x + c) = 0
            QuadraticRealPolynomial::computeRealRoots(a, b, c, roots);

            // Return the roots in ascending order.
            if (roots.size() == 0)
            {
                roots.clear();
                roots.push_back(0.0);
                return roots;
            } else if (roots[1] <= 0.0)
            {
                root1 = roots[0];
                root2 = roots[1];
                root3 = 0.0;

                roots.clear();
                roots.push_back(root1);
                roots.push_back(root2);
                roots.push_back(root3);
                return roots;
            }
            else if (roots[0] >= 0.0)
            {
                root1 = 0.0;
                root2 = roots[0];
                root3 = roots[1];

                roots.clear();
                roots.push_back(root1);
                roots.push_back(root2);
                roots.push_back(root3);
                return roots;
            }
            root1 = roots[0];
            root2 =  0.0;
            root3 = roots[1];

            roots.clear();
            roots.push_back(root1);
            roots.push_back(root2);
            roots.push_back(root3);
            return roots;
        }

        return computeRealRootsInner(a, b, c, d);
    }

    std::vector<Number> CubicRealPolynomial::computeRealRootsInner(Number a, Number b, Number c, Number d)
    {
        std::vector<Number> roots;

        Number A = a;
        Number B = b / 3.0;
        Number C = c / 3.0;
        Number D = d;

        Number AC = A * C;
        Number BD = B * D;
        Number B2  = B * B;
        Number C2  = C * C;
        Number delta1 = A * C - B2;
        Number delta2 = A * D - B * C;
        Number delta3 = B * D - C2;

        Number discriminant = 4.0 * delta1 * delta3 - delta2 * delta2;
        Number temp;
        Number temp1;

        if (discriminant < 0.0) {
            Number ABar;
            Number CBar;
            Number DBar;

            if (B2 * BD >= AC * C2) {
                ABar = A;
                CBar = delta1;
                DBar = -2.0 * B * delta1 + A * delta2;
            } else {
                ABar = D;
                CBar = delta3;
                DBar = -D * delta2 + 2.0 * C * delta3;
            }

            Number s = (DBar < 0.0) ? -1.0 : 1.0; // This is not std::Sign()!
            Number temp0 = -s * std::abs(ABar) * std::sqrt(-discriminant);
            temp1 = -DBar + temp0;

            Number x = temp1 / 2.0;
            Number p = x < 0.0 ? -std::pow(-x, 1.0 / 3.0) : std::pow(x, 1.0 / 3.0);
            Number q = (temp1 == temp0) ? -p : -CBar / p;

            temp = (CBar <= 0.0) ? (p + q) : (-DBar / (p * p + q * q + CBar));

            if (B2 * BD >= AC * C2)
            {
                roots.push_back((temp - B) / A);
                return roots;
            }

            roots.push_back(-D / (temp + C));
            return roots;
        }

        Number CBarA = delta1;
        Number DBarA = -2.0 * B * delta1 + A * delta2;

        Number CBarD = delta3;
        Number DBarD = -D * delta2 + 2.0 * C * delta3;

        Number squareRootOfDiscriminant = std::sqrt(discriminant);
        Number halfSquareRootOf3 = std::sqrt(3.0) / 2.0;

        Number theta = std::abs(std::atan2(A * squareRootOfDiscriminant, -DBarA) / 3.0);
        temp = 2.0 * std::sqrt(-CBarA);
        Number cosine = std::cos(theta);
        temp1 = temp * cosine;
        Number temp3 = temp * (-cosine / 2.0 - halfSquareRootOf3 * std::sin(theta));

        Number numeratorLarge = (temp1 + temp3 > 2.0 * B) ? temp1 - B : temp3 - B;
        Number denominatorLarge = A;

        Number root1 = numeratorLarge / denominatorLarge;

        theta = std::abs(std::atan2(D * squareRootOfDiscriminant, -DBarD) / 3.0);
        temp = 2.0 * std::sqrt(-CBarD);
        cosine = std::cos(theta);
        temp1 = temp * cosine;
        temp3 = temp * (-cosine / 2.0 - halfSquareRootOf3 * std::sin(theta));

        Number numeratorSmall = -D;
        Number denominatorSmall = (temp1 + temp3 < 2.0 * C) ? temp1 + C : temp3 + C;

        Number root3 = numeratorSmall / denominatorSmall;

        Number E = denominatorLarge * denominatorSmall;
        Number F = -numeratorLarge * denominatorSmall - denominatorLarge * numeratorSmall;
        Number G = numeratorLarge * numeratorSmall;

        Number root2 = (C * F - B * G) / (-B * F + C * E);

        if (root1 <= root2) {
            if (root1 <= root3) {
                if (root2 <= root3)
                {
                    roots.push_back(root1);
                    roots.push_back(root2);
                    roots.push_back(root3);
                    return roots;
                }
                roots.push_back(root1);
                roots.push_back(root3);
                roots.push_back(root2);
                return roots;
            }
            roots.push_back(root3);
            roots.push_back(root1);
            roots.push_back(root2);
            return roots;
        }
        if (root1 <= root3)
        {
            roots.push_back(root2);
            roots.push_back(root1);
            roots.push_back(root3);
            return roots;
        }
        if (root2 <= root3)
        {
            roots.push_back(root2);
            roots.push_back(root3);
            roots.push_back(root1);
            return roots;
        }
        roots.push_back(root3);
        roots.push_back(root2);
        roots.push_back(root1);
        return roots;
    }



一元三次方程求解c++实现