在上一篇文章线性回归是直线,这次来个多项式的

public class RegMultiLinear
{
    public double[] Param = null;
    public double R2
    {
        get;
        private set;
    }

    private double[] OriginX = null;
    private double[] OriginY = null;
    public ShRegMultiLinear()
    {
        // Y = a[0] + a[1] * X + a[2] * X + ......
        this.Param = new double[2];
    }

    //  外面应该剔除无效值以后再传进来,里面不做判断
    public void SetOriginData(double[] x, double[] y, int ex)
    {
        this.OriginX = x;
        this.OriginY = y;

        // 多项式回归至少是2次
        if (ex < 2)
            ex = 2;
        this.Param = new double[ex + 1];
    }

    public bool CalculateRegression()
    {
        // 几次多项式则至少要有这么多点
        if (this.OriginX == null || this.OriginX.Length < this.Param.Length ||
                this.OriginY == null || this.OriginX.Length != this.OriginY.Length)
            return false;

        // 计算方程组的增广矩阵
        double[,] em = new double[this.Param.Length, this.Param.Length + 1];
        for (int i = 0; i < this.Param.Length; i++)
        {
            for (int j = 0; j < this.Param.Length; j++)
            {
                for (int k = 0; k < this.OriginX.Length; k++)
                {
                    em[i, j] += Math.Pow(this.OriginX[k], i + j);
                }
            }

            for (int k = 0; k < this.OriginX.Length; k++)
            {
                em[i, this.Param.Length] += this.OriginY[k] * Math.Pow(this.OriginX[k], i);
            }
        }

        // 求解方程(消元过程)
        for (int k = 0; k < this.Param.Length - 1; k++)
        {
            for (int i = k + 1; i < this.Param.Length; i++)
            {
                double t = -em[i, k] / em[k, k];
                for (int j = k + 1; j <= this.Param.Length; j++)
                {
                    em[i, j] += em[k, j] * t;
                }
            }
        }

        // 回代求解
        for (int i = this.Param.Length - 1; i >= 0; i--)
        {
            for (int j = i + 1; j < this.Param.Length; j++)
            {
                em[i, this.Param.Length] -= em[i, j] * em[j, this.Param.Length];
            }
            em[i, this.Param.Length] /= em[i, i];
        }

        // 获取参数
        this.Param[0] = em[0, this.Param.Length];
        for (int i = 1; i < this.Param.Length; i++)
        {
            this.Param[i] = em[i, this.Param.Length];
        }
        return true;
    }

    public double CalculateValue(double x)
    {
        if (this.Param.Length < 3)
            return ParamTool.InvalidValue;

        double sum = 0.0;
        for (int index = 0; index < this.Param.Length; index++)
        {
            sum += this.Param[index] * Math.Pow(x, index);
        }
        return sum;
    }

    public double[] GetResult()
    {
        return this.Param;
    }

    public double GetR2()
    {
        double averY = this.OriginY.Average();

        double ssr = 0.0;
        double sse = 0.0;
        for (int i = 0; i < this.OriginX.Length; i++)
        {
            double yi = CalculateValue(this.OriginX[i]);
            // 回归平方和
            ssr += Math.Pow(yi - averY, 2);
            // 残差平方和
            sse += Math.Pow(yi - this.OriginY[i], 2);
        }
        return 1 - (sse / (ssr + sse));
    }
}