diff --git a/notebooks/bakllava-inference-from-python.ipynb b/notebooks/bakllava-inference-from-python.ipynb new file mode 100644 index 0000000..f449a51 --- /dev/null +++ b/notebooks/bakllava-inference-from-python.ipynb @@ -0,0 +1,256 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "435cb4a0-ab3e-4e73-9817-feb946f112f1", + "metadata": {}, + "source": [ + "### Configure:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "c9c3ce6a-1e62-4507-86a8-233f77f525d4", + "metadata": {}, + "outputs": [], + "source": [ + "low_gpu_memory_optimization = True" + ] + }, + { + "cell_type": "markdown", + "id": "c3cf0eaf-bac6-4f3b-a561-88f50bf9a6d8", + "metadata": {}, + "source": [ + "### Install dependencies:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "65d3fb1d-be7e-4828-a583-53e5ab56a1d5", + "metadata": {}, + "outputs": [], + "source": [ + "!git clone https://github.com/SkunkworksAI/BakLLaVA.git\n", + "%cd BakLLaVA\n", + "!pip install -e .\n", + "!pip uninstall transformers -y\n", + "!pip install transformers==4.34.0" + ] + }, + { + "cell_type": "markdown", + "id": "f1670688-a179-456c-936b-a6957d9c2b46", + "metadata": {}, + "source": [ + "### Run:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "f87075a7-c43b-4e63-8d9c-a999eb631d0d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2023-10-20 03:29:22,155] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n" + ] + } + ], + "source": [ + "from transformers import AutoConfig, AutoTokenizer\n", + "from llava.model.language_model.llava_mistral import LlavaMistralForCausalLM\n", + "from huggingface_hub import notebook_login\n", + "from llava.mm_utils import tokenizer_image_token, KeywordsStoppingCriteria\n", + "\n", + "from PIL import Image\n", + "import requests\n", + "from io import BytesIO\n", + "\n", + "from llava.conversation import conv_templates, SeparatorStyle\n", + "from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN\n", + "\n", + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "708e45d8-8c4a-4888-bb25-4ebe9a0af993", + "metadata": {}, + "outputs": [], + "source": [ + "notebook_login()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8393e364-3cfc-4ba8-bd11-0bd183b2a835", + "metadata": {}, + "outputs": [], + "source": [ + "model_path = \"SkunkworksAI/BakLLaVA-1\"\n", + "\n", + "cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True)\n", + "tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)\n", + "if low_gpu_memory_optimization:\n", + " model = LlavaMistralForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, load_in_8bit=True, config=cfg_pretrained)\n", + "else:\n", + " model = LlavaMistralForCausalLM.from_pretrained(model_path, config=cfg_pretrained)\n", + " model.to(\"cuda\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "2ab75125-9a4b-4df1-b314-4b4e96ad8c94", + "metadata": {}, + "outputs": [], + "source": [ + "mm_use_im_start_end = getattr(model.config, \"mm_use_im_start_end\", False)\n", + "mm_use_im_patch_token = getattr(model.config, \"mm_use_im_patch_token\", True)\n", + "if mm_use_im_patch_token:\n", + " tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)\n", + "if mm_use_im_start_end:\n", + " tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)\n", + "model.resize_token_embeddings(len(tokenizer))\n", + "\n", + "vision_tower = model.get_vision_tower()\n", + "if not vision_tower.is_loaded:\n", + " vision_tower.load_model()\n", + "vision_tower.to(device='cuda', dtype=torch.float16)\n", + "image_processor = vision_tower.image_processor\n", + "\n", + "if hasattr(model.config, \"max_sequence_length\"):\n", + " context_len = model.config.max_sequence_length\n", + "else:\n", + " context_len = 2048" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "64ecf511-ac4f-4bea-b0dd-d8f5611a9d58", + "metadata": {}, + "outputs": [], + "source": [ + "def load_image(image_file):\n", + " if image_file.startswith('http') or image_file.startswith('https'):\n", + " response = requests.get(image_file)\n", + " image = Image.open(BytesIO(response.content)).convert('RGB')\n", + " else:\n", + " image = Image.open(image_file).convert('RGB')\n", + " return image\n", + "\n", + "\n", + "# image = load_image(\"https://t4.ftcdn.net/jpg/00/97/58/97/360_F_97589769_t45CqXyzjz0KXwoBZT9PRaWGHRk5hQqQ.jpg\")\n", + "image = load_image(\"https://cdn.discordapp.com/attachments/1096822099345145969/1164641565550067852/heart_1.png?ex=6543f3fb&is=65317efb&hm=448cb26e19c141871e776af98077c4c1e97a8f29b96916ab671e5010c00e3625&\")\n", + "\n", + "if low_gpu_memory_optimization:\n", + " image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda()\n", + "else:\n", + " image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].cuda()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "dd3526fa-edb4-41e4-9494-f43f5fec2c45", + "metadata": {}, + "outputs": [], + "source": [ + "query = \"Describe this image\"\n", + "\n", + "if model.config.mm_use_im_start_end:\n", + " query = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\\n' + query\n", + "else:\n", + " query = DEFAULT_IMAGE_TOKEN + '\\n' + query\n", + "\n", + "conv = conv_templates[\"llava_v1\"].copy()\n", + "\n", + "conv.append_message(conv.roles[0], query)\n", + "conv.append_message(conv.roles[1], None)\n", + "prompt = conv.get_prompt()\n", + "input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()\n", + "\n", + "stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2\n", + "keywords = [stop_str]\n", + "stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "1383e630-8651-4eb7-898a-7312094f7b9e", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n", + "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The image features a detailed illustration of a human heart, showcasing its various parts and blood vessels. The heart is depicted in full color, with its interior and exterior structures clearly visible.\n", + "\n", + "The heart is surrounded by a network of blood vessels, including arteries and veins. There are at least six visible arteries, some of which are located near the top and bottom of the heart, while others are situated on its right side. Additionally, there are five visible veins, with some located near the top and bottom of the heart, and others on the left side. The arrangement of these blood vessels highlights the complex circulatory system that sustains the heart itself.\n" + ] + } + ], + "source": [ + "with torch.inference_mode():\n", + " output_ids = model.generate(\n", + " input_ids=input_ids,\n", + " images=image_tensor,\n", + " do_sample=True,\n", + " temperature=0.2,\n", + " max_new_tokens=1024,\n", + " use_cache=True,\n", + " stopping_criteria=[stopping_criteria])\n", + "\n", + "input_token_len = input_ids.shape[1]\n", + "n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()\n", + "if n_diff_input_output > 0:\n", + " print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')\n", + "outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]\n", + "outputs = outputs.strip()\n", + "if outputs.endswith(stop_str):\n", + " outputs = outputs[:-len(stop_str)]\n", + "outputs = outputs.strip()\n", + "print(outputs)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}