Skip to content

Commit

Permalink
Clarify script comments a bit
Browse files Browse the repository at this point in the history
  • Loading branch information
LetiP committed Mar 12, 2024
1 parent c7e53ae commit 00a66bf
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 26 deletions.
22 changes: 13 additions & 9 deletions mm-shap_albef_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,21 +129,24 @@ def pre_caption(caption, max_words=30):

def custom_masker(mask, x):
"""
Shap relevant function. Defines the masking function so the shap computation
can 'know' how the model prediction looks like when some tokens are masked.
Shap relevant function.
It gets a mask from the shap library with truth values about which image and text tokens to mask (False) and which not (True).
It defines how to mask the text tokens and masks the text tokens. So far, we don't mask the image, but have only defined which image tokens to mask. The image tokens masking happens in get_model_prediction().
"""
masked_X = x.clone()
mask = torch.tensor(mask).unsqueeze(0)
masked_X[~mask] = 0 # ~mask !!! to zero
# never mask out CLS and SEP tokens (makes no sense for the model to work without them)
masked_X[0, 0] = 101 # start token ALBEF
# masked_X[0, text_length_tok-1] = 4624 # sep token ALBEF (no TOKEN!!!)
# masked_X[0, nb_text_tokens-1] = 4624 # sep token ALBEF (no TOKEN!!!)
return masked_X


def get_model_prediction(x):
"""
Shap relevant function. Predict the model output for all combinations of masked tokens.
Shap relevant function.
1. Mask the image pixel according to the specified patches to mask from the custom masker.
2. Predict the model output for all combinations of masked image and tokens. This is then further passed to the shap libary.
"""
with torch.no_grad():
# split up the input_ids and the image_token_ids from x (containing both appended)
Expand All @@ -161,7 +164,7 @@ def get_model_prediction(x):

# call the model for each "new image" generated with masked features
for i in range(input_ids.shape[0]):
# here the actual masking of ALBEF is happening. The custom masker only specified which patches to mask, but no actual masking has happened
# here the actual masking of the image is happening. The custom masker only specified which patches to mask, but no actual masking has happened
masked_text_inputs = text_input.copy()
masked_text_inputs['input_ids'] = input_ids[i].unsqueeze(0)
masked_image = copy.deepcopy(image)
Expand Down Expand Up @@ -277,9 +280,10 @@ def load_models():
image = image.cpu()
text_input = text_input.to(image.device)

text_length_tok = text_input.input_ids.shape[1]
p = int(math.ceil(np.sqrt(text_length_tok)))
patch_size = 384 // p # 384 image size albef
nb_text_tokens = text_input.input_ids.shape[1] # number of text tokens
# calculate the number of patches needed to cover the image
p = int(math.ceil(np.sqrt(nb_text_tokens)))
patch_size = 384 // p # 384 is the image size for ALBEF
image_token_ids = torch.tensor(range(1, p**2+1)).unsqueeze(0) # take one less because CLS and SEP tokens do not count

# make a cobination between tokens and pixel_values (transform to patches first)
Expand All @@ -290,7 +294,7 @@ def load_models():
explainer = shap.Explainer(
get_model_prediction, custom_masker, silent=True)
shap_values = explainer(X)
mm_score = compute_mm_score(text_length_tok, shap_values)
mm_score = compute_mm_score(nb_text_tokens, shap_values)

if k == 0:
which = 'caption'
Expand Down
19 changes: 11 additions & 8 deletions mm-shap_clip_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,21 +54,24 @@

def custom_masker(mask, x):
"""
Shap relevant function. Defines the masking function so the shap computation
can 'know' how the model prediction looks like when some tokens are masked.
Shap relevant function.
It gets a mask from the shap library with truth values about which image and text tokens to mask (False) and which not (True).
It defines how to mask the text tokens and masks the text tokens. So far, we don't mask the image, but have only defined which image tokens to mask. The image tokens masking happens in get_model_prediction().
"""
masked_X = x.clone()
mask = torch.tensor(mask).unsqueeze(0)
masked_X[~mask] = 0 # ~mask !!! to zero
# never mask out CLS and SEP tokens (makes no sense for the model to work without them)
masked_X[0, 0] = 49406
masked_X[0, text_length_tok-1] = 49407
masked_X[0, nb_text_tokens-1] = 49407
return masked_X


def get_model_prediction(x):
"""
Shap relevant function. Predict the model output for all combinations of masked tokens.
Shap relevant function.
1. Mask the image pixel according to the specified patches to mask from the custom masker.
2. Predict the model output for all combinations of masked image and tokens. This is then further passed to the shap libary.
"""
with torch.no_grad():
# split up the input_ids and the image_token_ids from x (containing both appended)
Expand All @@ -82,7 +85,7 @@ def get_model_prediction(x):

