Skip to content

Commit

Permalink
Update SVM files
Browse files Browse the repository at this point in the history
  • Loading branch information
suleyman-kaya committed Jul 29, 2024
1 parent aedcbf8 commit 66d1c58
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 128 deletions.
121 changes: 41 additions & 80 deletions ml/svm.v
Original file line number Diff line number Diff line change
Expand Up @@ -26,39 +26,61 @@ pub mut:
config SVMConfig
}

pub struct SVM {
pub mut:
model &SVMModel = unsafe { nil }
kernel KernelFunction @[required]
config SVMConfig
}

type KernelFunction = fn ([]f64, []f64) f64

fn vector_dot(x []f64, y []f64) f64 {
mut sum := 0.0
for i := 0; i < x.len; i++ {
sum += x[i] * y[i]
}
return sum
}

fn vector_subtract(x []f64, y []f64) []f64 {
mut result := []f64{len: x.len}
for i := 0; i < x.len; i++ {
result[i] = x[i] - y[i]
}
return result
}

pub fn linear_kernel(x []f64, y []f64) f64 {
return dot_product(x, y)
return vector_dot(x, y)
}

pub fn polynomial_kernel(degree int) KernelFunction {
return fn [degree] (x []f64, y []f64) f64 {
return math.pow(dot_product(x, y) + 1.0, f64(degree))
return math.pow(vector_dot(x, y) + 1.0, f64(degree))
}
}

pub fn rbf_kernel(gamma f64) KernelFunction {
return fn [gamma] (x []f64, y []f64) f64 {
diff := vector_subtract(x, y)
return math.exp(-gamma * dot_product(diff, diff))
return math.exp(-gamma * vector_dot(diff, diff))
}
}

