Skip to content

Commit

Permalink
Update train_mnist_model.py
Browse files Browse the repository at this point in the history
Changed fetch_mldata() to fetch_openml().
  • Loading branch information
Rajtilak Bhattacharjee committed May 29, 2020
1 parent 249a90f commit 057cbad
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions projects/deploy_mnist/train_mnist_model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from sklearn.datasets import fetch_mldata
#from sklearn.datasets import fetch_mldata
from sklearn.datasets import fetch_openml
import numpy as np
from sklearn.linear_model import SGDClassifier
from sklearn.metrics import accuracy_score
from sklearn.externals import joblib

np.random.seed(42)
mnist = fetch_mldata("MNIST original")
#mnist = fetch_mldata("MNIST original")
mnist = fetch_openml('mnist_784', version=1, cache=True)
mnist.target = mnist.target.astype(np.int8)
X, y = mnist["data"], mnist["target"]

X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]
Expand Down

0 comments on commit 057cbad

Please sign in to comment.