Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Try new rule algorithm #16

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@ project/plugins/project/
.cache
.bsp/
.idea
.metals
.bloop
15 changes: 15 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"type": "java",
"name": "Debug (Attach to spark)",
"request": "attach",
"hostName": "localhost",
"port": 5050
}
]
}
5 changes: 5 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"files.watcherExclude": {
"**/target": true
}
}
19 changes: 12 additions & 7 deletions build.sbt
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
version := "1.0"
scalaVersion := "2.12.15"
scalaVersion := "2.12.20"
organization := "com.recsys"
name := "recsys-spark"

unmanagedBase := baseDirectory.value / "lib"

val sparkVersion = "3.2.1"
val sparkVersion = "3.5.4"

libraryDependencies ++= Seq(
"org.scala-lang" % "scala-library" % scalaVersion.value,
"org.apache.spark" %% "spark-sql" % sparkVersion % "provided",
"org.apache.spark" %% "spark-mllib" % sparkVersion % "provided",
"org.apache.spark" %% "spark-core" % sparkVersion % "provided",
"com.github.nscala-time" %% "nscala-time" % "2.32.0"
"org.apache.spark" %% "spark-sql" % sparkVersion,
"org.apache.spark" %% "spark-mllib" % sparkVersion,
"org.apache.spark" %% "spark-core" % sparkVersion,
"com.github.nscala-time" %% "nscala-time" % "2.32.0",
"com.github.fommil.netlib" % "all" % "1.1.2" pomOnly()
)

assemblyMergeStrategy in assembly := {
case PathList("META-INF", xs @ _*) => MergeStrategy.discard
case x => MergeStrategy.first
}
Binary file modified lib/sparkml-som_2.12-0.2.1.jar
Binary file not shown.
2 changes: 1 addition & 1 deletion project/build.properties
Original file line number Diff line number Diff line change
@@ -1 +1 @@
sbt.version=1.6.2
sbt.version=1.10.7
8 changes: 8 additions & 0 deletions project/metals.sbt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// format: off
// DO NOT EDIT! This file is auto-generated.

// This file enables sbt-bloop to create bloop config files.

addSbtPlugin("ch.epfl.scala" % "sbt-bloop" % "2.0.6")

// format: on
3 changes: 2 additions & 1 deletion project/plugins.sbt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "1.2.0")
addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "2.3.0")
addSbtPlugin("ch.epfl.scala" % "sbt-bloop" % "1.5.0")
8 changes: 8 additions & 0 deletions project/project/metals.sbt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// format: off
// DO NOT EDIT! This file is auto-generated.

// This file enables sbt-bloop to create bloop config files.

addSbtPlugin("ch.epfl.scala" % "sbt-bloop" % "2.0.6")

// format: on
8 changes: 8 additions & 0 deletions project/project/project/metals.sbt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// format: off
// DO NOT EDIT! This file is auto-generated.

// This file enables sbt-bloop to create bloop config files.

addSbtPlugin("ch.epfl.scala" % "sbt-bloop" % "2.0.6")

// format: on
103 changes: 99 additions & 4 deletions src/main/scala/Main.scala
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,85 @@ object Main {
})
}

