Skip to content

Commit

Permalink
fixes to sklearn-based NBSVM
Browse files Browse the repository at this point in the history
  • Loading branch information
amaiya committed Dec 1, 2023
1 parent 29e1ceb commit fa5054b
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
2 changes: 1 addition & 1 deletion ktrain/text/shallownlp/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def _fit_binary(self, X, y):
X_scaled = X * r

lsvc = LinearSVC(
C=self.C, fit_intercept=self.fit_intercept, max_iter=10000
C=self.C, fit_intercept=self.fit_intercept, max_iter=10000, dual=True
).fit(X_scaled, y)

mean_mag = np.abs(lsvc.coef_).mean()
Expand Down
7 changes: 4 additions & 3 deletions tests/test_zzz_shallownlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@


class TestShallowNLP(TestCase):

# @skip('temporarily disabled')
def test_classifier(self):
categories = [
Expand All @@ -42,7 +41,8 @@ def test_classifier(self):
classes = train_b.target_names

clf = snlp.Classifier()
clf.fit(x_train, y_train, ctype="nbsvm")
clf.create_model("nbsvm", x_train, vec__ngram_range=(1, 3), vec__binary=True)
clf.fit(x_train, y_train)
self.assertGreaterEqual(clf.evaluate(x_test, y_test), 0.93)
test_doc = "god christ jesus mother mary church sunday lord heaven amen"
self.assertEqual(clf.predict(test_doc), 3)
Expand All @@ -55,7 +55,8 @@ def test_classifier_chinese(self):
)
print("label names: %s" % (label_names))
clf = snlp.Classifier()
clf.fit(x_train, y_train, ctype="nbsvm")
clf.create_model("nbsvm", x_train, vec__ngram_range=(1, 3), vec__binary=True)
clf.fit(x_train, y_train)
self.assertGreaterEqual(clf.evaluate(x_train, y_train), 0.98)
neg_text = "我讨厌和鄙视这家酒店。"
pos_text = "我喜欢这家酒店。"
Expand Down

0 comments on commit fa5054b

Please sign in to comment.