Skip to content

Commit

Permalink
Adding Bytestream model (tensorflow#10731)
Browse files Browse the repository at this point in the history
Co-authored-by: Arun Kandoor <[email protected]>
  • Loading branch information
karunreddy30 and Arun Kandoor authored Aug 2, 2022
1 parent 2659c4e commit 50e8670
Show file tree
Hide file tree
Showing 7 changed files with 417 additions and 2 deletions.
20 changes: 19 additions & 1 deletion research/seq_flow_lite/input_fn_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_text as tftext

from layers import projection_layers # import seq_flow_lite module
from utils import misc_utils # import seq_flow_lite module
Expand Down Expand Up @@ -61,7 +62,24 @@ def _post_processor(features, batch_size):
label = tf.reshape(label, [batch_size, num_classes])
prxlayer = projection_layers.ProjectionLayer(model_config, mode)
projection, seq_length = prxlayer(text)
return {"projection": projection, "seq_length": seq_length, "label": label}
gbst_max_token_len = max_seq_len
if "gbst_max_token_len" in model_config:
gbst_max_token_len = model_config["gbst_max_token_len"]
byte_int = tftext.ByteSplitter().split(text).to_tensor(
default_value=0, shape=[batch_size, gbst_max_token_len])
token_ids = tf.cast(byte_int, tf.int32)
token_len = tf.strings.length(text)
mask = tf.cast(
tf.sequence_mask(token_len, maxlen=gbst_max_token_len), tf.int32)
mask *= 3
token_ids += mask
return {
"projection": projection,
"seq_length": seq_length,
"token_ids": token_ids,
"token_len": token_len,
"label": label
}

def _input_fn(params):
"""Method to be used for reading the data."""
Expand Down
14 changes: 14 additions & 0 deletions research/seq_flow_lite/layers/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,11 @@ py_strict_library(
srcs_version = "PY3",
deps = [
# package tensorflow
":embedding_layers",
"//layers:base_layers", # sequence projection
"//layers:conv_layers",
"//layers:dense_layers", # sequence projection
"//layers:normalization_layers",
"//layers:quantization_layers", # sequence projection
],
)
Expand All @@ -102,3 +105,14 @@ py_strict_library(
"//tf_ops:tf_custom_ops_py", # sequence projection
],
)

