首页 > 代码库 > Spark2 生存分析Survival regression

Spark2 生存分析Survival regression

  在spark.ml中,实现了加速失效时间(AFT)模型,这是一个用于检查数据的参数生存回归模型。 它描述了生存时间对数的模型,因此它通常被称为生存分析的对数线性模型。 不同于为相同目的设计的比例风险模型,AFT模型更容易并行化,因为每个实例独立地贡献于目标函数。

  当在具有常量非零列的数据集上匹配AFTSurvivalRegressionModel而没有截距时,Spark MLlib为常量非零列输出零系数。 这种行为不同于R survival :: survreg。

 

导入包

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.Row
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Column
import org.apache.spark.sql.DataFrameReader
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.DataFrameStatFunctions
import org.apache.spark.sql.functions._

import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.regression.AFTSurvivalRegression
import org.apache.spark.ml.feature.VectorAssembler

 

 

建模

    val spark = SparkSession.builder().appName("Spark Survival regression").config("spark.some.config.option", "some-value").getOrCreate()

    // For implicit conversions like converting RDDs to DataFrames
    import spark.implicits._

    val dataList: List[(Double, Double, Double, Double)] = List(
      (2, 51, 1, 1),
      (2, 58, 1, 1),
      (2, 55, 2, 1),
      (2, 28, 22, 1),
      (1, 21, 30, 0),
      (1, 19, 28, 1),
      (2, 25, 32, 1),
      (2, 48, 11, 1),
      (2, 47, 14, 1),
      (2, 25, 36, 0),
      (2, 31, 31, 0),
      (1, 24, 33, 0),
      (1, 25, 33, 0),
      (2, 30, 37, 0),
      (2, 33, 35, 0),
      (1, 36, 25, 1),
      (1, 30, 31, 0),
      (1, 41, 22, 1),
      (2, 43, 26, 1),
      (2, 45, 24, 1),
      (2, 35, 35, 0),
      (1, 29, 34, 0),
      (1, 35, 30, 0),
      (1, 32, 35, 1),
      (2, 36, 40, 1),
      (1, 32, 39, 0))

    val data = http://www.mamicode.com/dataList.toDF("sex", "age", "label", "censor").orderBy("label")

    val colArray = Array("sex", "age")

    val assembler = new VectorAssembler().setInputCols(colArray).setOutputCol("features")

    val vecDF: DataFrame = assembler.transform(data)

    val aft = new AFTSurvivalRegression()

    val model = aft.fit(vecDF)

    // Print the coefficients, intercept and scale parameter for AFT survival regression
    println(s"Coefficients: ${model.coefficients} Intercept: " +
      s"${model.intercept} Scale: ${model.scale}")

    val Array(coeff1, coeff2) = model.coefficients.toArray

    val intercept: Double = model.intercept

    val scale: Double = model.scale

    val aftDF = model.transform(vecDF)
    
    // 风险率h(t)
    aftDF.selectExpr("sex", "age", "label", "censor",
      "features", "round(prediction,2) as prediction",
      s"round( exp( sex*$coeff1+age*$coeff2+$intercept ), 2) as h(t)").orderBy("label").show(100, false)

 

 

 

Spark2 生存分析Survival regression