Skip to content
This repository has been archived by the owner on Apr 10, 2024. It is now read-only.

Fixing notebook and adding jupyter #254

Draft
wants to merge 57 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
d812c82
correcting a bug where the read method was ignoring the "mode" argument
Shulk97 Jun 16, 2020
50d2d9d
corresting bug that was adding a dot in filename for local paths in s…
Shulk97 Jun 16, 2020
9e2c472
adding a condition in objective 'direction' to allow the argument vec…
Shulk97 Jun 16, 2020
5275c85
updating tutorial notebook
Shulk97 Jun 16, 2020
dbda67d
updating notebook of modelzoo
Shulk97 Jun 16, 2020
0db3c3d
updating notebook of Semantidc Dictionnary
Shulk97 Jun 16, 2020
21793cb
adding a jupyter version of the notebook Semantic Dictionnary
Shulk97 Jun 16, 2020
0ee1e2d
updating notebook of Activation grid
Shulk97 Jun 16, 2020
d620123
adding a jupyter version of the notebook Activation grid
Shulk97 Jun 16, 2020
c758cc2
updating notebook of Spatial Attribution
Shulk97 Jun 16, 2020
88607b5
adding a Jupyter version of the notebook Spatial Attribution
Shulk97 Jun 16, 2020
ea7ad4a
updating notebook of Channel Attribution
Shulk97 Jun 16, 2020
94d9354
adding a Jupyter version of the notebook Channel Attribution
Shulk97 Jun 16, 2020
266fb0b
adding a disclaimer at the beginning of the Jupyter version of Activa…
Shulk97 Jun 16, 2020
c4bb3a9
adding an example of a custom model defined like ModelZoo models
Shulk97 Jun 16, 2020
862070d
adding a notebook containing examples of available objectives. It com…
Shulk97 Jun 16, 2020
27fb867
adding a module showing how to import a Keras Model
Shulk97 Jun 16, 2020
7f00d77
adding a notebook for spritemaps generation
Shulk97 Jun 16, 2020
1fba450
adding a license to the created python files
Shulk97 Jun 16, 2020
9c34f4f
changing the path of uploaded file to a relative path
Shulk97 Jun 16, 2020
d74dd93
refactoring of imports + reloading of every notebooks in order to con…
Shulk97 Jun 16, 2020
298e003
Add more ops to whitelist
gabgoh Jul 11, 2020
aedd7c1
Add cache to the docstring of load
csvoss Sep 3, 2020
d7652b0
added adv fine-tuned InceptionV1
stefsietz Sep 21, 2020
bccb9f7
rl_util attribution bug fix
jacobhilton Sep 25, 2020
34ed55f
default_score_fn
jacobhilton Sep 26, 2020
6fd0fb5
Bump lodash from 4.17.15 to 4.17.19 in /lucid/scratch/js
dependabot[bot] Jul 16, 2020
1de5835
Bump node-fetch from 2.1.1 to 2.6.1 in /lucid/scratch/js
dependabot[bot] Sep 10, 2020
130cb44
Bump serve from 9.4.0 to 10.1.2 in /lucid/scratch/js
dependabot[bot] Sep 28, 2020
6f06ff1
Bump acorn from 5.5.3 to 5.7.4 in /lucid/scratch/js
dependabot[bot] Sep 28, 2020
ab2ed71
rl_util notebook
jacobhilton Nov 12, 2020
1b4714b
rl_util.conv2d dtype bug fix
jacobhilton Nov 13, 2020
1e9a98c
rl_util notebook lucid commit hash
jacobhilton Nov 14, 2020
6beda59
Fix total_variation citation
ProGamerGov Nov 24, 2020
76d0ce1
Bump ini from 1.3.5 to 1.3.8 in /lucid/scratch/js
dependabot[bot] Jan 11, 2021
4cd0df0
Update show() to fix Firefox rendering issue
csvoss Jan 26, 2021
cf4870a
Delete package-lock.json
mihaimaruseac Mar 4, 2021
068f57f
added clip model
gabgoh Mar 4, 2021
e68dbb6
Update CLIP model URL to modelzoo bucket
ludwigschubert Mar 12, 2021
60668b2
Correct import paths for Clip model
ludwigschubert Mar 12, 2021
f175a88
Require Python 3.7 as numpy dependency does so
ludwigschubert Mar 12, 2021
96ff5ae
Require Python 3.7 on CI, too
ludwigschubert Mar 12, 2021
33a6323
Pin TF version on CI to last 1.x release
ludwigschubert Mar 12, 2021
8b4ba19
WIP, TBS
ludwigschubert Mar 12, 2021
f26307b
Attempt to catch ModuleNotFoundError in channel_reducer to get tests …
ludwigschubert Mar 12, 2021
0ade290
WIP testing newer gfile module import
ludwigschubert Mar 12, 2021
d07b276
Correct way to call modern GFile
ludwigschubert Mar 12, 2021
0fcc7ab
Fix a longstanding bug in url_scope that added an extra . after local…
ludwigschubert Mar 12, 2021
5f4b260
Add "slow" marker to pytest config options to avoid warning, disbale …
ludwigschubert Mar 12, 2021
4d972ba
tf.spectral.irfft2d -> tf.signal.irfft2d
ludwigschubert Mar 12, 2021
1803e85
Pin numpy to version 1.19
ludwigschubert Mar 12, 2021
9e0ebcb
[*.py] Rename "Arguments:" to "Args:"
SamuelMarks Dec 5, 2020
17d52b5
Update xy2rgb.ipynb
Tylersuard Dec 14, 2020
3ab6a96
Fixed spelling of factorization
ProGamerGov Sep 16, 2020
3263e11
ignore lock files
1wheel Sep 28, 2020
58b201e
@colah's graph_analysis: nicer json parsed graph structure
colah Dec 4, 2019
a4c29a9
Adding a message for channel/spatial attribution indicating how to ge…
Shulk7 Mar 14, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,7 @@ lib
share