py_strict_library(
name = "embedding_layers",
srcs = ["embedding_layers.py"],
srcs_version = "PY3",
deps = [
# package tensorflow
"//layers:base_layers",
"//layers:quantization_layers",
],
)
75 changes: 75 additions & 0 deletions research/seq_flow_lite/layers/embedding_layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright 2020 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Layers for embedding."""
import tensorflow as tf

from layers import base_layers
from layers import quantization_layers


class EmbeddingLayer(base_layers.BaseLayer):
"""Embedding layer."""

def __init__(self,
shape,
num_bits=8,
initializer=None,
trainable=True,
**kwargs):
self.shape = shape
self.quantizer = quantization_layers.ActivationQuantization(
num_bits=num_bits, **kwargs)
super(EmbeddingLayer, self).__init__(**kwargs)
if initializer is None:
initializer = tf.keras.initializers.GlorotUniform()
self.initializer = initializer
self.trainable = trainable

def build(self, input_shapes):
self.embedding_table = self.add_weight(
name="embedding_table",
shape=self.shape,
initializer=self.initializer,
trainable=self.trainable,
dtype=tf.float32)
if self.trainable:
self.add_reg_loss(self.embedding_table)

def call(self, indices):
assert indices.dtype in [tf.int64, tf.int32]
outputs = tf.nn.embedding_lookup(self.embedding_table, indices)
return self.quantizer(outputs)


class EmbeddingFullyConnected(EmbeddingLayer):
"""Uses embedding table as weights in a fully connected op."""

def __init__(self, **kwargs):
shape = kwargs.pop("shape", None)
initializer = kwargs.pop("initializer", None)
self.qoutput = quantization_layers.ActivationQuantization(**kwargs)
super(EmbeddingFullyConnected, self).__init__(
shape=shape, initializer=initializer, **kwargs)

def fully_connected(self, inputs, bias=None, weights_scale_factor=None):
# This method can only be called after a call to "call" method in this class
self._assert_rank_and_type(inputs, 2)
weights = self.embedding_table
if weights_scale_factor is not None:
weights = weights * weights_scale_factor
outputs = tf.matmul(inputs, weights, transpose_b=True)
if bias is not None:
outputs = tf.nn.bias_add(outputs, bias)
return self.qoutput(outputs)
147 changes: 147 additions & 0 deletions research/seq_flow_lite/layers/misc_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@
# ==============================================================================
# Lint as: python3
"""Layers for embedding."""
import math
import tensorflow as tf

from layers import base_layers # import seq_flow_lite module
from layers import conv_layers
from layers import dense_layers # import seq_flow_lite module
from layers import embedding_layers
from layers import quantization_layers # import seq_flow_lite module


Expand Down Expand Up @@ -92,3 +95,147 @@ def call(self, keys, queries, sequence_length):
# seq_dim = tf.shape(result)[1]
# result = tf.reshape(result, [1, seq_dim, seq_dim])
return result


class GBSTLayerV2(base_layers.BaseLayer):
"""Tokenization layer."""

def __init__(self,
feature_size,
max_seq_len,
downsample_rate=2,
max_subword_block_width=4,
conv_kernel_size=5,
block_mixing_mode=None,
add_block_pos_embed=False,
**kwargs):
super(GBSTLayerV2, self).__init__(**kwargs)
self.feature_size = feature_size
self.max_seq_len = max_seq_len
self.downsample_rate = downsample_rate
self.subword_blocks_width = [1, 2, 3, 4]
self.max_subword_block_width = len(self.subword_blocks_width)
self.block_mixing_mode = block_mixing_mode

self.add_block_pos_embed = add_block_pos_embed
if self.add_block_pos_embed:
self.block_pos_embedding = embedding_layers.EmbeddingLayer(
shape=[self.max_subword_block_width, self.feature_size], **kwargs)
self.conv_kernel_size = conv_kernel_size
self.conv_layer = conv_layers.EncoderQConvolution(
filters=feature_size,
ksize=conv_kernel_size,
rank=3,
padding="VALID",
activation=None,
**kwargs)
padding = [conv_kernel_size - 1, 0]
self.zero_pad = tf.keras.layers.ZeroPadding1D(padding=padding)
self.block_attn = dense_layers.BaseQDense(
units=1,
rank=3,
activation=None,
normalize=False,
quantize_output=False,
**kwargs)
self.scores_concat = quantization_layers.ConcatQuantization(
axis=3, **kwargs)
self.attn_concat = quantization_layers.ConcatQuantization(axis=0, **kwargs)
self.qact = quantization_layers.ActivationQuantization(**kwargs)
self.qact_dot = quantization_layers.ActivationQuantization(**kwargs)
self.qoutput = quantization_layers.ActivationQuantization(**kwargs)

def call(self, inputs, seq_length):
"""Performs downsampling on the character-scale input representation.
Based in principle on https://arxiv.org/pdf/2106.12672.pdf.
Args:
inputs: float Tensor of shape [batch_size, seq_length, embedding_size].
seq_length: sequence length of shape [batch_size].
Returns:
<float>[batch_size, seq_length / downsample_rate, embedding_size].
Downsampled sequences.
"""
self._assert_rank_and_type(inputs, 3)
bsz = self.get_batch_dimension(inputs)
max_seq_len = self.max_seq_len

if self.parameters.mode in [base_layers.PREDICT, base_layers.TFLITE]:
num_steps = tf.shape(inputs)[1]

inputs = self.zero_pad(inputs)
inputs = self.conv_layer(inputs)

all_block_scores = []
all_sequences = []
for subword_len in self.subword_blocks_width:
if self.add_block_pos_embed:
block_pos_indices = tf.range(subword_len, dtype=tf.int32)
block_pos_indices = tf.reshape(block_pos_indices, [1, -1])
block_pos_embeds = self.block_pos_embedding(block_pos_indices)
tile_len = math.ceil(max_seq_len / float(subword_len))
retiled_block_pos_embeds = tf.repeat(block_pos_embeds, tile_len, axis=1)
inputs += retiled_block_pos_embeds
# For this block size, form candidate block embeddings and scores.
# candidates shape: [batch, seq_len/subword_len, dim]
# block_scores shape: [batch, seq_len/subword_len, 1]
candidates = tf.nn.avg_pool(
inputs, [subword_len], strides=[subword_len], padding="SAME")
candidates = self.conv_layer.quantize_using_output_range(candidates)

block_scores = self.block_attn(candidates)
# Upsample it back to the original sequence length.
retiled_seq = tf.repeat(candidates, subword_len, axis=1)
retiled_block_scores = tf.repeat(block_scores, subword_len, axis=1)

# Make sure everything is the right length and add new dimension to concat
# candidate blocks on.
if self.parameters.mode in [base_layers.PREDICT, base_layers.TFLITE]:
retiled_block_scores = retiled_block_scores[:, :num_steps, :]
retiled_seq = retiled_seq[:, :num_steps, :]
else:
retiled_block_scores = retiled_block_scores[:, :max_seq_len, :]
retiled_seq = retiled_seq[:, :max_seq_len, :]
retiled_seq = tf.expand_dims(retiled_seq, axis=-1)
retiled_block_scores = tf.expand_dims(retiled_block_scores, axis=-1)
all_sequences.append(retiled_seq)
all_block_scores.append(retiled_block_scores)

block_net = self.scores_concat(all_block_scores)
if self.block_mixing_mode == "score_attention":
if self.parameters.mode in [base_layers.PREDICT, base_layers.TFLITE]:
block_attn_steps = []
self.attn_concat(None)
for i in range(num_steps):
block_i = tf.reshape(block_net[:, i:i + 1, :, :], [1, -1])
block_attn_steps.append(tf.matmul(block_i, block_i, transpose_b=True))
block_attn = self.attn_concat(block_attn_steps)
block_attn = tf.reshape(block_attn, [bsz, -1, 1, 1])
else:
block_attn = self.attn_concat(
[tf.matmul(block_net, block_net, transpose_b=True)])

block_attn = tf.nn.softmax(block_attn, axis=1)
block_attn = self.qrange_sigmoid(block_attn, tf_only=True)
block_net_scaled = self.qact(block_attn * block_net)
else:
block_net_scaled = block_net

candidate_embeds = self.conv_layer.quantize_using_output_range(
tf.concat(all_sequences, axis=3))
dot_product = self.qact_dot(block_net_scaled * candidate_embeds)
output = self.qoutput(tf.reduce_mean(dot_product, axis=-1, keepdims=True))
output = tf.reshape(output, [bsz, -1, self.feature_size])

# Removing pad entries for inference mode.
if self.parameters.mode in [base_layers.PREDICT, base_layers.TFLITE]:
output = output[:, :num_steps, :]
# Downsample by mean pooling.
if self.downsample_rate > 1:
output = tf.nn.avg_pool(
output, (self.downsample_rate,),
strides=(self.downsample_rate,),
padding="VALID")
return output
16 changes: 16 additions & 0 deletions research/seq_flow_lite/models/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,19 @@ py_library(
"//tf_ops:tf_custom_ops_py", # sequence projection
],
)

py_library(
name = "byteqrnn",
srcs = ["byteqrnn.py"],
srcs_version = "PY3",
deps = [
# package tensorflow
"//layers:base_layers",
"//layers:dense_layers",
"//layers:embedding_layers",
"//layers:misc_layers",
"//layers:qrnn_layers",
# //tf_ops:tf_custom_ops",
"//tf_ops:tf_custom_ops_py",
],
)
Loading

0 comments on commit 50e8670

Please sign in to comment.