Skip to content

Commit

Permalink
Minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
lvapeab committed Apr 11, 2020
1 parent aa7174c commit efbdf2d
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 25 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ script:
elif [[ "MODE" == "PEP8_DOC" ]]; then
PYTHONPATH=$PWD:$PYTHONPATH py.test --pep8 -m pep8 -n0 && py.test tests/docs;
elif [[ "MODE" == "API" ]]; then
PYTHONPATH=$PWD:$PYTHONPATH pip install git+git://www.github.com/keras-team/keras.git && python update_api.py && pip install -e .[tests] --progress-bar off && py.test tests/test_api.py;
PYTHONPATH=$PWD:$PYTHONPATH pip install git+git://www.github.com/MarcBS/keras.git && python update_api.py && pip install -e .[tests] --progress-bar off && py.test tests/test_api.py;
elif [[ "$RUN_ONLY_BACKEND_TESTS" == "1" ]]; then
PYTHONPATH=$PWD:$PYTHONPATH py.test tests/keras/backend/;
else
Expand Down
3 changes: 3 additions & 0 deletions keras/backend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ def set_floatx(floatx):
# Arguments
floatx: String, 'float16', 'float32', or 'float64'.
# Raises:
ValueError: if `floatx` is unknown.
# Example
```python
>>> from keras import backend as K
Expand Down
32 changes: 8 additions & 24 deletions keras/backend/tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,25 +203,6 @@ def manual_variable_initialization(value):
_MANUAL_VAR_INIT = value


def set_image_data_format(data_format):
"""Sets the value of the data format convention.
# Arguments
data_format: string. `'channels_first'` or `'channels_last'`.
# Example
```python
>>> from keras import backend as K
>>> K.image_data_format()
'channels_first'
>>> K.set_image_data_format('channels_last')
>>> K.image_data_format()
'channels_last'
```
"""
tf_keras_backend.set_image_data_format(data_format)


def normalize_data_format(value):
"""Checks that the value correspond to a valid data format.
Expand Down Expand Up @@ -2730,10 +2711,11 @@ def tile(x, n):
n = tuple(n)

shape = int_shape(x)
if len(n) < len(shape): # Padding the axis
n = tuple([1 for _ in range(len(shape) - len(n))]) + n
elif len(n) != len(shape):
raise NotImplementedError
if not is_tensor(n):
if len(n) < len(shape): # Padding the axis
n = tuple([1 for _ in range(len(shape) - len(n))]) + n
elif len(n) != len(shape):
raise NotImplementedError

return tf.tile(x, n)

Expand Down Expand Up @@ -2761,7 +2743,9 @@ def batch_flatten(x):
# Returns
A tensor.
"""
x = tf.reshape(x, tf.stack([-1, prod(shape(x)[1:])]))
x = tf.reshape(
x, tf.stack([-1, prod(shape(x)[1:])],
name='stack_' + str(np.random.randint(1e4))))
return x


Expand Down

0 comments on commit efbdf2d

Please sign in to comment.