fn dot_product(a []f64, b []f64) f64 {
mut sum := 0.0
for i in 0 .. a.len {
sum += a[i] * b[i]
pub fn SVM.new(kernel KernelFunction, config SVMConfig) &SVM {
return &SVM{
kernel: kernel
config: config
}
return sum
}

fn vector_subtract(a []f64, b []f64) []f64 {
mut result := []f64{len: a.len}
for i in 0 .. a.len {
result[i] = a[i] - b[i]
}
return result
pub fn (mut s SVM) train(data []DataPoint) {
s.model = train_svm(data, s.kernel, s.config)
}

pub fn (s &SVM) predict(x []f64) int {
return predict(s.model, x)
}

pub fn train_svm(data []DataPoint, kernel KernelFunction, config SVMConfig) &SVMModel {
Expand All @@ -71,7 +93,7 @@ pub fn train_svm(data []DataPoint, kernel KernelFunction, config SVMConfig) &SVM
}

mut passes := 0
for passes < model.config.max_iterations {
for {
mut num_changed_alphas := 0
for i in 0 .. data.len {
ei := predict_raw(model, data[i].x) - f64(data[i].y)
Expand Down Expand Up @@ -138,6 +160,10 @@ pub fn train_svm(data []DataPoint, kernel KernelFunction, config SVMConfig) &SVM
} else {
passes = 0
}

if passes >= model.config.max_iterations {
break
}
}

for i in 0 .. data.len {
Expand All @@ -160,68 +186,3 @@ fn predict_raw(model &SVMModel, x []f64) f64 {
pub fn predict(model &SVMModel, x []f64) int {
return if predict_raw(model, x) >= 0 { 1 } else { -1 }
}

pub struct MulticlassSVM {
pub mut:
models [][]&SVMModel
classes []int
}

pub fn train_multiclass_svm(data []DataPoint, kernel KernelFunction, config SVMConfig) &MulticlassSVM {
mut classes := []int{}
for point in data {
if point.y !in classes {
classes << point.y
}
}
classes.sort()

mut models := [][]&SVMModel{len: classes.len, init: []&SVMModel{}}

for i in 0 .. classes.len {
models[i] = []&SVMModel{len: classes.len, init: unsafe { nil }} // unsafe { nil } kullanarak initialize ediyoruz
for j in i + 1 .. classes.len {
mut binary_data := []DataPoint{}
for point in data {
if point.y == classes[i] || point.y == classes[j] {
binary_data << DataPoint{
x: point.x
y: if point.y == classes[i] { 1 } else { -1 }
}
}
}
models[i][j] = train_svm(binary_data, kernel, config)
}
}

return &MulticlassSVM{
models: models
classes: classes
}
}

pub fn predict_multiclass(model &MulticlassSVM, x []f64) int {
mut class_votes := map[int]int{}

for i in 0 .. model.classes.len {
for j in i + 1 .. model.classes.len {
prediction := predict(model.models[i][j], x)
if prediction == 1 {
class_votes[model.classes[i]]++
} else {
class_votes[model.classes[j]]++
}
}
}

mut max_votes := 0
mut predicted_class := 0
for class, votes in class_votes {
if votes > max_votes {
max_votes = votes
predicted_class = class
}
}

return predicted_class
}
118 changes: 70 additions & 48 deletions ml/svm_test.v
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,34 @@ module ml

import math

fn test_vector_dot() {
x := [1.0, 2.0, 3.0]
y := [4.0, 5.0, 6.0]
result := vector_dot(x, y)
assert math.abs(result - 32.0) < 1e-6
}

fn test_vector_subtract() {
x := [1.0, 2.0, 3.0]
y := [4.0, 5.0, 6.0]
result := vector_subtract(x, y)
assert result == [-3.0, -3.0, -3.0]
}

fn test_linear_kernel() {
x := [1.0, 2.0, 3.0]
y := [4.0, 5.0, 6.0]
result := linear_kernel(x, y)
assert math.abs(result - 32.0) < 1e-6
}

fn test_polynomial_kernel() {
x := [1.0, 2.0, 3.0]
y := [4.0, 5.0, 6.0]
kernel := polynomial_kernel(3)
result := kernel(x, y)
expected := math.pow((1 * 4 + 2 * 5 + 3 * 6 + 1), 3) // (32 + 1)^3
assert result == expected
expected := math.pow(32.0 + 1.0, 3)
assert math.abs(result - expected) < 1e-6
}

fn test_rbf_kernel() {
Expand All @@ -17,89 +38,90 @@ fn test_rbf_kernel() {
gamma := 0.5
kernel := rbf_kernel(gamma)
result := kernel(x, y)
expected := math.exp(-gamma * ((1 - 4) * (1 - 4) + (2 - 5) * (2 - 5) + (3 - 6) * (3 - 6))) // exp(-0.5 * 27)
expected := math.exp(-gamma * 27.0)
assert math.abs(result - expected) < 1e-6
}

fn test_train_svm() {
kernel := linear_kernel
fn test_svm_new() {
config := SVMConfig{}
svm := SVM.new(linear_kernel, config)
assert svm.kernel == linear_kernel
assert svm.config == config
}

fn test_svm_train_and_predict() {
mut svm := SVM.new(linear_kernel, SVMConfig{})
data := [
DataPoint{[2.0, 3.0], 1},
DataPoint{[1.0, 1.0], -1},
DataPoint{[3.0, 4.0], 1},
DataPoint{[0.0, 0.0], -1},
]
config := SVMConfig{}
model := train_svm(data, kernel, config)
svm.train(data)

for point in data {
assert predict(model, point.x) == point.y
prediction := svm.predict(point.x)
assert prediction == point.y
}
}

fn test_predict_svm() {
kernel := linear_kernel
fn test_train_svm() {
data := [
DataPoint{[2.0, 3.0], 1},
DataPoint{[1.0, 1.0], -1},
DataPoint{[3.0, 4.0], 1},
DataPoint{[0.0, 0.0], -1},
]
config := SVMConfig{}
model := train_svm(data, kernel, config)
model := train_svm(data, linear_kernel, config)

assert predict(model, [2.0, 3.0]) == 1
assert predict(model, [1.0, 1.0]) == -1
assert predict(model, [3.0, 4.0]) == 1
assert predict(model, [0.0, 0.0]) == -1
for point in data {
prediction := predict(model, point.x)
assert prediction == point.y
}
}

fn test_train_multiclass_svm() {
kernel := linear_kernel
fn test_predict_raw() {
data := [
DataPoint{[2.0, 3.0], 1},
DataPoint{[1.0, 1.0], 2},
DataPoint{[3.0, 4.0], 1},
DataPoint{[0.0, 0.0], 2},
DataPoint{[3.0, 3.0], 3},
DataPoint{[1.0, 1.0], -1},
]
config := SVMConfig{}
model := train_multiclass_svm(data, kernel, config)
model := train_svm(data, linear_kernel, config)

for point in data {
assert predict_multiclass(model, point.x) == point.y
}
result := predict_raw(model, [2.0, 3.0])
assert result > 0

result2 := predict_raw(model, [1.0, 1.0])
assert result2 < 0
}

fn test_predict_multiclass_svm() {
kernel := linear_kernel
fn test_predict() {
data := [
DataPoint{[2.0, 3.0], 1},
DataPoint{[1.0, 1.0], 2},
DataPoint{[1.0, 1.0], -1},
DataPoint{[3.0, 4.0], 1},
DataPoint{[0.0, 0.0], 2},
DataPoint{[3.0, 3.0], 3},
DataPoint{[0.0, 0.0], -1},
]
config := SVMConfig{}
model := train_multiclass_svm(data, kernel, config)
model := train_svm(data, linear_kernel, config)

assert predict_multiclass(model, [2.0, 3.0]) == 1
assert predict_multiclass(model, [1.0, 1.0]) == 2
assert predict_multiclass(model, [3.0, 4.0]) == 1
assert predict_multiclass(model, [0.0, 0.0]) == 2
assert predict_multiclass(model, [3.0, 3.0]) == 3
for point in data {
prediction := predict(model, point.x)
assert prediction == point.y
}
}

fn test_kernels() {
kernels := [
linear_kernel,
polynomial_kernel(3),
rbf_kernel(0.5),
]
for kernel in kernels {
test_train_svm()
test_predict_svm()
test_train_multiclass_svm()
test_predict_multiclass_svm()
}
fn main() {
test_vector_dot()
test_vector_subtract()
test_linear_kernel()
test_polynomial_kernel()
test_rbf_kernel()
test_svm_new()
test_svm_train_and_predict()
test_train_svm()
test_predict_raw()
test_predict()
println('All tests passed successfully!')
}

0 comments on commit 66d1c58

Please sign in to comment.