Using the new ViT from TF Hub 🎉

I just tried out the new Vision Transformer (ViT) on TensorFlow Hub by adding two lines of code in the Custom component. Below is the full code, but the only lines I added was:

import tensorflow_hub as hub

and

input_=hub.KerasLayer(“https://tfhub.dev/sayakpaul/vit_s16_classification/1”)(input_)

NOTE: I had to install (pip install tensorflow_hub) in the same environment as perceptilabs was installed.

import tensorflow_hub as hub
class LayerCustom_LayerCustom_1Keras(tf.keras.layers.Layer, PerceptiLabsVisualizer):
    def call(self, inputs, training=True):
        """ Takes a tensor and one-hot encodes it """
    input_ = inputs['input']
    input_=hub.KerasLayer("https://tfhub.dev/sayakpaul/vit_s16_classification/1")(input_)
    output = preview = input_
    self._outputs = {            
        'output': output,
        'preview': output,
    }
    return self._outputs
def get_config(self):
    """Any variables belonging to this layer that should be rendered in the frontend.
    Returns:
        A dictionary with tensor names for keys and picklable for values.
    """
    return {}
@property
def visualized_trainables(self):
    """ Returns two tf.Variables (weights, biases) to be visualized in the frontend """
    return tf.constant(0), tf.constant(0)
class LayerCustom_LayerCustom_1(Tf2xLayer):
    def __init__(self):
        super().__init__(
            keras_class=LayerCustom_LayerCustom_1Keras
        )
3 Likes

Just to give this a little more context, here’s Papers with Code for the Vision Transformer.

TL;DR the abstract of the original paper An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale

While the Transformer architecture has become the de-facto standard for natural
language processing tasks, its applications to computer vision remain limited. In
vision, attention is either applied in conjunction with convolutional networks, or
used to replace certain components of convolutional networks while keeping their
overall structure in place. We show that this reliance on CNNs is not necessary
and a pure transformer applied directly to sequences of image patches can perform
very well on image classification tasks. When pre-trained on large amounts of
data and transferred to multiple mid-sized or small image recognition benchmarks
(ImageNet, CIFAR-100, VTAB, etc.), Vision Transformer (ViT) attains excellent
results compared to state-of-the-art convolutional networks while requiring substantially
fewer computational resources to train.

(Emphasis added)

2 Likes

While it works to import the model with hub.KerasLayer i ran into an issue when trying to train:

Error during training!
Traceback (most recent call last):
  File "perceptilabs/coreInterface.py", line 32, in perceptilabs.coreInterface.TrainingSessionInterface.run_stepwise
  File "perceptilabs/coreInterface.py", line 33, in perceptilabs.coreInterface.TrainingSessionInterface.run_stepwise
  File "perceptilabs/coreInterface.py", line 52, in _main_loop
  File "perceptilabs/trainer/base.py", line 174, in run_stepwise
  File "perceptilabs/trainer/base.py", line 282, in _loop_over_dataset
  File "/home/joakim/anaconda3/envs/perceptilabs/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 889, in __call__
    result = self._call(*args, **kwds)
  File "/home/joakim/anaconda3/envs/perceptilabs/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 950, in _call
    return self._stateless_fn(*args, **kwds)
  File "/home/joakim/anaconda3/envs/perceptilabs/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 3024, in __call__
    filtered_flat_args, captured_inputs=graph_function.captured_inputs)  # pylint: disable=protected-access
  File "/home/joakim/anaconda3/envs/perceptilabs/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1961, in _call_flat
    ctx, args, cancellation_manager=cancellation_manager))
  File "/home/joakim/anaconda3/envs/perceptilabs/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 596, in call
    ctx=ctx)
  File "/home/joakim/anaconda3/envs/perceptilabs/lib/python3.7/site-packages/tensorflow/python/eager/execute.py", line 60, in quick_execute
    inputs, attrs, num_outputs)
tensorflow.python.framework.errors_impl.FailedPreconditionError: 2 root error(s) found.
  (0) Failed precondition:  Read variable failure _AnonymousVar1073. It could mean the variable is uninitialized or the variable is on another device 
	 [[{{node jax2tf_arg_0/ReadVariableOp}}]]
	 [[training_model/layer_custom__layer_custom_1_keras/keras_layer/StatefulPartitionedCall]]
  (1) Failed precondition:  Read variable failure _AnonymousVar1073. It could mean the variable is uninitialized or the variable is on another device 
	 [[{{node jax2tf_arg_0/ReadVariableOp}}]]
	 [[training_model/layer_custom__layer_custom_1_keras/keras_layer/StatefulPartitionedCall]]
	 [[Identity_9/_6]]
0 successful operations.
0 derived errors ignored. [Op:__inference__work_on_batch_73978]
Function call stack:
_work_on_batch -> _work_on_batch

The same thing happens when i try to train on the converted pytorch model discussed here
Any ideas?

Hi @birdstream

This is not going to be the most helpful answer - unless I can track down the thing I am about to mention - but I notice ‘jax2tf’ in the errors… and I know I was trying something that relied on jax a while ago and hit problems…

Maybe @robertl will know better. I will keep looking here though… I keep notes on just about everything, alas “jax” seems to crop up in the encoding of a lot of images in ipynb files

UPDATE

I do however have a jupyter notebook that handles the vit_s16_classification model OK, could be useful for comparison. I just fixed it up (by pip installing wget etc.) and the picture was classified as elephant - from the same environment that PL V0.13.1 is installed in (PY 3.8.10)

classification_ViT.ipynb (149.5 KB)

Hey @birdstream,

We looked into it (or rather @mukund_s took a look ) and the issue seems to be that the hub.KerasLayer("https://tfhub.dev/sayakpaul/vit_s16_classification/1") is inside the call function instead of the build function.
The call function is called every time any inference happens while the build function only is called once. What likely happens here is that calling the hub call multiple times causes this crash.

Here’s a working script you can copy into the Custom Layer and train with:

class LayerCustom_LayerCustom_1Keras(tf.keras.layers.Layer, PerceptiLabsVisualizer):
    def call(self, inputs, training=True):
        """ Takes a tensor and one-hot encodes it """
        input_ = inputs['input']
        output = self.vit(input_)
        self._outputs = {            
            'output': output,
            'preview': output,
        }
        return self._outputs
    def build(self, input_shape):
        import tensorflow_hub as hub
        self.vit = hub.KerasLayer("https://tfhub.dev/sayakpaul/vit_s16_classification/1")
    def get_config(self):
        """Any variables belonging to this layer that should be rendered in the frontend.
        Returns:
            A dictionary with tensor names for keys and picklable for values.
        """
        return {}
    @property
    def visualized_trainables(self):
        """ Returns two tf.Variables (weights, biases) to be visualized in the frontend """
        return tf.constant(0), tf.constant(0)
class LayerCustom_LayerCustom_1(Tf2xLayer):
    def __init__(self):
        super().__init__(
            keras_class=LayerCustom_LayerCustom_1Keras
        )

Hope that works! :slight_smile:

1 Like

Yes it does! :slight_smile: Thanks a lot! (I’ll update my post about converting pytorch to tensorflow/keras)
And thank you @mukund_s

2 Likes