tests/fixtures/generated_outputs/


lucid/scratch/js/package-lock.json
lucid/scratch/js/yarn.lock
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
language: python
python:
- "3.6"
- "3.7"
install:
- pip install -U pip wheel
- pip install python-coveralls
Expand Down
8 changes: 4 additions & 4 deletions lucid/misc/channel_reducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
import sklearn.decomposition

try:
from sklearn.decomposition.base import BaseEstimator
except AttributeError:
from sklearn.base import BaseEstimator
from sklearn.decomposition.base import BaseEstimator
except ModuleNotFoundError:
from sklearn.base import BaseEstimator


class ChannelReducer(object):
Expand All @@ -44,7 +44,7 @@ def __init__(self, n_components=3, reduction_alg="NMF", **kwargs):
Inputs:
n_components: Numer of dimensions to reduce inner most dimension to.
reduction_alg: A string or sklearn.decomposition class. Defaults to
"NMF" (non-negative matrix facotrization). Other options include:
"NMF" (non-negative matrix factorization). Other options include:
"PCA", "FastICA", and "MiniBatchDictionaryLearning". The name of any of
the sklearn.decomposition classes will work, though.
kwargs: Additional kwargs to be passed on to the reducer.
Expand Down
100 changes: 100 additions & 0 deletions lucid/misc/custom_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright 2018 The Lucid Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from lucid.modelzoo.vision_base import Model, _layers_from_list_of_dicts


class CustomModel(Model):
"""Example of custom Lucid Model class. This example is based on Mobilenet
from Keras Applications
"""

model_path = "lucid_protobuf_file.pb"
dataset = "ImageNet"
image_shape = [224, 224, 3]
image_value_range = (-1, 1)
input_name = "input"
# Labels as a index-class name dictionnary :
# Of course if you really use a daset with 1000 classes you
# should consider loading them from a file.
_labels = {
0: 'tench, Tinca tinca',
1: 'goldfish, Carassius auratus',
2: 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias',
3: 'tiger shark, Galeocerdo cuvieri',
4: 'hammerhead, hammerhead shark',
5: 'electric ray, crampfish, numbfish, torpedo',
6: 'stingray',
7: 'cock',
8: 'hen',
9: 'ostrich, Struthio camelus',
# ...
999: 'toilet tissue, toilet paper, bathroom tissue'}
}

@property
def labels(self):
return self._labels

def label_index(self, label):
return list(self._labels.values()).index(label)

