-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtopicModel.scala
73 lines (49 loc) · 2.22 KB
/
topicModel.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.SQLImplicits
import org.apache.spark.sql.SQLContext
import org.apache.spark.ml.feature.RegexTokenizer
import org.apache.spark.ml.feature.{Tokenizer, CountVectorizer}
import org.apache.spark.ml.feature.StopWordsRemover
import org.apache.spark.sql.functions._
import org.apache.spark.ml.clustering.LDA
import org.apache.spark.sql.SparkSession
object TopicModeling {
def main(args: Array[String]): Unit = {
if (args.length != 2) {
println("Usage: inputPath outputPath")
}
val sc = new SparkContext(new SparkConf().setAppName("Spark TopicModeling"))
val spark = SparkSession.builder().getOrCreate()
import spark.implicits._
val corpus = sc.textFile(args(0)).map(_.toLowerCase())
val corpus_body = corpus.map(_.split("\\n")).map(_.mkString("\n"))
val corpus_df = corpus_body.zipWithIndex.map(_.swap).flatMapValues(_.split("\\n")).toDF("id", "corpus")
val corpus_DF = corpus_df.withColumn("id", monotonically_increasing_id())
val tokenizer = new RegexTokenizer().setPattern("[\\W_]+").setMinTokenLength(4) // Filter away tokens with length < 4
.setInputCol("corpus")
.setOutputCol("tokens")
val tokenized_df = tokenizer.transform(corpus_DF)
val remover = new StopWordsRemover()
.setInputCol("tokens")
.setOutputCol("filtered")
val removeEmpty = udf((array: Seq[String]) => !array.isEmpty)
val tokenized_DF = tokenized_df.filter(removeEmpty(col("tokens")))
val filtered_df = remover.transform(tokenized_DF)
val cv = new CountVectorizer()
.setInputCol("filtered")
.setOutputCol("features")
.setVocabSize(10000)
.setMinTF(10)
.setMinDF(10)
.setBinary(true)
val cvFitted = cv.fit(filtered_df)
val prepped = cvFitted.transform(filtered_df)
val lda = new LDA().setK(5).setMaxIter(60)
val model = lda.fit(prepped)
val vocabList = cvFitted.vocabulary
val termsIdx2Str = udf { (termIndices: Seq[Int]) => termIndices.map(idx => vocabList(idx)) }
val topics = model.describeTopics(maxTermsPerTopic = 6)
.withColumn("terms", termsIdx2Str(col("termIndices")))
topics.select("terms").rdd.saveAsTextFile(args(1))
}
}