# call the model for each "new image" generated with masked features
for i in range(input_ids.shape[0]):
# here the actual masking of CLIP is happening. The custom masker only specified which patches to mask, but no actual masking has happened
# here the actual masking of the image is happening. The custom masker only specified which patches to mask, but no actual masking has happened
masked_inputs = copy.deepcopy(inputs) # initialize the thing
masked_inputs['input_ids'] = input_ids[i].unsqueeze(0)

Expand Down Expand Up @@ -168,8 +171,8 @@ def load_models():
continue
model_prediction = model(**inputs).logits_per_image[0,0].item()

text_length_tok = inputs.input_ids.shape[1]
p = int(math.ceil(np.sqrt(text_length_tok)))
nb_text_tokens = inputs.input_ids.shape[1]
p = int(math.ceil(np.sqrt(nb_text_tokens)))
patch_size = 224 // p
image_token_ids = torch.tensor(
range(1, p**2+1)).unsqueeze(0) # (inputs.pixel_values.shape[-1] // patch_size)**2 +1
Expand All @@ -181,7 +184,7 @@ def load_models():
explainer = shap.Explainer(
get_model_prediction, custom_masker, silent=True)
shap_values = explainer(X)
mm_score = compute_mm_score(text_length_tok, shap_values)
mm_score = compute_mm_score(nb_text_tokens, shap_values)

if k == 0:
which = 'caption'
Expand Down
22 changes: 13 additions & 9 deletions mm-shap_lxmert_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,21 +57,24 @@

def custom_masker(mask, x):
"""
Shap relevant function. Defines the masking function so the shap computation
can 'know' how the model prediction looks like when some tokens are masked.
Shap relevant function.
It gets a mask from the shap library with truth values about which image and text tokens to mask (False) and which not (True).
It defines how to mask the text tokens and masks the text tokens. So far, we don't mask the image, but have only defined which image tokens to mask. The image tokens masking happens in get_model_prediction().
"""
masked_X = x.clone()
mask = torch.tensor(mask).unsqueeze(0)
masked_X[~mask] = 0
# never mask out CLS and SEP tokens (makes no sense for the model to work without them)
masked_X[0, 0] = 101
masked_X[0, text_length_tok-1] = 102
masked_X[0, nb_text_tokens-1] = 102
return masked_X


def get_model_prediction(x):
"""
Shap relevant function. Predict the model output for all combinations of masked tokens.
Shap relevant function.
1. Mask the image pixel according to the specified patches to mask from the custom masker.
2. Predict the model output for all combinations of masked image and tokens. This is then further passed to the shap libary.
"""
# split up the input_ids and the image_token_ids from x (containing both appended)
input_ids = torch.tensor(x[:, :inputs.input_ids.shape[1]]).cuda()
Expand All @@ -81,7 +84,8 @@ def get_model_prediction(x):
result = np.zeros(input_ids.shape[0])

# call the model for each "new image" generated with masked features
for i in range(input_ids.shape[0]):
for i in range(input_ids.shape[0]):
# here the actual masking of the image is happening. The custom masker only specified which patches to mask, but no actual masking has happened
masked_images = copy.deepcopy(images).cuda() # do I need deepcopy?

# pathify the image
Expand Down Expand Up @@ -262,10 +266,10 @@ def load_models(task):
output_lxmert['cross_relationship_score']).cpu().detach()[:, 1].item()

# determine text length in number of tokens (after tokenization)
text_length_tok = np.count_nonzero(inputs.input_ids)
p = int(math.ceil(np.sqrt(text_length_tok)))
nb_text_tokens = np.count_nonzero(inputs.input_ids) # number of text tokens
p = int(math.ceil(np.sqrt(nb_text_tokens)))
patch_size_row = images.shape[2] // p # we have 36 image token ids
patch_size_col = images.shape[3] // p # we have text_length_tok-1 image token ids
patch_size_col = images.shape[3] // p # we have nb_text_tokens-1 image token ids

# features.shape[1] = 36 as we have 36 image regions
image_token_ids = torch.tensor(range(1, p**2+1)).unsqueeze(0)
Expand All @@ -276,7 +280,7 @@ def load_models(task):
explainer = shap.Explainer(
get_model_prediction, custom_masker, silent=True)
shap_values = explainer(X)
mm_score = compute_mm_score(text_length_tok, shap_values)
mm_score = compute_mm_score(nb_text_tokens, shap_values)

if k == 0:
which = 'caption'
Expand Down

0 comments on commit 00a66bf

Please sign in to comment.