CustomModel.layers = _layers_from_list_of_dicts(
CustomModel(),
[
{"name": "conv1_relu/Relu6", "depth": 32, "tags": ["conv"]},
{"name": "conv_pw_1_relu/Relu6", "depth": 64, "tags": ["conv"]},
{"name": "conv_pw_2_relu/Relu6", "depth": 128, "tags": ["conv"]},
{"name": "conv_pw_3_relu/Relu6", "depth": 128, "tags": ["conv"]},
{"name": "conv_pw_4_relu/Relu6", "depth": 256, "tags": ["conv"]},
{"name": "conv_pw_5_relu/Relu6", "depth": 256, "tags": ["conv"]},
{"name": "conv_pw_6_relu/Relu6", "depth": 512, "tags": ["conv"]},
{"name": "conv_pw_7_relu/Relu6", "depth": 512, "tags": ["conv"]},
{"name": "conv_pw_8_relu/Relu6", "depth": 512, "tags": ["conv"]},
{"name": "conv_pw_9_relu/Relu6", "depth": 512, "tags": ["conv"]},
{"name": "conv_pw_10_relu/Relu6", "depth": 512, "tags": ["conv"]},
{"name": "conv_pw_11_relu/Relu6", "depth": 512, "tags": ["conv"]},
{"name": "conv_pw_12_relu/Relu6", "depth": 1024, "tags": ["conv"]},
{"name": "conv_pw_13_relu/Relu6", "depth": 1024, "tags": ["conv"]},
{"name": "dense/BiasAdd", "depth": 256, "tags": ["dense"]},
{"name": "dense_1/BiasAdd", "depth": 256, "tags": ["dense"]},
{"name": "dense_2/BiasAdd", "depth": 1000, "tags": ["dense"]},
{"name": "softmax/Softmax", "depth": 1000, "tags": ["dense"]},
],
)

output_shapes = {
"conv1_relu/Relu6": (112, 112, 32),
"conv_pw_1_relu/Relu6": (112, 112, 64),
"conv_pw_2_relu/Relu6": (56, 56, 128),
"conv_pw_3_relu/Relu6": (56, 56, 128),
"conv_pw_4_relu/Relu6": (28, 28, 256),
"conv_pw_5_relu/Relu6": (28, 28, 256),
"conv_pw_6_relu/Relu6": (14, 14, 512),
"conv_pw_7_relu/Relu6": (14, 14, 512),
"conv_pw_8_relu/Relu6": (14, 14, 512),
"conv_pw_9_relu/Relu6": (14, 14, 512),
"conv_pw_10_relu/Relu6": (14, 14, 512),
"conv_pw_11_relu/Relu6": (14, 14, 512),
"conv_pw_12_relu/Relu6": (7, 7, 1024),
"conv_pw_13_relu/Relu6": (7, 7, 1024),
"dense/BiasAdd": (256,),
"dense_1/BiasAdd": (256,),
"dense_2/BiasAdd": (1000,),
"softmax/Softmax": (1000,),
}