def sequentialTopKCrossValidation(): Seq[(Double, Double, Double)] = {
val spark = SparkSession.getActiveSession.orNull

val recSys = new SequentialTopKRecommender(5, 1682).setGridSize(
3, 3
).setPeriods(5).setMinParamsApriori(
0.005, 0.80
).setMinParamsSequential(
0.001, 0.60
)

val predictions_accumulator1 = new ListBufferAccumulator[(Double, Double, Double)]
spark.sparkContext.register(predictions_accumulator1, "predictions1")
val predictions_accumulator2 = new ListBufferAccumulator[(Double, Double, Double)]
spark.sparkContext.register(predictions_accumulator2, "predictions2")
val predictions_accumulator3 = new ListBufferAccumulator[(Double, Double, Double)]
spark.sparkContext.register(predictions_accumulator3, "predictions3")
val predictions_accumulator4 = new ListBufferAccumulator[(Double, Double, Double)]
spark.sparkContext.register(predictions_accumulator4, "predictions4")
val predictions_accumulator5 = new ListBufferAccumulator[(Double, Double, Double)]
spark.sparkContext.register(predictions_accumulator5, "predictions5")

Seq(1, 2, 3, 4, 5).map(index => {
println("Fold " + index)
val train = dataset("data/train-fold" + index + ".csv")
val test = dataset("data/test-fold" + index + ".csv")

val accumulator = index match {
case 1 => predictions_accumulator1
case 2 => predictions_accumulator2
case 3 => predictions_accumulator3
case 4 => predictions_accumulator4
case 5 => predictions_accumulator5
}

println("train")
recSys.fit(train)
println("termina train")

println("datos test empiezan")
val testData = test.groupBy("user_id").agg(
collect_list(col("item_id")).as("items"),
collect_list(col("rating")).as("ratings")
).collect()
println("datos test terminan")

testData.foreach(row => {
val userId = row.getInt(0)
val items = row.getList(1).toArray()
val ratings = row.getList(2).toArray()

val relevant = items.zip(ratings).filter(
_._2.asInstanceOf[Double] >= 4.0
).map(_._1.asInstanceOf[Int]).toSet

val selected = recSys.transform(
train.filter(col("user_id") === userId)
)

accumulator.add(
new RankingMetrics(k = 5, selected.map(_._1).toSet, relevant).getRankingMetrics,
)
}: Unit)

val metricPerUser = accumulator.value
val sumMetrics = metricPerUser.reduce((a, b) => {
(a._1 + b._1, a._2 + b._2, a._3 + b._3)
})

val finalMetrics = (
sumMetrics._1 / metricPerUser.length,
sumMetrics._2 / metricPerUser.length,
sumMetrics._3 / metricPerUser.length
)

finalMetrics
})
}

def hybridCrossValidation(firstRecSys: BaseRecommender, secondRecSys: BaseRecommender, weightFirstRecSys: Double, weightSecondRecSys: Double, topK: Int, numberOfItems: Int): Seq[(Double, Double, Double)] = {
val spark = SparkSession.getActiveSession.orNull

Expand Down Expand Up @@ -337,7 +416,20 @@ object Main {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().master(
"local[*]"
).config(
)
.config(
"spark.driver.memory", "6gb"
)
.config(
"spark.executor.memory", "6gb"
)
.config(
"spark.driver.extraJavaOptions", "--add-opens=java.base/sun.nio.ch=ALL-UNNAMED"
)
.config(
"spark.executor.extraJavaOptions", "--add-opens=java.base/sun.nio.ch=ALL-UNNAMED -agentlib:jdwp=transport=dt_socket,server=n,address=localhost:5005,suspend=n"
)
.config(
"spark.sql.autoBroadcastJoinThreshold", "-1"
).config(
"spark.jars", "lib/sparkml-som_2.12-0.2.1.jar"
Expand All @@ -347,9 +439,9 @@ object Main {
spark.sparkContext.setLogLevel("ERROR")

println("Try here your recommender")
// User based approach
// val results = userBasedTopKCrossValidation(new PearsonSimilarity, 25, 5, 1682)
// println(results)
// User based approach
// val results = userBasedTopKCrossValidation(new PearsonSimilarity, 25, 5, 1682)
// println(results)

// Item based approach
// val results = itemBasedTopKCrossValidation(new CosineSimilarity, 25, 5, 1682)
Expand Down Expand Up @@ -379,6 +471,9 @@ object Main {
// val results = hybridCrossValidation(recSys1, recSys2, 0.6, 0.4, 5, 1682)
// println(result)

val results = sequentialTopKCrossValidation()
println(results)

spark.stop()
}
}
Loading