-
Notifications
You must be signed in to change notification settings - Fork 1
/
nmf.py
114 lines (97 loc) · 3.83 KB
/
nmf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import numpy as np
from .utils.cp_compat import get_array_module
from .utils.data import MinibatchData, NoneIterator, AsyncMinibatchData
from .utils import assertion, normalize
from .nmf_methods import batch_mu, serizel, kasai, grads
BATCH_METHODS = ['mu']
MINIBATCH_METHODS = [
'asg-mu', 'gsg-mu', 'asag-mu', 'gsag-mu', # Romain Serizel et al
'svrmu', 'svrmu-acc', # H. Kasai et al
]
_JITTER = 1.0e-15
def solve(y, D, x=None, tol=1.0e-3, minibatch=None, maxiter=1000, method='mu',
likelihood='l2', mask=None, random_seed=None, **kwargs):
"""
Non-negative matrix factrization.
argmin_{x, D} {|y - xD|^2 - alpha |x|}
s.t. |D_j|^2 <= 1 and D > 0 and x > 0
with
y: [n_samples, n_channels]
x: [n_samples, n_features]
D: [n_features, n_channels]
Parameters
----------
y: array-like.
Shape: [n_samples, ch]
D: array-like.
Initial dictionary, shape [ch, n_component]
x0: array-like
An initial estimate of x
tol: a float.
Criterion to stop iteration
maxiter: an integer
Number of iteration
method: string
One of ['mu', 'asg-mu', 'gsg-mu', 'asag-mu', 'gsag-mu',
'svrmu', 'svrmu-acc']
likelihood: string
One of ['l2', 'kl']
mask: an array-like of Boolean (or integer, float)
The missing point should be zero. One for otherwise.
"""
xp = get_array_module(D)
if x is None:
x = xp.ones((y.shape[0], D.shape[0]), dtype=y.dtype)
assertion.assert_dtypes(y=y, D=D, x=x)
assertion.assert_dtypes(y=y, D=D, x=x, mask=mask, dtypes='f')
assertion.assert_shapes('x', x, 'D', D, axes=1)
assertion.assert_shapes('y', y, 'D', D, axes=[-1])
assertion.assert_shapes('y', y, 'mask', mask)
assertion.assert_ndim('y', y, 2)
assertion.assert_ndim('D', D, 2)
assertion.assert_ndim('x', x, 2)
assertion.assert_nonnegative(D)
assertion.assert_nonnegative(x)
if likelihood in ['kl']:
assertion.assert_nonnegative(y)
D = normalize.l2_strict(D, axis=-1, xp=xp)
# batch methods
if minibatch is None:
# Check all the class are numpy or cupy
xp = get_array_module(y, D, x)
if method == 'mu':
return batch_mu.solve(y, D, x, tol, maxiter, likelihood, mask, xp,
**kwargs)
raise NotImplementedError('Batch-NMF with {} algorithm is not yet '
'implemented.'.format(method))
if xp is np:
# check all the array type is np
get_array_module(y, D, x, mask)
y = MinibatchData(y, minibatch)
x = MinibatchData(x, minibatch)
if mask is None:
mask = NoneIterator()
else:
mask = MinibatchData(mask, minibatch)
rng = xp.random.RandomState(random_seed)
else:
# minibatch methods
def get_dataset(a, needs_update=True):
if a is None:
return NoneIterator()
if get_array_module(a) is not np:
return MinibatchData(a, minibatch)
return AsyncMinibatchData(a, minibatch,
needs_update=needs_update)
x = get_dataset(x, needs_update=True)
y = get_dataset(y, needs_update=False)
mask = get_dataset(mask, needs_update=False)
rng = xp.random.RandomState(random_seed)
if method in ['asg-mu', 'gsg-mu', 'asag-mu', 'gsag-mu']:
return serizel.solve(y, D, x, tol, minibatch, maxiter, method,
likelihood, mask, rng, xp, **kwargs)
if method in ['svrmu', 'svrmu-acc']:
return kasai.solve(y, D, x, tol, minibatch, maxiter, method,
likelihood, mask, rng, xp, **kwargs)
raise NotImplementedError('NMF with {} algorithm is not yet '
'implemented.'.format(method))