for layer in CustomModel.layers:
layer.shape = output_shapes[layer.name]
2 changes: 1 addition & 1 deletion lucid/misc/graph_analysis/filter_overlay.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"""


standard_include_ops = ["Placeholder", "Relu", "Relu6", "Add", "Split", "Softmax", "Concat", "ConcatV2", "Conv2D", "MaxPool", "AvgPool", "MatMul"] # Conv2D
standard_include_ops = ["Placeholder", "Relu", "Relu6", "Add", "Split", "Softmax", "Concat", "ConcatV2", "Conv2D", "MaxPool", "AvgPool", "MatMul", "EwZXy"] # Conv2D

def ops_whitelist(graph, include_ops=standard_include_ops):
keep_nodes = [node.name for node in graph.nodes if node.op in include_ops]
Expand Down
53 changes: 35 additions & 18 deletions lucid/misc/graph_analysis/parse_overlay.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,35 @@ def parse_structure(node):
structure = node.sub_structure

if structure is None:
return node.name
return {
"type" : "Node",
"name": node.name
}
elif structure.structure_type == "Sequence":
return {"Sequence" : [parse_structure(n) for n in structure.structure["sequence"]]}
return {
"type" : "Sequence",
"children": [parse_structure(n) for n in structure.structure["sequence"]]
}
elif structure.structure_type == "HeadBranch":
return {"Sequence" : [
{"Branch" : [parse_structure(n) for n in structure.structure["branches"]] },
parse_structure(structure.structure["head"])
]}
return {
"type" : "Sequence",
"children": [{
"type": "Branch",
"children": [parse_structure(n) for n in structure.structure["branches"]]
},
parse_structure(structure.structure["head"])]
}
elif structure.structure_type == "TailBranch":
return {"Sequence" : [
return {
"type" : "Sequence",
"children": [
parse_structure(structure.structure["tail"]),
{"Branch" : [parse_structure(n) for n in structure.structure["branches"]] },
]}
{
"type": "Branch",
"subtype": "AuxilliaryHeadBranch",
"children": [parse_structure(n) for n in structure.structure["branches"]]
}]
}
else:
data = {}
for k in structure.structure:
Expand All @@ -70,26 +86,27 @@ def parse_structure(node):
else:
data[k] = parse_structure(structure.structure[k])

return {structure.structure_type : data}
data["type"] = structure.structure_type
return data


def flatten_sequences(structure):
"""Flatten nested sequences into a single sequence."""
if isinstance(structure, str) or structure is None:
if isinstance(structure, str) or (isinstance(structure, dict) and structure["type"] == "Node") or structure is None:
return structure
else:
structure = structure.copy()
for k in structure:
structure[k] = [flatten_sequences(sub) for sub in structure[k]]
if "children" in structure:
structure["children"] = [flatten_sequences(sub) for sub in structure["children"]]

if "Sequence" in structure:
if structure["type"] == "Sequence":
new_seq = []
for sub in structure["Sequence"]:
if isinstance(sub, dict) and "Sequence" in sub:
new_seq += sub["Sequence"]
for sub in structure["children"]:
if isinstance(sub, dict) and sub["type"] == "Sequence":
new_seq += sub["children"]
else:
new_seq.append(sub)
structure["Sequence"] = new_seq
structure["children"] = new_seq
return structure


Expand Down
2 changes: 2 additions & 0 deletions lucid/misc/io/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ def load(url_or_handle, allow_unsafe_formats=False, cache=None, **kwargs):
Args:
url_or_handle: a (reachable) URL, or an already open file handle
allow_unsafe_formats: set to True to allow saving unsafe formats (eg. pickles)
cache: whether to attempt caching the resource. Defaults to True only if
the given URL specifies a remote resource.

Raises:
RuntimeError: If file extension or URL is not supported.
Expand Down
25 changes: 16 additions & 9 deletions lucid/misc/io/reading.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import logging
from urllib.parse import urlparse
from urllib import request
from tensorflow import gfile
from tensorflow.io.gfile import GFile
import tensorflow as tf
from tempfile import gettempdir
import gc
Expand Down Expand Up @@ -121,7 +121,7 @@ def read_handle(url, cache=None, mode="rb"):


def _handle_gfile(url, mode="rb"):
return gfile.Open(url, mode)
return GFile(url, mode)


def _handle_web_url(url, mode="r"):
Expand All @@ -136,7 +136,7 @@ def _is_remote(scheme):


RESERVED_PATH_CHARS = re.compile("[^a-zA-Z0-9]")
LUCID_CACHE_DIR_NAME = 'lucid_cache'
LUCID_CACHE_DIR_NAME = "lucid_cache"
MAX_FILENAME_LENGTH = 200
_LUCID_CACHE_DIR = None # filled on first use

Expand All @@ -146,16 +146,22 @@ def local_cache_path(remote_url):
"""Returns the path that remote_url would be cached at locally."""
local_name = RESERVED_PATH_CHARS.sub("_", remote_url)
if len(local_name) > MAX_FILENAME_LENGTH:
filename_hash = hashlib.sha256(local_name.encode('utf-8')).hexdigest()
truncated_name = local_name[:(MAX_FILENAME_LENGTH-(len(filename_hash)) - 1)] + '-' + filename_hash
log.debug(f'truncated long cache filename to {truncated_name} (original {len(local_name)} char name: {local_name}')
filename_hash = hashlib.sha256(local_name.encode("utf-8")).hexdigest()
truncated_name = (
local_name[: (MAX_FILENAME_LENGTH - (len(filename_hash)) - 1)]
+ "-"
+ filename_hash
)
log.debug(
f"truncated long cache filename to {truncated_name} (original {len(local_name)} char name: {local_name}"
)
local_name = truncated_name
if _LUCID_CACHE_DIR is None:
_LUCID_CACHE_DIR = os.path.join(gettempdir(), LUCID_CACHE_DIR_NAME)
if not os.path.exists(_LUCID_CACHE_DIR):
# folder might exist if another thread/process creates it concurrently, this would be ok
os.makedirs(_LUCID_CACHE_DIR, exist_ok=True)
log.info(f'created lucid cache dir at {_LUCID_CACHE_DIR}')
log.info(f"created lucid cache dir at {_LUCID_CACHE_DIR}")
return os.path.join(_LUCID_CACHE_DIR, local_name)


Expand All @@ -177,7 +183,7 @@ def _read_and_cache(url, mode="rb"):
with lock:
if os.path.exists(local_path):
log.debug("Found cached file '%s'.", local_path)
return _handle_gfile(local_path)
return _handle_gfile(local_path, mode)
log.debug("Caching URL '%s' locally at '%s'.", url, local_path)
try:
with write_handle(local_path, "wb") as output_handle, read_handle(
Expand All @@ -199,7 +205,8 @@ def _read_and_cache(url, mode="rb"):


from functools import partial
_READ_BUFFER_SIZE = 1048576 # setting a larger value here to help read bigger chunks of files over the network (eg from GCS)

_READ_BUFFER_SIZE = 1048576 # setting a larger value here to help read bigger chunks of files over the network (eg from GCS)


def _file_chunk_iterator(file_handle):
Expand Down
14 changes: 10 additions & 4 deletions lucid/misc/io/scoping.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,20 @@ def _normalize_url(url: str) -> str:
# os.path.normpath mangles url schemes: gs://etc -> gs:/etc
# urlparse.urljoin doesn't normalize paths
url_scheme, sep, url_path = url.partition("://")
normalized_path = os.path.normpath(url_path)
return url_scheme + sep + normalized_path
# 2021-03-12 @ludwig this method is often called with paths that are not URLs.
# thus, url_path may be empty
# in this case we can't call `os.path.normpath(url_path)`
# as it "normalizes" an empty input to "." (current directory)
normalized_path = os.path.normpath(url_path) if url_path else ""
joined = url_scheme + sep + normalized_path
return joined


def scope_url(url, io_scopes=None):
io_scopes = io_scopes or current_io_scopes()
if "//" in url or url.startswith("/"):
if "//" in url or url.startswith("/") or url.startswith("./"):
return url
paths = io_scopes + [url]
joined = os.path.join(*paths)
return _normalize_url(joined)
normalized = _normalize_url(joined)
return normalized
2 changes: 1 addition & 1 deletion lucid/misc/io/showing.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def _image_url(array, fmt='png', mode="data", quality=90, domain=None):

def _image_html(array, w=None, domain=None, fmt='png'):
url = _image_url(array, domain=domain, fmt=fmt)
style = "image-rendering: pixelated;"
style = "image-rendering: pixelated; image-rendering: crisp-edges;"
if w is not None:
style += "width: {w}px;".format(w=w)
return """<img src="{url}" style="{style}">""".format(**locals())
Expand Down
2 changes: 1 addition & 1 deletion lucid/misc/iter_nd_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
def recursive_enumerate_nd(it, stop_iter=None, prefix=()):
"""Recursively enumerate nested iterables with tuples n-dimenional indices.

Arguments:
Args:
it: object to be enumerated
stop_iter: User defined funciton which can conditionally block further
iteration. Defaults to allowing iteration.
Expand Down
4 changes: 2 additions & 2 deletions lucid/misc/stimuli.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def img_f(x,y):
return (negative if interior, positive if exterior)
img = sampler(img_f)

Arguments:
Args:
size: Size of image to be rendered in pixels.
alias_factor: Number of samples to use in aliasing.
color_a: Color of exterior. A 3-tuple of floats between 0 and 1. Defaults
Expand Down Expand Up @@ -148,7 +148,7 @@ def rounded_corner(orientation, r, angular_width=90, size=224, **kwds):
This function is a flexible generator of "rounded corner" stimuli. It returns
an image, represented as a numpy array of shape [size, size, 3].

Arguments:
Args:
orientation: The orientation of the curve, in degrees.
r: radius of the curve
angular_width: when r=0 and we have sharp corner, this controls the angle
Expand Down
23 changes: 23 additions & 0 deletions lucid/modelzoo/other_models/Clip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright 2018 The Lucid Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from lucid.modelzoo.vision_base import Model

class Clip_ResNet50_4x(Model):
image_value_range = (0, 255)
input_name = 'input_image'
model_name = "Clip_ResNet50_4x"
image_shape = [288, 288, 3]
model_path = "gs://modelzoo/vision/other_models/Clip_ResNet50_4x.pb"
Loading