diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..e0c97d5f92335698e93bdd8dfa799c744fb3c2c5 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -1,35 +1,3 @@
-*.7z filter=lfs diff=lfs merge=lfs -text
-*.arrow filter=lfs diff=lfs merge=lfs -text
-*.bin filter=lfs diff=lfs merge=lfs -text
-*.bz2 filter=lfs diff=lfs merge=lfs -text
-*.ckpt filter=lfs diff=lfs merge=lfs -text
-*.ftz filter=lfs diff=lfs merge=lfs -text
-*.gz filter=lfs diff=lfs merge=lfs -text
-*.h5 filter=lfs diff=lfs merge=lfs -text
-*.joblib filter=lfs diff=lfs merge=lfs -text
-*.lfs.* filter=lfs diff=lfs merge=lfs -text
-*.mlmodel filter=lfs diff=lfs merge=lfs -text
-*.model filter=lfs diff=lfs merge=lfs -text
-*.msgpack filter=lfs diff=lfs merge=lfs -text
-*.npy filter=lfs diff=lfs merge=lfs -text
-*.npz filter=lfs diff=lfs merge=lfs -text
-*.onnx filter=lfs diff=lfs merge=lfs -text
-*.ot filter=lfs diff=lfs merge=lfs -text
-*.parquet filter=lfs diff=lfs merge=lfs -text
-*.pb filter=lfs diff=lfs merge=lfs -text
-*.pickle filter=lfs diff=lfs merge=lfs -text
-*.pkl filter=lfs diff=lfs merge=lfs -text
-*.pt filter=lfs diff=lfs merge=lfs -text
-*.pth filter=lfs diff=lfs merge=lfs -text
-*.rar filter=lfs diff=lfs merge=lfs -text
-*.safetensors filter=lfs diff=lfs merge=lfs -text
-saved_model/**/* filter=lfs diff=lfs merge=lfs -text
-*.tar.* filter=lfs diff=lfs merge=lfs -text
-*.tar filter=lfs diff=lfs merge=lfs -text
-*.tflite filter=lfs diff=lfs merge=lfs -text
-*.tgz filter=lfs diff=lfs merge=lfs -text
-*.wasm filter=lfs diff=lfs merge=lfs -text
-*.xz filter=lfs diff=lfs merge=lfs -text
-*.zip filter=lfs diff=lfs merge=lfs -text
-*.zst filter=lfs diff=lfs merge=lfs -text
-*tfevents* filter=lfs diff=lfs merge=lfs -text
+version https://git-lfs.github.com/spec/v1
+oid sha256:8628eb71bc7c80d0709a04c69c25570f006ef9ca0dbb27bf2b1a3be74605edda
+size 1619
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..06974d20a76bf416d33438e200c8e11958d4002e
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c71d239df91726fc519c6eb72d318ec65820627232b2f796219e87dcf35d0ab4
+size 11357
diff --git a/README.md b/README.md
index 340a31bb5f9b74dc06138cb27c713a0410017851..474256b9bf0b69331ed3dfdf63ad60ef10246b8f 100644
--- a/README.md
+++ b/README.md
@@ -1,3 +1,256 @@
----
-license: cc-by-sa-4.0
----
+
More Text, Less Point: Towards 3D Data-Efficient Point-Language Understanding
+
+ Yuan Tang* Xu Han* Xianzhi Liโ Qiao Yu Jinfeng Xu Yixue Hao Long Hu Min Chen
+
+ Huazhong University of Science and Technology South China University of Technology
+
+
+
+
+ AAAI 2025
+
+
+
+
+
+
+## ๐ Contents
+
+- [๐ Overview](#-overview)
+- [๐ฆ Training and Evaluation](#-Training-and-Evaluation)
+- [๐ Citation](#-citation)
+- [๐ License](#-license)
+- [๐ Related Work](#-related-work)
+- [๐ Acknowledgements](#-acknowledgements)
+
+## ๐ Overview
+
+
+
+
+- We introduce a new task of 3D data-efficient point-language understanding, aiming to enable LLMs to achieve robust 3D understanding with minimal 3D data.
+- We propose GreenPLM to tackle this 3D data-limited task from a novel perspective, enhancing point-LLM alignment with more free-text data.
+- we introduce a 6M T3D dataset, design a 3-stage training strategy, and present a 0M-Pooling module for token pooling.
+- We introduce the Accuracy-to-3D-Data Ratio (A3DR) to measure the efficiency of 3D data usage and establish an evaluation benchmark based on open-source LLMs.
+- GreenPLM outperforms previous models using only 12\% of 3D data and even surpasses GPT4Point (660K 3D data) using only text, demonstrating superior 3D data efficiency.
+
+
+
+## ๐ฆ Training-and-Evaluation
+
+### Download project
+The **code, weights, and dataset** of the project have already been uploaded to [Hugging Face](https://huggingface.co/YuanTang96/GreenPLM). Simply download them once to get started with the project.
+
+### Install Environment
+Enter the project directory and execute the following command:
+```bash
+conda create -n greenplm python=3.10 -y
+conda activate greenplm
+bash envInstall.sh
+ ```
+
+### Project Directory Introduction
+- `./greenplm/release` contains the paper's weights, training scripts, and testing scripts.
+- `./pretrained_weight` stores the pre-trained weights required for the training and testing phases of the project.
+- `./lava-vicuna_2024_4_Phi-3-mini-4k-instruct` is the weight directory for Phi-3.
+- `./dataset/T3D` is the 6M dataset proposed in this project.
+- `./dataset/T3D/stage_1/brief_1M_caption.json` is the dataset for Stage I.
+- `./dataset/T3D/stage_2/stage_2_data_210k.json` is the dataset for Stage II.
+
+### Dataset Preparation
+
+`./dataset/Objaverse/8192_npy.zip` contains the point cloud data from Objaverse that is required for this project. To unzip the dataset:
+
+```bash
+unzip ./dataset/Objaverse/8192_npy.zip -d ./dataset/Objaverse/
+```
+
+### Inference
+
+#### Paper Weights
+##### GreenPLM-0
+The model trained only on text data, i.e., (Stage I & Stage II).
+
+```bash
+bash ./release/paper/scripts/test/release_stage_2.sh
+```
+The output JSON results are saved in `./release/paper/result_json/stage_2`.
+
+##### GreenPLM
+The model trained on a small amount of 3D data, i.e., (Stage I & Stage II & Stage III).
+
+```bash
+bash ./release/paper/scripts/test/release_stage_3.sh
+```
+The output JSON results are saved in `./release/paper/result_json/stage_3`.
+
+
+
+#### Weights Using All T3D Dataset
+
+ We also provide weights trained using the entire T3D dataset, meaning we use 5M data points from T3D in Stage II, instead of just 210k as in our paper. (click to expand)
+
+##### GreenPLM-0
+The model trained only on text data, i.e., (Stage I & Stage II).
+
+```bash
+bash ./release/5M_data_seting/scripts/test/release_5M_stage_2.sh
+```
+The output JSON results are saved in `./release/5M_data_seting/result_json/stage_2`.
+
+##### GreenPLM
+The model trained on a small amount of 3D data, i.e., (Stage I & Stage II & Stage III).
+
+```bash
+bash ./release/5M_data_seting/scripts/test/release_5M_stage_3.sh
+```
+The output JSON results are saved in `./release/5M_data_seting/result_json/stage_3`.
+
+
+
+
+### Evaluation
+#### Using LLM
+
+ - You can get the **DASHSCOPE_API_KEY** from [aliyun](https://bailian.console.aliyun.com/?apiKey=1#/api-key). The evaluation may require 9 CNY (~ 1.3 USD).
+ - If you have enough GPU resources, you can also build your own Qwen2-72B-Instruct service, following the [Qwen2](https://github.com/QwenLM/Qwen2?tab=readme-ov-file). Then evaluate the results for free!
+
+ 1. Evaluate the open vocabulary classification on objaverse
+ ```bash
+ export PYTHONPATH=$PWD
+ export DASHSCOPE_API_KEY=sk-xxx
+ python ./pointllm/eval/evaluator_opensource_llm_QwenAPI.py \
+ --results_path /path/to/evaluation/PointLLM_brief_description_val_200_GT_Objaverse_classification_prompt0.json \
+ --eval_type open-free-form-classification \
+ --model_type qwen2-72b-instruct \
+ --parallel --num_workers 4
+ ```
+
+ ```bash
+ export PYTHONPATH=$PWD
+ export DASHSCOPE_API_KEY=sk-xxx
+ python ./pointllm/eval/evaluator_opensource_llm_QwenAPI.py \
+ --results_path /path/to/evaluation/PointLLM_brief_description_val_200_GT_Objaverse_classification_prompt1.json \
+ --eval_type open-free-form-classification \
+ --model_type qwen2-72b-instruct \
+ --parallel --num_workers 4
+ ```
+
+ 2. Evaluate the close-set zero-shot classification on ModelNet40
+
+ ```bash
+ export PYTHONPATH=$PWD
+ export DASHSCOPE_API_KEY=sk-xxx
+ python ./pointllm/eval/evaluator_opensource_llm_QwenAPI.py \
+ --results_path /path/to/evaluation/ModelNet_classification_prompt0.json \
+ --eval_type modelnet-close-set-classification \
+ --model_type qwen2-72b-instruct \
+ --parallel --num_workers 4
+ ```
+
+ ```bash
+ export PYTHONPATH=$PWD
+ export DASHSCOPE_API_KEY=sk-xxx
+ python ./pointllm/eval/evaluator_opensource_llm_QwenAPI.py \
+ --results_path /path/to/evaluation/ModelNet_classification_prompt1.json \
+ --eval_type modelnet-close-set-classification \
+ --model_type qwen2-72b-instruct \
+ --parallel --num_workers 4
+ ```
+
+ 3. Evaluate the object captioning on objaverse
+
+ ```bash
+ export PYTHONPATH=$PWD
+ export DASHSCOPE_API_KEY=sk-xxx
+ python ./pointllm/eval/evaluator_opensource_llm_QwenAPI.py \
+ --results_path /path/to/evaluation/PointLLM_brief_description_val_200_GT_Objaverse_captioning_prompt2.json \
+ --eval_type object-captioning \
+ --model_type qwen2-72b-instruct \
+ --parallel --num_workers 4
+ ```
+
+#### Traditional Metric Evaluation
+For the object captioning task, run the following command to evaluate model outputs with traditional metrics Sentence-BERT and SimCSE.
+
+```bash
+CUDA_VISIBLE_DEVICES=0 python pointllm/eval/traditional_evaluator.py --results_path /path/to/evaluation/PointLLM_brief_description_val_200_GT_Objaverse_captioning_prompt2.json
+```
+
+
+## Training
+
+**Stage I**
+```bash
+bash ./release/paper/scripts/train/1.sh
+```
+
+**Stage II**: GreenPLM-0
+```bash
+bash ./release/paper/scripts/train/2.sh
+```
+
+**Stage III**: GreenPLM
+```bash
+bash ./release/paper/scripts/train/3.sh
+```
+
+
+ We also provide training scripts using the entire T3D dataset, meaning we use 5M data from T3D in Stage II, instead of just 210k as in our paper. (click to expand)
+
+**Stage II**: GreenPLM-0
+```bash
+bash ./release/5M_data_seting/scripts/train/2.sh
+```
+
+**Stage III**: GreenPLM
+```bash
+bash ./release/5M_data_seting/scripts/train/3.sh
+```
+
+
+
+**Note**: You can modify the `--output_dir` argument in the scripts to set the output directory for the trained weights.
+
+
+
+
+
+
+
+
+
+
+
+
+
+## ๐ Citation
+If you find our work helpful, please consider citing:
+```bibtex
+@inproceedings{tang2025more,
+ title={More text, less point: Towards 3d data-efficient point-language understanding},
+ author={Tang, Yuan and Han, Xu and Li, Xianzhi and Yu, Qiao and Xu, Jinfeng and Hao, Yixue and Hu, Long and Chen, Min},
+ booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
+ volume={39},
+ number={7},
+ pages={7284--7292},
+ year={2025}
+}
+```
+
+## ๐ License
+
+
+This work is under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License.
+
+## ๐ Related Work
+Together, Let's make LLM for 3D great!
+- [Point-Bind & Point-LLM](https://arxiv.org/abs/2309.00615): aligns point clouds with Image-Bind to reason multi-modality input without 3D-instruction data training.
+- [3D-LLM](https://arxiv.org/abs/2307.12981): employs 2D foundation models to encode multi-view images of 3D point clouds.
+- [PointLLM](https://arxiv.org/abs/2308.16911): employs 3D point clouds with LLaVA.
+- [ShapeLLM](http://arxiv.org/abs/2402.17766): combines a powerful point cloud encoder with LLM for embodied scenes.
+- [MiniGPT-3D](https://arxiv.org/pdf/2405.01413) : takes the first step toward efficient 3D-LLM, requiring only a single RTX 3090 GPU and one day of training time.
+
+
+## ๐ Acknowledgements
+We would like to thank the authors of [PointLLM](https://github.com/OpenRobotLab/PointLLM), [Uni3D](https://github.com/baaivision/Uni3D), [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct), and [LLaVA-pp](https://github.com/mbzuai-oryx/LLaVA-pp) for their great works and repos.
\ No newline at end of file
diff --git a/cog.yaml b/cog.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d092e4083ea48243250d973191d493f24514d6ba
--- /dev/null
+++ b/cog.yaml
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e090913eb14a1290a253f8b3ecbd1f6bc24fb1c373afefbbf7896446961953c1
+size 981
diff --git a/config.json b/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..bd6abbba6dd31417e7c2e48300a8ea5c97ca4503
--- /dev/null
+++ b/config.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:63700c458c2b4e1b461cef0f51ff7604c8d024e0edbe36c8b8419c23bb6bac2d
+size 71
diff --git a/dataset/Objaverse/8192_npy.zip b/dataset/Objaverse/8192_npy.zip
new file mode 100644
index 0000000000000000000000000000000000000000..3d32929f268cc71f74d33f085549c7cb4c248d0e
--- /dev/null
+++ b/dataset/Objaverse/8192_npy.zip
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cf2bb1d3c7ff97d8d5bc62e5b623231adffc23e542a0e35ed8dd191fbd0dc542
+size 6352316008
diff --git a/dataset/Objaverse/PointLLM_brief_description_val_200_GT.json b/dataset/Objaverse/PointLLM_brief_description_val_200_GT.json
new file mode 100644
index 0000000000000000000000000000000000000000..38acf9e7d0f2eea1757575c23c9f6c88b8017805
--- /dev/null
+++ b/dataset/Objaverse/PointLLM_brief_description_val_200_GT.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cc0dabd9767a574f8acf22a6f1791689c567e49102b275e92c8e0bd083327955
+size 65315
diff --git a/dataset/Objaverse/PointLLM_complex_50k_brief_40k_all_90k.json b/dataset/Objaverse/PointLLM_complex_50k_brief_40k_all_90k.json
new file mode 100644
index 0000000000000000000000000000000000000000..5dc893e7ee7a3f23ecf432681bad893a5a734264
--- /dev/null
+++ b/dataset/Objaverse/PointLLM_complex_50k_brief_40k_all_90k.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c8543396bfd46dbb3bce4ea67907fb3b68e5cbf2aa0cb1cc03b4c413ccdb8f48
+size 43683846
diff --git a/dataset/T3D/stage_1/brief_1M_caption.json b/dataset/T3D/stage_1/brief_1M_caption.json
new file mode 100644
index 0000000000000000000000000000000000000000..8e41a3a7c3d3041b36125b4f13f3efd14ccfcdbb
--- /dev/null
+++ b/dataset/T3D/stage_1/brief_1M_caption.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0cf3a03d17a06d14488d6a6234de1fbc5c0970811b53ac88ed4b02a607459713
+size 646557180
diff --git a/dataset/T3D/stage_2/stage_2_data_210k.json b/dataset/T3D/stage_2/stage_2_data_210k.json
new file mode 100644
index 0000000000000000000000000000000000000000..37f8fec7c49e2439efe6a8d2ff1bc4beb58348cc
--- /dev/null
+++ b/dataset/T3D/stage_2/stage_2_data_210k.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a3565382fa139d4e8aca746f33aa74592c99f990a1a1d4dd141ab3571afe72a1
+size 168421688
diff --git a/dataset/T3D/stage_2/stage_2_data_5M.json b/dataset/T3D/stage_2/stage_2_data_5M.json
new file mode 100644
index 0000000000000000000000000000000000000000..43e4fabc036b7d8e841da42db4afb2d364621938
--- /dev/null
+++ b/dataset/T3D/stage_2/stage_2_data_5M.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:268913bd44e0be2d58177713b6740e4cd2b2054e410d4479643fadff5f1c2c04
+size 3966817060
diff --git a/dataset/modelnet40_data/modelnet40_test_8192pts_fps.dat b/dataset/modelnet40_data/modelnet40_test_8192pts_fps.dat
new file mode 100644
index 0000000000000000000000000000000000000000..00172b35e7f99b9decbbf8438c080dc64449f33d
--- /dev/null
+++ b/dataset/modelnet40_data/modelnet40_test_8192pts_fps.dat
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d85b2287a683bbb53842e712db1f77f5e772c371868381726d2541e33bd5cf87
+size 485526630
diff --git a/envInstall.sh b/envInstall.sh
new file mode 100644
index 0000000000000000000000000000000000000000..b399e9244e33174208650d967a18147beb2a3ce5
--- /dev/null
+++ b/envInstall.sh
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:be91e45733f958b25c51e976a9260f114b5b6614c8948ddd9cb70589b8eca071
+size 212
diff --git a/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/.gitattributes b/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/.gitattributes
new file mode 100644
index 0000000000000000000000000000000000000000..489d0cf0f2b10ab1ab15b57e9cb384b3b7a0abdf
--- /dev/null
+++ b/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/.gitattributes
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:11ad7efa24975ee4b0c3c3a38ed18737f0658a5f75a0a96787b576a78a023361
+size 1519
diff --git a/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/CODE_OF_CONDUCT.md b/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/CODE_OF_CONDUCT.md
new file mode 100644
index 0000000000000000000000000000000000000000..f9ba8cf65f3e3104dd061c178066ec8247811f33
--- /dev/null
+++ b/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/CODE_OF_CONDUCT.md
@@ -0,0 +1,9 @@
+# Microsoft Open Source Code of Conduct
+
+This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
+
+Resources:
+
+- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
+- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
+- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
diff --git a/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/LICENSE b/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..65d96207b60c56bb98330c832a914c8399547f75
--- /dev/null
+++ b/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/LICENSE
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fa8235e5b48faca34e3ca98cf4f694ef08bd216d28b58071a1f85b1d50cb814d
+size 1084
diff --git a/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/NOTICE.md b/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/NOTICE.md
new file mode 100644
index 0000000000000000000000000000000000000000..ee58e836b8cb628406447bae6b6b75a0fa553143
--- /dev/null
+++ b/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/NOTICE.md
@@ -0,0 +1,38 @@
+NOTICES AND INFORMATION
+Do Not Translate or Localize
+
+This software incorporates material from third parties.
+
+**Component.** https://github.com/Dao-AILab/flash-attention
+
+**Open Source License/Copyright Notice.**
+
+BSD 3-Clause License
+
+Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file.
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+* Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+* Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+* Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
\ No newline at end of file
diff --git a/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/README.md b/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..28cb5bb2f5ad193abd2591864f1ad00ce24d3cfb
--- /dev/null
+++ b/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/README.md
@@ -0,0 +1,256 @@
+---
+license: mit
+license_link: https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/resolve/main/LICENSE
+
+language:
+- en
+pipeline_tag: text-generation
+tags:
+- nlp
+- code
+inference:
+ parameters:
+ temperature: 0.7
+widget:
+ - messages:
+ - role: user
+ content: Can you provide ways to eat combinations of bananas and dragonfruits?
+---
+
+## Model Summary
+
+The Phi-3-Mini-4K-Instruct is a 3.8B parameters, lightweight, state-of-the-art open model trained with the Phi-3 datasets that includes both synthetic data and the filtered publicly available websites data with a focus on high-quality and reasoning dense properties.
+The model belongs to the Phi-3 family with the Mini version in two variants [4K](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) and [128K](https://huggingface.co/microsoft/Phi-3-mini-128k-instruct) which is the context length (in tokens) that it can support.
+
+The model has underwent a post-training process that incorporates both supervised fine-tuning and direct preference optimization for the instruction following and safety measures.
+When assessed against benchmarks testing common sense, language understanding, math, code, long context and logical reasoning, Phi-3 Mini-4K-Instruct showcased a robust and state-of-the-art performance among models with less than 13 billion parameters.
+
+Resources and Technical Documentation:
+
++ [Phi-3 Microsoft Blog](https://aka.ms/phi3blog-april)
++ [Phi-3 Technical Report](https://aka.ms/phi3-tech-report)
++ [Phi-3 on Azure AI Studio](https://aka.ms/phi3-azure-ai)
++ Phi-3 GGUF: [4K](https://aka.ms/Phi3-mini-4k-instruct-gguf)
++ Phi-3 ONNX: [4K](https://aka.ms/Phi3-mini-4k-instruct-onnx)
+
+## Intended Uses
+
+**Primary use cases**
+
+The model is intended for commercial and research use in English. The model provides uses for applications which require:
+
+1) Memory/compute constrained environments
+2) Latency bound scenarios
+3) Strong reasoning (especially code, math and logic)
+
+Our model is designed to accelerate research on language and multimodal models, for use as a building block for generative AI powered features.
+
+**Use case considerations**
+
+Our models are not specifically designed or evaluated for all downstream purposes. Developers should consider common limitations of language models as they select use cases, and evaluate and mitigate for accuracy, safety, and fariness before using within a specific downstream use case, particularly for high risk scenarios. Developers should be aware of and adhere to applicable laws or regulations (including privacy, trade compliance laws, etc.) that are relevant to their use case.
+
+Nothing contained in this Model Card should be interpreted as or deemed a restriction or modification to the license the model is released under.
+
+## How to Use
+
+Phi-3 Mini-4K-Instruct has been integrated in the development version (4.41.0.dev0) of `transformers`. Until the official version is released through `pip`, ensure that you are doing one of the following:
+
+* When loading the model, ensure that `trust_remote_code=True` is passed as an argument of the `from_pretrained()` function.
+
+* Update your local `transformers` to the development version: `pip uninstall -y transformers && pip install git+https://github.com/huggingface/transformers`. The previous command is an alternative to cloning and installing from the source.
+
+The current `transformers` version can be verified with: `pip list | grep transformers`.
+
+Phi-3 Mini-4K-Instruct is also available in [HuggingChat](https://aka.ms/try-phi3-hf-chat).
+
+### Tokenizer
+
+Phi-3 Mini-4K-Instruct supports a vocabulary size of up to `32064` tokens. The [tokenizer files](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/added_tokens.json) already provide placeholder tokens that can be used for downstream fine-tuning, but they can also be extended up to the model's vocabulary size.
+
+### Chat Format
+
+Given the nature of the training data, the Phi-3 Mini-4K-Instruct model is best suited for prompts using the chat format as follows.
+You can provide the prompt as a question with a generic template as follow:
+```markdown
+<|user|>\nQuestion <|end|>\n<|assistant|>
+```
+For example:
+```markdown
+<|user|>
+How to explain Internet for a medieval knight?<|end|>
+<|assistant|>
+```
+
+where the model generates the text after `<|assistant|>` . In case of few-shots prompt, the prompt can be formatted as the following:
+
+```markdown
+<|user|>
+I am going to Paris, what should I see?<|end|>
+<|assistant|>
+Paris, the capital of France, is known for its stunning architecture, art museums, historical landmarks, and romantic atmosphere. Here are some of the top attractions to see in Paris:\n\n1. The Eiffel Tower: The iconic Eiffel Tower is one of the most recognizable landmarks in the world and offers breathtaking views of the city.\n2. The Louvre Museum: The Louvre is one of the world's largest and most famous museums, housing an impressive collection of art and artifacts, including the Mona Lisa.\n3. Notre-Dame Cathedral: This beautiful cathedral is one of the most famous landmarks in Paris and is known for its Gothic architecture and stunning stained glass windows.\n\nThese are just a few of the many attractions that Paris has to offer. With so much to see and do, it's no wonder that Paris is one of the most popular tourist destinations in the world."<|end|>
+<|user|>
+What is so great about #1?<|end|>
+<|assistant|>
+```
+
+### Sample inference code
+
+This code snippets show how to get quickly started with running the model on a GPU:
+
+```python
+import torch
+from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
+
+torch.random.manual_seed(0)
+
+model = AutoModelForCausalLM.from_pretrained(
+ "microsoft/Phi-3-mini-4k-instruct",
+ device_map="cuda",
+ torch_dtype="auto",
+ trust_remote_code=True,
+)
+tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
+
+messages = [
+ {"role": "user", "content": "Can you provide ways to eat combinations of bananas and dragonfruits?"},
+ {"role": "assistant", "content": "Sure! Here are some ways to eat bananas and dragonfruits together: 1. Banana and dragonfruit smoothie: Blend bananas and dragonfruits together with some milk and honey. 2. Banana and dragonfruit salad: Mix sliced bananas and dragonfruits together with some lemon juice and honey."},
+ {"role": "user", "content": "What about solving an 2x + 3 = 7 equation?"},
+]
+
+pipe = pipeline(
+ "text-generation",
+ model=model,
+ tokenizer=tokenizer,
+)
+
+generation_args = {
+ "max_new_tokens": 500,
+ "return_full_text": False,
+ "temperature": 0.0,
+ "do_sample": False,
+}
+
+output = pipe(messages, **generation_args)
+print(output[0]['generated_text'])
+```
+
+*Some applications/frameworks might not include a BOS token (``) at the start of the conversation. Please ensure that it is included since it provides more reliable results.*
+
+## Responsible AI Considerations
+
+Like other language models, the Phi series models can potentially behave in ways that are unfair, unreliable, or offensive. Some of the limiting behaviors to be aware of include:
+
++ Quality of Service: the Phi models are trained primarily on English text. Languages other than English will experience worse performance. English language varieties with less representation in the training data might experience worse performance than standard American English.
++ Representation of Harms & Perpetuation of Stereotypes: These models can over- or under-represent groups of people, erase representation of some groups, or reinforce demeaning or negative stereotypes. Despite safety post-training, these limitations may still be present due to differing levels of representation of different groups or prevalence of examples of negative stereotypes in training data that reflect real-world patterns and societal biases.
++ Inappropriate or Offensive Content: these models may produce other types of inappropriate or offensive content, which may make it inappropriate to deploy for sensitive contexts without additional mitigations that are specific to the use case.
++ Information Reliability: Language models can generate nonsensical content or fabricate content that might sound reasonable but is inaccurate or outdated.
++ Limited Scope for Code: Majority of Phi-3 training data is based in Python and use common packages such as "typing, math, random, collections, datetime, itertools". If the model generates Python scripts that utilize other packages or scripts in other languages, we strongly recommend users manually verify all API uses.
+
+Developers should apply responsible AI best practices and are responsible for ensuring that a specific use case complies with relevant laws and regulations (e.g. privacy, trade, etc.). Important areas for consideration include:
+
++ Allocation: Models may not be suitable for scenarios that could have consequential impact on legal status or the allocation of resources or life opportunities (ex: housing, employment, credit, etc.) without further assessments and additional debiasing techniques.
++ High-Risk Scenarios: Developers should assess suitability of using models in high-risk scenarios where unfair, unreliable or offensive outputs might be extremely costly or lead to harm. This includes providing advice in sensitive or expert domains where accuracy and reliability are critical (ex: legal or health advice). Additional safeguards should be implemented at the application level according to the deployment context.
++ Misinformation: Models may produce inaccurate information. Developers should follow transparency best practices and inform end-users they are interacting with an AI system. At the application level, developers can build feedback mechanisms and pipelines to ground responses in use-case specific, contextual information, a technique known as Retrieval Augmented Generation (RAG).
++ Generation of Harmful Content: Developers should assess outputs for their context and use available safety classifiers or custom solutions appropriate for their use case.
++ Misuse: Other forms of misuse such as fraud, spam, or malware production may be possible, and developers should ensure that their applications do not violate applicable laws and regulations.
+
+
+## Training
+
+### Model
+
+* Architecture: Phi-3 Mini-4K-Instruct has 3.8B parameters and is a dense decoder-only Transformer model. The model is fine-tuned with Supervised fine-tuning (SFT) and Direct Preference Optimization (DPO) to ensure alignment with human preferences and safety guidlines.
+* Inputs: Text. It is best suited for prompts using chat format.
+* Context length: 4K tokens
+* GPUs: 512 H100-80G
+* Training time: 7 days
+* Training data: 3.3T tokens
+* Outputs: Generated text in response to the input
+* Dates: Our models were trained between February and April 2024
+* Status: This is a static model trained on an offline dataset with cutoff date October 2023. Future versions of the tuned models may be released as we improve models.
+
+### Datasets
+
+Our training data includes a wide variety of sources, totaling 3.3 trillion tokens, and is a combination of
+1) Publicly available documents filtered rigorously for quality, selected high-quality educational data, and code;
+2) Newly created synthetic, โtextbook-likeโ data for the purpose of teaching math, coding, common sense reasoning, general knowledge of the world (science, daily activities, theory of mind, etc.);
+3) High quality chat format supervised data covering various topics to reflect human preferences on different aspects such as instruct-following, truthfulness, honesty and helpfulness.
+
+### Fine-tuning
+
+A basic example of multi-GPUs supervised fine-tuning (SFT) with TRL and Accelerate modules is provided [here](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/resolve/main/sample_finetune.py).
+
+## Benchmarks
+
+We report the results for Phi-3-Mini-4K-Instruct on standard open-source benchmarks measuring the model's reasoning ability (both common sense reasoning and logical reasoning). We compare to Phi-2, Mistral-7b-v0.1, Mixtral-8x7b, Gemma 7B, Llama-3-8B-Instruct, and GPT-3.5.
+
+All the reported numbers are produced with the exact same pipeline to ensure that the numbers are comparable. These numbers might differ from other published numbers due to slightly different choices in the evaluation.
+
+As is now standard, we use few-shot prompts to evaluate the models, at temperature 0.
+The prompts and number of shots are part of a Microsoft internal tool to evaluate language models, and in particular we did no optimization to the pipeline for Phi-3.
+More specifically, we do not change prompts, pick different few-shot examples, change prompt format, or do any other form of optimization for the model.
+
+The number of kโshot examples is listed per-benchmark.
+
+| | Phi-3-Mini-4K-In
3.8b | Phi-3-Small
7b (preview) | Phi-3-Medium
14b (preview) | Phi-2
2.7b | Mistral
7b | Gemma
7b | Llama-3-In
8b | Mixtral
8x7b | GPT-3.5
version 1106 |
+|---|---|---|---|---|---|---|---|---|---|
+| MMLU
5-Shot | 68.8 | 75.3 | 78.2 | 56.3 | 61.7 | 63.6 | 66.5 | 68.4 | 71.4 |
+| HellaSwag
5-Shot | 76.7 | 78.7 | 83.2 | 53.6 | 58.5 | 49.8 | 71.1 | 70.4 | 78.8 |
+| ANLI
7-Shot | 52.8 | 55.0 | 58.7 | 42.5 | 47.1 | 48.7 | 57.3 | 55.2 | 58.1 |
+| GSM-8K
0-Shot; CoT | 82.5 | 86.4 | 90.8 | 61.1 | 46.4 | 59.8 | 77.4 | 64.7 | 78.1 |
+| MedQA
2-Shot | 53.8 | 58.2 | 69.8 | 40.9 | 49.6 | 50.0 | 60.5 | 62.2 | 63.4 |
+| AGIEval
0-Shot | 37.5 | 45.0 | 49.7 | 29.8 | 35.1 | 42.1 | 42.0 | 45.2 | 48.4 |
+| TriviaQA
5-Shot | 64.0 | 59.1 | 73.3 | 45.2 | 72.3 | 75.2 | 67.7 | 82.2 | 85.8 |
+| Arc-C
10-Shot | 84.9 | 90.7 | 91.9 | 75.9 | 78.6 | 78.3 | 82.8 | 87.3 | 87.4 |
+| Arc-E
10-Shot | 94.6 | 97.1 | 98.0 | 88.5 | 90.6 | 91.4 | 93.4 | 95.6 | 96.3 |
+| PIQA
5-Shot | 84.2 | 87.8 | 88.2 | 60.2 | 77.7 | 78.1 | 75.7 | 86.0 | 86.6 |
+| SociQA
5-Shot | 76.6 | 79.0 | 79.4 | 68.3 | 74.6 | 65.5 | 73.9 | 75.9 | 68.3 |
+| BigBench-Hard
0-Shot | 71.7 | 75.0 | 82.5 | 59.4 | 57.3 | 59.6 | 51.5 | 69.7 | 68.32 |
+| WinoGrande
5-Shot | 70.8 | 82.5 | 81.2 | 54.7 | 54.2 | 55.6 | 65 | 62.0 | 68.8 |
+| OpenBookQA
10-Shot | 83.2 | 88.4 | 86.6 | 73.6 | 79.8 | 78.6 | 82.6 | 85.8 | 86.0 |
+| BoolQ
0-Shot | 77.6 | 82.9 | 86.5 | -- | 72.2 | 66.0 | 80.9 | 77.6 | 79.1 |
+| CommonSenseQA
10-Shot | 80.2 | 80.3 | 82.6 | 69.3 | 72.6 | 76.2 | 79 | 78.1 | 79.6 |
+| TruthfulQA
10-Shot | 65.0 | 68.1 | 74.8 | -- | 52.1 | 53.0 | 63.2 | 60.1 | 85.8 |
+| HumanEval
0-Shot | 59.1 | 59.1 | 54.7 | 47.0 | 28.0 | 34.1 | 60.4 | 37.8 | 62.2 |
+| MBPP
3-Shot | 53.8 | 71.4 | 73.7 | 60.6 | 50.8 | 51.5 | 67.7 | 60.2 | 77.8 |
+
+## Software
+
+* [PyTorch](https://github.com/pytorch/pytorch)
+* [DeepSpeed](https://github.com/microsoft/DeepSpeed)
+* [Transformers](https://github.com/huggingface/transformers)
+* [Flash-Attention](https://github.com/HazyResearch/flash-attention)
+
+## Hardware
+Note that by default, the Phi-3-mini model uses flash attention, which requires certain types of GPU hardware to run. We have tested on the following GPU types:
+* NVIDIA A100
+* NVIDIA A6000
+* NVIDIA H100
+
+If you want to run the model on:
+* NVIDIA V100 or earlier generation GPUs: call AutoModelForCausalLM.from_pretrained() with attn_implementation="eager"
+* CPU: use the **GGUF** quantized models [4K](https://aka.ms/Phi3-mini-4k-instruct-gguf)
++ Optimized inference on GPU, CPU, and Mobile: use the **ONNX** models [4K](https://aka.ms/Phi3-mini-4k-instruct-onnx)
+
+
+## Cross Platform Support
+
+ONNX runtime ecosystem now supports Phi-3 Mini models across platforms and hardware. You can find the optimized Phi-3 Mini-4K-Instruct ONNX model [here](https://aka.ms/phi3-mini-4k-instruct-onnx).
+
+Optimized Phi-3 models are also published here in ONNX format, to run with ONNX Runtime on CPU and GPU across devices, including server platforms, Windows, Linux and Mac desktops, and mobile CPUs, with the precision best suited to each of these targets. DirectML support lets developers bring hardware acceleration to Windows devices at scale across AMD, Intel, and NVIDIA GPUs.
+Along with DirectML, ONNX Runtime provides cross platform support for Phi-3 across a range of devices CPU, GPU, and mobile.
+
+Here are some of the optimized configurations we have added:
+
+1. ONNX models for int4 DML: Quantized to int4 via AWQ
+2. ONNX model for fp16 CUDA
+3. ONNX model for int4 CUDA: Quantized to int4 via RTN
+4. ONNX model for int4 CPU and Mobile: Quantized to int4 via RTN
+
+## License
+
+The model is licensed under the [MIT license](https://huggingface.co/microsoft/Phi-3-mini-4k/resolve/main/LICENSE).
+
+## Trademarks
+
+This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft trademarks or logos is subject to and must followโฏ[Microsoftโs Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks). Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. Any use of third-party trademarks or logos are subject to those third-partyโs policies.
diff --git a/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/SECURITY.md b/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/SECURITY.md
new file mode 100644
index 0000000000000000000000000000000000000000..b3c89efc852e22f71eabf5dfbc6ac62493425eb6
--- /dev/null
+++ b/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/SECURITY.md
@@ -0,0 +1,41 @@
+
+
+## Security
+
+Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin).
+
+If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below.
+
+## Reporting Security Issues
+
+**Please do not report security vulnerabilities through public GitHub issues.**
+
+Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report).
+
+If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp).
+
+You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
+
+Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
+
+ * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
+ * Full paths of source file(s) related to the manifestation of the issue
+ * The location of the affected source code (tag/branch/commit or direct URL)
+ * Any special configuration required to reproduce the issue
+ * Step-by-step instructions to reproduce the issue
+ * Proof-of-concept or exploit code (if possible)
+ * Impact of the issue, including how an attacker might exploit the issue
+
+This information will help us triage your report more quickly.
+
+If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs.
+
+## Preferred Languages
+
+We prefer all communications to be in English.
+
+## Policy
+
+Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd).
+
+
diff --git a/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/added_tokens.json b/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/added_tokens.json
new file mode 100644
index 0000000000000000000000000000000000000000..001971e98fef26f2c9d4c9f00ebccf32b2d500cc
--- /dev/null
+++ b/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/added_tokens.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8f5b652d997cf841b1a79f2557e235b33140006dac569e76edb8ab3437d138c3
+size 293
diff --git a/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/config.json b/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..4dd35005a4fe792ed2676d29f438f1d43911107d
--- /dev/null
+++ b/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/config.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d81d79531772a0cafe4615997c659cbf53b831622daed49494b4a6937eef9dd3
+size 904
diff --git a/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/configuration_phi3.py b/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/configuration_phi3.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4553db23ac65c608fd150a14acbd04d3ff80a0f
--- /dev/null
+++ b/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/configuration_phi3.py
@@ -0,0 +1,213 @@
+# coding=utf-8
+# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+""" Phi-3 model configuration"""
+
+
+from transformers.configuration_utils import PretrainedConfig
+from transformers.utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+PHI3_PRETRAINED_CONFIG_ARCHIVE_MAP = {
+ "microsoft/Phi-3-mini-4k-instruct": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/resolve/main/config.json",
+ "microsoft/Phi-3-mini-128k-instruct": "https://huggingface.co/microsoft/Phi-3-mini-128k-instruct/resolve/main/config.json",
+}
+
+
+class Phi3Config(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Phi3Model`]. It is used to instantiate a Phi-3
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the
+ [microsoft/Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 32064):
+ Vocabulary size of the Phi-3 model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`Phi3Model`].
+ hidden_size (`int`, *optional*, defaults to 3072):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 8192):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 32):
+ Number of hidden layers in the Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ num_key_value_heads (`int`, *optional*):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details checkout [this
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
+ `num_attention_heads`.
+ resid_pdrop (`float`, *optional*, defaults to 0.0):
+ Dropout probability for mlp outputs.
+ embd_pdrop (`int`, *optional*, defaults to 0.0):
+ The dropout ratio for the embeddings.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio after computing the attention scores.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 4096):
+ The maximum sequence length that this model might ever be used with.
+ original_max_position_embeddings (`int`, *optional*, defaults to 4096):
+ The maximum sequence length that this model was trained with. This is used to determine the size of the
+ original RoPE embeddings when using long scaling.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon value used for the RMSNorm.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`. Whether to tie weight embeddings or not.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to tie weight embeddings
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`dict`, *optional*):
+ The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must
+ contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be either `su` or `yarn` and
+ the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size
+ divided by the number of attention heads divided by 2.
+ bos_token_id (`int`, *optional*, defaults to 1):
+ The id of the "beginning-of-sequence" token.
+ eos_token_id (`int`, *optional*, defaults to 32000):
+ The id of the "end-of-sequence" token.
+ pad_token_id (`int`, *optional*, defaults to 32000):
+ The id of the padding token.
+ sliding_window (`int`, *optional*):
+ Sliding window attention window size. If `None`, no sliding window is applied.
+
+ Example:
+
+ ```python
+ >>> from transformers import Phi3Model, Phi3Config
+
+ >>> # Initializing a Phi-3 style configuration
+ >>> configuration = Phi3Config.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
+
+ >>> # Initializing a model from the configuration
+ >>> model = Phi3Model(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "phi3"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ vocab_size=32064,
+ hidden_size=3072,
+ intermediate_size=8192,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ num_key_value_heads=None,
+ resid_pdrop=0.0,
+ embd_pdrop=0.0,
+ attention_dropout=0.0,
+ hidden_act="silu",
+ max_position_embeddings=4096,
+ original_max_position_embeddings=4096,
+ initializer_range=0.02,
+ rms_norm_eps=1e-5,
+ use_cache=True,
+ tie_word_embeddings=False,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ bos_token_id=1,
+ eos_token_id=32000,
+ pad_token_id=32000,
+ sliding_window=None,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.resid_pdrop = resid_pdrop
+ self.embd_pdrop = embd_pdrop
+ self.attention_dropout = attention_dropout
+ self.hidden_act = hidden_act
+ self.max_position_embeddings = max_position_embeddings
+ self.original_max_position_embeddings = original_max_position_embeddings
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self._rope_scaling_validation()
+ self.sliding_window = sliding_window
+
+ super().__init__(
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ pad_token_id=pad_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+ def _rope_scaling_validation(self):
+ """
+ Validate the `rope_scaling` configuration.
+ """
+ if self.rope_scaling is None:
+ return
+
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 3:
+ raise ValueError(
+ "`rope_scaling` must be a dictionary with three fields, `type`, `short_factor` and `long_factor`, "
+ f"got {self.rope_scaling}"
+ )
+ rope_scaling_type = self.rope_scaling.get("type", None)
+ rope_scaling_short_factor = self.rope_scaling.get("short_factor", None)
+ rope_scaling_long_factor = self.rope_scaling.get("long_factor", None)
+ if rope_scaling_type is None or rope_scaling_type not in ["su", "yarn"]:
+ raise ValueError(f"`rope_scaling`'s type field must be one of ['su', 'yarn'], got {rope_scaling_type}")
+ if not (
+ isinstance(rope_scaling_short_factor, list)
+ and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor)
+ ):
+ raise ValueError(
+ f"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}"
+ )
+ if not len(rope_scaling_short_factor) == self.hidden_size // self.num_attention_heads // 2:
+ raise ValueError(
+ f"`rope_scaling`'s short_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_short_factor)}"
+ )
+ if not (
+ isinstance(rope_scaling_long_factor, list)
+ and all(isinstance(x, (int, float)) for x in rope_scaling_long_factor)
+ ):
+ raise ValueError(
+ f"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}"
+ )
+ if not len(rope_scaling_long_factor) == self.hidden_size // self.num_attention_heads // 2:
+ raise ValueError(
+ f"`rope_scaling`'s long_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_long_factor)}"
+ )
diff --git a/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/generation_config.json b/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/generation_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..a80fd8d5b362482f16ab4ba3d8d3fdec6f84762a
--- /dev/null
+++ b/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/generation_config.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:466c2f0dc6eb59aa1593d3dd30e5d6614b8bf9e5d0c3b94f268ce4e341345009
+size 172
diff --git a/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/model-00001-of-00002.safetensors b/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/model-00001-of-00002.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..65d40bfd824c728795fb520d67fc91ff7c1a16d0
--- /dev/null
+++ b/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/model-00001-of-00002.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f95c89449ba404df51c3f633df65c52cabea9dbfae7e21977d32b5daa397cc91
+size 4972489328
diff --git a/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/model-00002-of-00002.safetensors b/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/model-00002-of-00002.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..5b5c3007674dcf16d6da1f673d3ebe4bdd022e16
--- /dev/null
+++ b/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/model-00002-of-00002.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5de695669421c1e12d2d1588a4795de6a5900bd1513c06be357b65a041b3590e
+size 2669692552
diff --git a/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/model.safetensors.index.json b/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/model.safetensors.index.json
new file mode 100644
index 0000000000000000000000000000000000000000..fb9316bdf8039b9c2456d9ebf34c1cb0ee41e73c
--- /dev/null
+++ b/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/model.safetensors.index.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6eec78ee2a442445d9d342a7cd9a763da40a49bd91d8611b7db5e9a29c90a428
+size 16331
diff --git a/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/modeling_phi3.py b/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/modeling_phi3.py
new file mode 100644
index 0000000000000000000000000000000000000000..29f4fa30295346e8926394c6cc7897eac23a1d9e
--- /dev/null
+++ b/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/modeling_phi3.py
@@ -0,0 +1,1606 @@
+# coding=utf-8
+# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+""" PyTorch Phi-3 model."""
+
+import inspect
+import math
+import warnings
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from transformers.activations import ACT2FN
+from transformers.cache_utils import Cache, DynamicCache
+from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPast,
+ CausalLMOutputWithPast,
+ SequenceClassifierOutputWithPast,
+ TokenClassifierOutput,
+)
+from transformers.modeling_utils import PreTrainedModel
+from transformers.utils import (
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ is_flash_attn_2_available,
+ is_flash_attn_greater_or_equal_2_10,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_phi3 import Phi3Config
+
+
+logger = logging.get_logger(__name__)
+
+# Transformers scans dependencies in the modeling file, causing issues on conditional loading. The regex only ignores try/catch blocks, but not if statements
+# if is_flash_attn_2_available():
+_flash_supports_window_size = False
+try:
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
+
+ _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
+except ImportError as error:
+ logger.warning(
+ f"`flash-attention` package not found, consider installing for better performance: {error}."
+ )
+ if not _flash_supports_window_size:
+ logger.warning(
+ "Current `flash-attention` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`."
+ )
+
+_CHECKPOINT_FOR_DOC = "microsoft/Phi-3-mini-4k-instruct"
+_CONFIG_FOR_DOC = "Phi3Config"
+
+PHI3_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "microsoft/Phi-3-mini-4k-instruct",
+ "microsoft/Phi-3-mini-128k-instruct",
+ # See all Phi-3 models at https://huggingface.co/models?filter=Phi-3
+]
+
+
+# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Phi3
+class Phi3RMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ Phi3RMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+
+# Copied from transformers.models.llama.modeling_llama._get_unpad_data
+def _get_unpad_data(attention_mask):
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
+ return (
+ indices,
+ cu_seqlens,
+ max_seqlen_in_batch,
+ )
+
+
+# Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding with gemma->phi3, Gemma->Phi3
+class Phi3RotaryEmbedding(nn.Module):
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
+ super().__init__()
+
+ self.dim = dim
+ self.max_position_embeddings = max_position_embeddings
+ self.base = base
+ self.register_buffer("inv_freq", None, persistent=False)
+
+ @torch.no_grad()
+ def forward(self, x, position_ids, seq_len=None):
+ # x: [bs, num_attention_heads, seq_len, head_size]
+ if self.inv_freq is None:
+ self.inv_freq = 1.0 / (
+ self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
+ )
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
+ position_ids_expanded = position_ids[:, None, :].float()
+ # Force float32 since bfloat16 loses precision on long contexts
+ # See https://github.com/huggingface/transformers/pull/29285
+ device_type = x.device.type
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False):
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos()
+ sin = emb.sin()
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+class Phi3SuScaledRotaryEmbedding(Phi3RotaryEmbedding):
+ def __init__(self, dim, config, device=None):
+ super().__init__(dim, config.max_position_embeddings, config.rope_theta, device)
+
+ self.short_factor = config.rope_scaling["short_factor"]
+ self.long_factor = config.rope_scaling["long_factor"]
+ self.original_max_position_embeddings = config.original_max_position_embeddings
+
+ @torch.no_grad()
+ def forward(self, x, position_ids, seq_len=None):
+ seq_len = torch.max(position_ids) + 1
+ if seq_len > self.original_max_position_embeddings:
+ ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
+ else:
+ ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
+
+ inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim
+ self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
+
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ # Force float32 since bfloat16 loses precision on long contexts
+ # See https://github.com/huggingface/transformers/pull/29285
+ device_type = x.device.type
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False):
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+
+ scale = self.max_position_embeddings / self.original_max_position_embeddings
+ if scale <= 1.0:
+ scaling_factor = 1.0
+ else:
+ scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
+
+ cos = emb.cos() * scaling_factor
+ sin = emb.sin() * scaling_factor
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+class Phi3YarnScaledRotaryEmbedding(Phi3RotaryEmbedding):
+ def __init__(self, dim, config, device=None):
+ super().__init__(dim, config.max_position_embeddings, config.rope_theta, device)
+
+ self.short_factor = config.rope_scaling["short_factor"]
+ self.long_factor = config.rope_scaling["long_factor"]
+ self.original_max_position_embeddings = config.original_max_position_embeddings
+
+ @torch.no_grad()
+ def forward(self, x, position_ids, seq_len=None):
+ seq_len = torch.max(position_ids) + 1
+ if seq_len > self.original_max_position_embeddings:
+ ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
+ else:
+ ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
+
+ inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim
+ self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
+
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ # Force float32 since bfloat16 loses precision on long contexts
+ # See https://github.com/huggingface/transformers/pull/29285
+ device_type = x.device.type
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False):
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+
+ scale = self.max_position_embeddings / self.original_max_position_embeddings
+ if scale <= 1.0:
+ scaling_factor = 1.0
+ else:
+ scaling_factor = 0.1 * math.log(scale) + 1.0
+
+ cos = emb.cos() * scaling_factor
+ sin = emb.sin() * scaling_factor
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+# Copied from transformers.models.llama.modeling_llama.rotate_half
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+class Phi3MLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+
+ self.config = config
+ self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
+
+ self.activation_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
+ up_states = self.gate_up_proj(hidden_states)
+
+ gate, up_states = up_states.chunk(2, dim=-1)
+ up_states = up_states * self.activation_fn(gate)
+
+ return self.down_proj(up_states)
+
+
+# Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+class Phi3Attention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: Phi3Config, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ if layer_idx is None:
+ logger.warning_once(
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
+
+ self.attention_dropout = config.attention_dropout
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.max_position_embeddings = config.max_position_embeddings
+ self.original_max_position_embeddings = config.original_max_position_embeddings
+ self.rope_theta = config.rope_theta
+ self.rope_scaling = config.rope_scaling
+ self.is_causal = True
+
+ if (self.head_dim * self.num_heads) != self.hidden_size:
+ raise ValueError(
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+
+ op_size = self.num_heads * self.head_dim + 2 * (self.num_key_value_heads * self.head_dim)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
+ self.qkv_proj = nn.Linear(self.hidden_size, op_size, bias=False)
+ self._init_rope()
+
+ def _init_rope(self):
+ if self.rope_scaling is None:
+ self.rotary_emb = Phi3RotaryEmbedding(
+ self.head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ base=self.rope_theta,
+ )
+ else:
+ scaling_type = self.config.rope_scaling["type"]
+ if scaling_type == "su":
+ self.rotary_emb = Phi3SuScaledRotaryEmbedding(self.head_dim, self.config)
+ elif scaling_type == "yarn":
+ self.rotary_emb = Phi3YarnScaledRotaryEmbedding(self.head_dim, self.config)
+ else:
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ logger.warning_once("You are not running the flash-attention implementation, expect numerical differences.")
+
+ bsz, q_len, _ = hidden_states.size()
+
+ qkv = self.qkv_proj(hidden_states)
+ query_pos = self.num_heads * self.head_dim
+ query_states = qkv[..., :query_pos]
+ key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
+ value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ if self.layer_idx is None:
+ raise ValueError(
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
+ "with a layer index."
+ )
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
+ cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
+
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+
+ if past_key_value is not None:
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights + attention_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class Phi3FlashAttention2(Phi3Attention):
+ """
+ Phi-3 flash attention module. This module inherits from `Phi3Attention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ # Phi3FlashAttention2 attention does not support output_attentions
+
+ if not _flash_supports_window_size:
+ logger.warning_once(
+ "The current flash attention version does not support sliding window attention. Please use `attn_implementation='eager'` or upgrade flash-attn library."
+ )
+ raise ValueError("The current flash attention version does not support sliding window attention.")
+
+ output_attentions = False
+
+ if "padding_mask" in kwargs:
+ warnings.warn(
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
+ )
+
+ # overwrite attention_mask with padding_mask
+ attention_mask = kwargs.pop("padding_mask")
+
+ bsz, q_len, _ = hidden_states.size()
+
+ qkv = self.qkv_proj(hidden_states)
+ query_pos = self.num_heads * self.head_dim
+ query_states = qkv[..., :query_pos]
+ key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
+ value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ if self.layer_idx is None:
+ raise ValueError(
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
+ "with a layer index."
+ )
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
+
+ # Because the input can be padded, the absolute sequence length depends on the max position id.
+ rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
+ cos, sin = self.rotary_emb(value_states, position_ids, seq_len=rotary_seq_len)
+
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+
+ use_sliding_windows = (
+ _flash_supports_window_size
+ and getattr(self.config, "sliding_window", None) is not None
+ and kv_seq_len > self.config.sliding_window
+ )
+
+ if past_key_value is not None:
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
+ cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
+ if (
+ getattr(self.config, "sliding_window", None) is not None
+ and kv_seq_len > self.config.sliding_window
+ and cache_has_contents
+ ):
+ slicing_tokens = 1 - self.config.sliding_window
+
+ past_key = past_key_value[self.layer_idx][0]
+ past_value = past_key_value[self.layer_idx][1]
+
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
+
+ if past_key.shape[-2] != self.config.sliding_window - 1:
+ raise ValueError(
+ f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
+ f" {past_key.shape}"
+ )
+
+ if attention_mask is not None:
+ attention_mask = attention_mask[:, slicing_tokens:]
+ attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
+
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ attn_dropout = self.attention_dropout if self.training else 0.0
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32.
+
+ if query_states.dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.qkv_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ # Reashape to the expected shape for Flash Attention
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ attn_output = self._flash_attention_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ dropout=attn_dropout,
+ use_sliding_windows=use_sliding_windows,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+ # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._flash_attention_forward
+ def _flash_attention_forward(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ query_length,
+ dropout=0.0,
+ softmax_scale=None,
+ use_sliding_windows=False,
+ ):
+ """
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
+ first unpad the input, then computes the attention scores and pad the final attention scores.
+
+ Args:
+ query_states (`torch.Tensor`):
+ Input query states to be passed to Flash Attention API
+ key_states (`torch.Tensor`):
+ Input key states to be passed to Flash Attention API
+ value_states (`torch.Tensor`):
+ Input value states to be passed to Flash Attention API
+ attention_mask (`torch.Tensor`):
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
+ position of padding tokens and 1 for the position of non-padding tokens.
+ dropout (`float`):
+ Attention dropout
+ softmax_scale (`float`, *optional*):
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
+ use_sliding_windows (`bool`, *optional*):
+ Whether to activate sliding window attention.
+ """
+ if not self._flash_attn_uses_top_left_mask:
+ causal = self.is_causal
+ else:
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
+ causal = self.is_causal and query_length != 1
+
+ # Contains at least one padding token in the sequence
+ if attention_mask is not None:
+ batch_size = query_states.shape[0]
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
+ query_states, key_states, value_states, attention_mask, query_length
+ )
+
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
+
+ if not use_sliding_windows:
+ attn_output_unpad = flash_attn_varlen_func(
+ query_states,
+ key_states,
+ value_states,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_in_batch_q,
+ max_seqlen_k=max_seqlen_in_batch_k,
+ dropout_p=dropout,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ )
+ else:
+ attn_output_unpad = flash_attn_varlen_func(
+ query_states,
+ key_states,
+ value_states,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_in_batch_q,
+ max_seqlen_k=max_seqlen_in_batch_k,
+ dropout_p=dropout,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ window_size=(self.config.sliding_window, self.config.sliding_window),
+ )
+
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
+ else:
+ if not use_sliding_windows:
+ attn_output = flash_attn_func(
+ query_states,
+ key_states,
+ value_states,
+ dropout,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ )
+ else:
+ attn_output = flash_attn_func(
+ query_states,
+ key_states,
+ value_states,
+ dropout,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ window_size=(self.config.sliding_window, self.config.sliding_window),
+ )
+
+ return attn_output
+
+ # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._upad_input
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
+ batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
+
+ # On the first iteration we need to properly re-create the padding mask
+ # by slicing it on the proper place
+ if kv_seq_len != attention_mask.shape[-1]:
+ attention_mask_num_tokens = attention_mask.shape[-1]
+ attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
+
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
+
+ key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
+ value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
+
+ if query_length == kv_seq_len:
+ query_layer = index_first_axis(
+ query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
+ )
+ cu_seqlens_q = cu_seqlens_k
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
+ indices_q = indices_k
+ elif query_length == 1:
+ max_seqlen_in_batch_q = 1
+ cu_seqlens_q = torch.arange(
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
+ ) # There is a memcpy here, that is very bad.
+ indices_q = cu_seqlens_q[:-1]
+ query_layer = query_layer.squeeze(1)
+ else:
+ # The -q_len: slice assumes left padding.
+ attention_mask = attention_mask[:, -query_length:]
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
+
+ return (
+ query_layer,
+ key_layer,
+ value_layer,
+ indices_q,
+ (cu_seqlens_q, cu_seqlens_k),
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
+ )
+
+
+# copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Phi3
+# TODO @Arthur no longer copied from LLama after static cache
+class Phi3SdpaAttention(Phi3Attention):
+ """
+ Phi3 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+ `Phi3Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
+ SDPA API.
+ """
+
+ # Adapted from Phi3Attention.forward
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if output_attentions:
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
+ logger.warning_once(
+ "Phi3Model is using Phi3SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+
+ qkv = self.qkv_proj(hidden_states)
+ query_pos = self.num_heads * self.head_dim
+ query_states = qkv[..., :query_pos]
+ key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
+ value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
+ cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
+
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+
+ if past_key_value is not None:
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
+ )
+
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
+ if query_states.device.type == "cuda" and attention_mask is not None:
+ query_states = query_states.contiguous()
+ key_states = key_states.contiguous()
+ value_states = value_states.contiguous()
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=attention_mask,
+ dropout_p=self.attention_dropout if self.training else 0.0,
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
+ is_causal=self.is_causal and attention_mask is None and q_len > 1,
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output, None, past_key_value
+
+
+PHI3_ATTENTION_CLASSES = {
+ "eager": Phi3Attention,
+ "flash_attention_2": Phi3FlashAttention2,
+ "sdpa": Phi3SdpaAttention,
+}
+
+
+class Phi3DecoderLayer(nn.Module):
+ def __init__(self, config: Phi3Config, layer_idx: int):
+ super().__init__()
+
+ self.config = config
+ self.self_attn = PHI3_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
+
+ self.mlp = Phi3MLP(config)
+ self.input_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ self.resid_attn_dropout = nn.Dropout(config.resid_pdrop)
+ self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop)
+ self.post_attention_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ **kwargs,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ if "padding_mask" in kwargs:
+ warnings.warn(
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
+ )
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`):
+ input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
+ `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ """
+
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ attn_outputs, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ hidden_states = residual + self.resid_attn_dropout(attn_outputs)
+
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + self.resid_mlp_dropout(hidden_states)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+PHI3_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`Phi3Config`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+ "The bare Phi-3 model outputting raw hidden-states without any specific head on top.",
+ PHI3_START_DOCSTRING,
+)
+class Phi3PreTrainedModel(PreTrainedModel):
+ config_class = Phi3Config
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["Phi3DecoderLayer"]
+ _skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn_2 = True
+ _supports_sdpa = False
+ _supports_cache_class = True
+
+ _version = "0.0.5"
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+
+PHI3_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
+ `past_key_values`).
+
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+ information on the default strategy.
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.n_positions - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
+
+ Two formats are allowed:
+ - a [`~cache_utils.Cache`] instance;
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
+ cache format.
+
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
+ legacy cache format will be returned.
+
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
+ of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare Phi-3 model outputting raw hidden-states without any specific head on top.",
+ PHI3_START_DOCSTRING,
+)
+class Phi3Model(Phi3PreTrainedModel):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi3DecoderLayer`]
+
+ Args:
+ config: Phi3Config
+ """
+
+ def __init__(self, config: Phi3Config):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.embed_dropout = nn.Dropout(config.embd_pdrop)
+ self.layers = nn.ModuleList(
+ [Phi3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self._attn_implementation = config._attn_implementation
+ self.norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape[:2]
+ elif inputs_embeds is not None:
+ batch_size, seq_length = inputs_embeds.shape[:2]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ past_key_values_length = 0
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ if use_cache:
+ use_legacy_cache = not isinstance(past_key_values, Cache)
+ if use_legacy_cache:
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
+
+ if position_ids is None:
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ position_ids = torch.arange(
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
+ )
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+ else:
+ position_ids = position_ids.view(-1, seq_length).long()
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
+ is_padding_right = attention_mask[:, -1].sum().item() != batch_size
+ if is_padding_right:
+ raise ValueError(
+ "You are attempting to perform batched generation with padding_side='right'"
+ " this may lead to unexpected behaviour for Flash Attention version of Phi3. Make sure to "
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
+ )
+
+ if self._attn_implementation == "flash_attention_2":
+ # 2d mask is passed through the layers
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
+ else:
+ # 4d mask is passed through the layers
+ attention_mask = _prepare_4d_causal_attention_mask(
+ attention_mask,
+ (batch_size, seq_length),
+ inputs_embeds,
+ past_key_values_length,
+ sliding_window=self.config.sliding_window,
+ )
+
+ hidden_states = inputs_embeds
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = None
+
+ for decoder_layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
+ hidden_states,
+ attention_mask,
+ position_ids,
+ past_key_values,
+ output_attentions,
+ use_cache,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = None
+ if use_cache:
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+
+class Phi3ForCausalLM(Phi3PreTrainedModel):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi3
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = Phi3Model(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder
+ def get_decoder(self):
+ return self.model
+
+ # Ignore copy
+ @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ r"""
+ Args:
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, Phi3ForCausalLM
+
+ >>> model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct")
+ >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct")
+
+ >>> prompt = "This is an example script ."
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ 'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum'
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+ logits = self.lm_head(hidden_states)
+ logits = logits.float()
+
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
+ shift_labels = shift_labels.view(-1)
+ # Enable model parallelism
+ shift_labels = shift_labels.to(shift_logits.device)
+ loss = loss_fct(shift_logits, shift_labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ # Copied from transformers.models.persimmon.modeling_persimmon.PersimmonForCausalLM.prepare_inputs_for_generation
+ def prepare_inputs_for_generation(
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
+ ):
+ if past_key_values is not None:
+ if isinstance(past_key_values, Cache):
+ cache_length = past_key_values.get_seq_length()
+ past_length = past_key_values.seen_tokens
+ max_cache_length = past_key_values.get_max_length()
+ else:
+ cache_length = past_length = past_key_values[0][0].shape[2]
+ max_cache_length = None
+
+ # Keep only the unprocessed tokens:
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
+ # input)
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
+ # input_ids based on the past_length.
+ elif past_length < input_ids.shape[1]:
+ input_ids = input_ids[:, past_length:]
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
+
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
+ if (
+ max_cache_length is not None
+ and attention_mask is not None
+ and cache_length + input_ids.shape[1] > max_cache_length
+ ):
+ attention_mask = attention_mask[:, -max_cache_length:]
+
+ position_ids = kwargs.get("position_ids", None)
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ if past_key_values:
+ position_ids = position_ids[:, -input_ids.shape[1] :]
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and past_key_values is None:
+ model_inputs = {"inputs_embeds": inputs_embeds}
+ else:
+ model_inputs = {"input_ids": input_ids}
+
+ model_inputs.update(
+ {
+ "position_ids": position_ids,
+ "past_key_values": past_key_values,
+ "use_cache": kwargs.get("use_cache"),
+ "attention_mask": attention_mask,
+ }
+ )
+ return model_inputs
+
+ @staticmethod
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM._reorder_cache
+ def _reorder_cache(past_key_values, beam_idx):
+ reordered_past = ()
+ for layer_past in past_key_values:
+ reordered_past += (
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
+ )
+ return reordered_past
+
+
+@add_start_docstrings(
+ """
+ The [`Phi3Model`] with a sequence classification head on top (linear layer).
+
+ [`Phi3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
+ (e.g. GPT-2) do.
+
+ Since it does classification on the last token, it requires to know the position of the last token. If a
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
+ each row of the batch).
+ """,
+ PHI3_START_DOCSTRING,
+)
+# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Phi3, LLAMA->PHI3, self.transformer->self.model, transformer_outputs->model_outputs
+class Phi3ForSequenceClassification(Phi3PreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.model = Phi3Model(config)
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ model_outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = model_outputs[0]
+ logits = self.score(hidden_states)
+
+ if input_ids is not None:
+ batch_size = input_ids.shape[0]
+ else:
+ batch_size = inputs_embeds.shape[0]
+
+ if self.config.pad_token_id is None and batch_size != 1:
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
+ if self.config.pad_token_id is None:
+ sequence_lengths = -1
+ else:
+ if input_ids is not None:
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
+ sequence_lengths = sequence_lengths.to(logits.device)
+ else:
+ sequence_lengths = -1
+
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
+
+ loss = None
+ if labels is not None:
+ labels = labels.to(logits.device)
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(pooled_logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(pooled_logits, labels)
+ if not return_dict:
+ output = (pooled_logits,) + model_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=pooled_logits,
+ past_key_values=model_outputs.past_key_values,
+ hidden_states=model_outputs.hidden_states,
+ attentions=model_outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ [`Phi3Model`] with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
+ Named-Entity-Recognition (NER) tasks.
+ """,
+ PHI3_START_DOCSTRING,
+)
+# Copied from transformers.models.mpt.modeling_mpt.MptForTokenClassification with Mpt->Phi3,MPT->PHI3,self.transformer->self.model,transformer_outputs->model_outputs
+class Phi3ForTokenClassification(Phi3PreTrainedModel):
+ def __init__(self, config: Phi3Config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.model = Phi3Model(config)
+ if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
+ classifier_dropout = config.classifier_dropout
+ elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
+ classifier_dropout = config.hidden_dropout
+ else:
+ classifier_dropout = 0.1
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TokenClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **deprecated_arguments,
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ model_outputs = self.model(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = model_outputs[0]
+ hidden_states = self.dropout(hidden_states)
+ logits = self.classifier(hidden_states)
+
+ loss = None
+ if labels is not None:
+ # move labels to correct device to enable model parallelism
+ labels = labels.to(logits.device)
+ batch_size, seq_length = labels.shape
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(
+ logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
+ )
+
+ if not return_dict:
+ output = (logits,) + model_outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=model_outputs.hidden_states,
+ attentions=model_outputs.attentions,
+ )
diff --git a/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/sample_finetune.py b/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/sample_finetune.py
new file mode 100644
index 0000000000000000000000000000000000000000..610879e593a4333aba7b00a13c3c84bacd3f549d
--- /dev/null
+++ b/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/sample_finetune.py
@@ -0,0 +1,217 @@
+import sys
+import logging
+
+import datasets
+from datasets import load_dataset
+from peft import LoraConfig
+import torch
+import transformers
+from trl import SFTTrainer
+from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig
+
+"""
+A simple example on using SFTTrainer and Accelerate to finetune Phi-3 models. For
+a more advanced example, please follow HF alignment-handbook/scripts/run_sft.py.
+This example has utilized DeepSpeed ZeRO3 offload to reduce the memory usage. The
+script can be run on V100 or later generation GPUs. Here are some suggestions on
+futher reducing memory consumption:
+ - reduce batch size
+ - decrease lora dimension
+ - restrict lora target modules
+Please follow these steps to run the script:
+1. Install dependencies:
+ conda install -c conda-forge accelerate
+ pip3 install -i https://pypi.org/simple/ bitsandbytes
+ pip3 install peft transformers trl datasets
+ pip3 install deepspeed
+2. Setup accelerate and deepspeed config based on the machine used:
+ accelerate config
+Here is a sample config for deepspeed zero3:
+ compute_environment: LOCAL_MACHINE
+ debug: false
+ deepspeed_config:
+ gradient_accumulation_steps: 1
+ offload_optimizer_device: none
+ offload_param_device: none
+ zero3_init_flag: true
+ zero3_save_16bit_model: true
+ zero_stage: 3
+ distributed_type: DEEPSPEED
+ downcast_bf16: 'no'
+ enable_cpu_affinity: false
+ machine_rank: 0
+ main_training_function: main
+ mixed_precision: bf16
+ num_machines: 1
+ num_processes: 4
+ rdzv_backend: static
+ same_network: true
+ tpu_env: []
+ tpu_use_cluster: false
+ tpu_use_sudo: false
+ use_cpu: false
+3. check accelerate config:
+ accelerate env
+4. Run the code:
+ accelerate launch sample_finetune.py
+"""
+
+logger = logging.getLogger(__name__)
+
+
+###################
+# Hyper-parameters
+###################
+training_config = {
+ "bf16": True,
+ "do_eval": False,
+ "learning_rate": 5.0e-06,
+ "log_level": "info",
+ "logging_steps": 20,
+ "logging_strategy": "steps",
+ "lr_scheduler_type": "cosine",
+ "num_train_epochs": 1,
+ "max_steps": -1,
+ "output_dir": "./checkpoint_dir",
+ "overwrite_output_dir": True,
+ "per_device_eval_batch_size": 4,
+ "per_device_train_batch_size": 4,
+ "remove_unused_columns": True,
+ "save_steps": 100,
+ "save_total_limit": 1,
+ "seed": 0,
+ "gradient_checkpointing": True,
+ "gradient_checkpointing_kwargs":{"use_reentrant": False},
+ "gradient_accumulation_steps": 1,
+ "warmup_ratio": 0.2,
+ }
+
+peft_config = {
+ "r": 16,
+ "lora_alpha": 32,
+ "lora_dropout": 0.05,
+ "bias": "none",
+ "task_type": "CAUSAL_LM",
+ "target_modules": "all-linear",
+ "modules_to_save": None,
+}
+train_conf = TrainingArguments(**training_config)
+peft_conf = LoraConfig(**peft_config)
+
+
+###############
+# Setup logging
+###############
+logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ handlers=[logging.StreamHandler(sys.stdout)],
+)
+log_level = train_conf.get_process_log_level()
+logger.setLevel(log_level)
+datasets.utils.logging.set_verbosity(log_level)
+transformers.utils.logging.set_verbosity(log_level)
+transformers.utils.logging.enable_default_handler()
+transformers.utils.logging.enable_explicit_format()
+
+# Log on each process a small summary
+logger.warning(
+ f"Process rank: {train_conf.local_rank}, device: {train_conf.device}, n_gpu: {train_conf.n_gpu}"
+ + f" distributed training: {bool(train_conf.local_rank != -1)}, 16-bits training: {train_conf.fp16}"
+)
+logger.info(f"Training/evaluation parameters {train_conf}")
+logger.info(f"PEFT parameters {peft_conf}")
+
+
+################
+# Modle Loading
+################
+checkpoint_path = "microsoft/Phi-3-mini-4k-instruct"
+# checkpoint_path = "microsoft/Phi-3-mini-128k-instruct"
+model_kwargs = dict(
+ use_cache=False,
+ trust_remote_code=True,
+ attn_implementation="flash_attention_2", # loading the model with flash-attenstion support
+ torch_dtype=torch.bfloat16,
+ device_map=None
+)
+model = AutoModelForCausalLM.from_pretrained(checkpoint_path, **model_kwargs)
+tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
+tokenizer.model_max_length = 2048
+tokenizer.pad_token = tokenizer.unk_token # use unk rather than eos token to prevent endless generation
+tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
+tokenizer.padding_side = 'right'
+
+
+##################
+# Data Processing
+##################
+def apply_chat_template(
+ example,
+ tokenizer,
+):
+ messages = example["messages"]
+ # Add an empty system message if there is none
+ if messages[0]["role"] != "system":
+ messages.insert(0, {"role": "system", "content": ""})
+ example["text"] = tokenizer.apply_chat_template(
+ messages, tokenize=False, add_generation_prompt=False)
+ return example
+
+raw_dataset = load_dataset("HuggingFaceH4/ultrachat_200k")
+train_dataset = raw_dataset["train_sft"]
+test_dataset = raw_dataset["test_sft"]
+column_names = list(train_dataset.features)
+
+processed_train_dataset = train_dataset.map(
+ apply_chat_template,
+ fn_kwargs={"tokenizer": tokenizer},
+ num_proc=10,
+ remove_columns=column_names,
+ desc="Applying chat template to train_sft",
+)
+
+processed_test_dataset = test_dataset.map(
+ apply_chat_template,
+ fn_kwargs={"tokenizer": tokenizer},
+ num_proc=10,
+ remove_columns=column_names,
+ desc="Applying chat template to test_sft",
+)
+
+
+###########
+# Training
+###########
+trainer = SFTTrainer(
+ model=model,
+ args=train_conf,
+ peft_config=peft_conf,
+ train_dataset=processed_train_dataset,
+ eval_dataset=processed_test_dataset,
+ max_seq_length=2048,
+ dataset_text_field="text",
+ tokenizer=tokenizer,
+ packing=True
+)
+train_result = trainer.train()
+metrics = train_result.metrics
+trainer.log_metrics("train", metrics)
+trainer.save_metrics("train", metrics)
+trainer.save_state()
+
+
+#############
+# Evaluation
+#############
+tokenizer.padding_side = 'left'
+metrics = trainer.evaluate()
+metrics["eval_samples"] = len(processed_test_dataset)
+trainer.log_metrics("eval", metrics)
+trainer.save_metrics("eval", metrics)
+
+
+# ############
+# # Save model
+# ############
+trainer.save_model(train_conf.output_dir)
\ No newline at end of file
diff --git a/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/special_tokens_map.json b/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/special_tokens_map.json
new file mode 100644
index 0000000000000000000000000000000000000000..4592102136f610e23d1fd9c18b602963ca950392
--- /dev/null
+++ b/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/special_tokens_map.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:51d7c72bbb0e5dbc001ba6cb799c53dee0539303d4e9c483583cf12e9fe48e48
+size 568
diff --git a/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/tokenizer.json b/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/tokenizer.json
new file mode 100644
index 0000000000000000000000000000000000000000..ff2599b90bc00c0199c9798d87e84b99b90787a9
--- /dev/null
+++ b/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/tokenizer.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0bbddd4b39b594027b022cf22c47669dcd9e05ffc3b6d4a972b39a713750f823
+size 1844409
diff --git a/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/tokenizer.model b/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/tokenizer.model
new file mode 100644
index 0000000000000000000000000000000000000000..6c00c742ce03c627d6cd5b795984876fa49fa899
--- /dev/null
+++ b/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/tokenizer.model
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
+size 499723
diff --git a/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/tokenizer_config.json b/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/tokenizer_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..f3ba9966372f28e95d1c5491330b524b924ccfe6
--- /dev/null
+++ b/lava-vicuna_2024_4_Phi-3-mini-4k-instruct/tokenizer_config.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:441a655644c244ab6fb6aae4320e5b01793bc5a9ef03dd94e9f6dedf337ec01b
+size 3169
diff --git a/llava/__init__.py b/llava/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b4298cb298e00897251a01dea121732b588766c
--- /dev/null
+++ b/llava/__init__.py
@@ -0,0 +1 @@
+from .model import LlavaPhiForCausalLM
\ No newline at end of file
diff --git a/llava/__pycache__/__init__.cpython-310.pyc b/llava/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2765f1033a6bbfb80b3d8ec50694e9f061e8c001
--- /dev/null
+++ b/llava/__pycache__/__init__.cpython-310.pyc
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:477a627025b05218587d4685fdce9a89d1188091f8ff133905d99cbb0c91ed76
+size 191
diff --git a/llava/__pycache__/constants.cpython-310.pyc b/llava/__pycache__/constants.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f9c440edd15b72844715690c3f120075bc41c82c
--- /dev/null
+++ b/llava/__pycache__/constants.cpython-310.pyc
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dea6328de8a9b83efa7e659a28a8f2a63e9f6c2d919ce1163b018759aa706ce9
+size 534
diff --git a/llava/__pycache__/conversation.cpython-310.pyc b/llava/__pycache__/conversation.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dafd04684f7d946214abe52e5bfa9c88602cb04f
--- /dev/null
+++ b/llava/__pycache__/conversation.cpython-310.pyc
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d16ca7eeefdfe3a131acacd089db4ca20c5a43e8cde8067e8a2c9ab76e0fc4f2
+size 10902
diff --git a/llava/__pycache__/mm_utils.cpython-310.pyc b/llava/__pycache__/mm_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..aef3ae215f35390181d3a6eed46c5b037ab37f8c
--- /dev/null
+++ b/llava/__pycache__/mm_utils.cpython-310.pyc
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8af51e494f02363457a814aad8fb0e89241e3fcb5dbf56514ac2f242a52c6d8e
+size 8773
diff --git a/llava/bpe_simple_vocab_16e6.txt.gz b/llava/bpe_simple_vocab_16e6.txt.gz
new file mode 100644
index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113
--- /dev/null
+++ b/llava/bpe_simple_vocab_16e6.txt.gz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
+size 1356917
diff --git a/llava/constants.py b/llava/constants.py
new file mode 100644
index 0000000000000000000000000000000000000000..783a0fa4de72738da29a043e2e5a1a5ad8fc35b2
--- /dev/null
+++ b/llava/constants.py
@@ -0,0 +1,14 @@
+CONTROLLER_HEART_BEAT_EXPIRATION = 30
+WORKER_HEART_BEAT_INTERVAL = 15
+
+LOGDIR = "."
+
+# Model Constants
+IGNORE_INDEX = -100
+IMAGE_TOKEN_INDEX = -200
+DEFAULT_PC_TOKEN = ""
+DEFAULT_IMAGE_TOKEN = ""
+DEFAULT_IMAGE_PATCH_TOKEN = ""
+DEFAULT_IM_START_TOKEN = ""
+DEFAULT_IM_END_TOKEN = ""
+IMAGE_PLACEHOLDER = ""
diff --git a/llava/conversation.py b/llava/conversation.py
new file mode 100644
index 0000000000000000000000000000000000000000..8328d024b826582673d4ee14f4e78aaf4d51910f
--- /dev/null
+++ b/llava/conversation.py
@@ -0,0 +1,422 @@
+# Modified from LLaVA: https://github.com/haotian-liu/LLaVA.git
+import dataclasses
+from enum import auto, Enum
+from typing import List, Tuple
+import base64
+from io import BytesIO
+from PIL import Image
+
+
+class SeparatorStyle(Enum):
+ """Different separator style."""
+ SINGLE = auto()
+ TWO = auto()
+ MPT = auto()
+ PLAIN = auto()
+ LLAMA_2 = auto()
+
+
+@dataclasses.dataclass
+class Conversation:
+ """A class that keeps all conversation history."""
+ system: str
+ roles: List[str]
+ messages: List[List[str]]
+ offset: int
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
+ sep: str = "###"
+ sep2: str = None
+ version: str = "Unknown"
+
+ skip_next: bool = False
+
+ def get_prompt(self):
+ messages = self.messages
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
+ messages = self.messages.copy()
+ init_role, init_msg = messages[0].copy()
+ init_msg = init_msg[0].replace("", "").strip()
+ if 'mmtag' in self.version:
+ messages[0] = (init_role, init_msg)
+ messages.insert(0, (self.roles[0], ""))
+ messages.insert(1, (self.roles[1], "Received."))
+ else:
+ messages[0] = (init_role, "\n" + init_msg)
+
+ if self.sep_style == SeparatorStyle.SINGLE:
+ ret = self.system + self.sep
+ for role, message in messages:
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += role + ": " + message + self.sep
+ else:
+ ret += role + ":"
+ elif self.sep_style == SeparatorStyle.TWO:
+ seps = [self.sep, self.sep2]
+ ret = self.system + seps[0]
+ for i, (role, message) in enumerate(messages):
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += role + ": " + message + seps[i % 2]
+ else:
+ ret += role + ":"
+ elif self.sep_style == SeparatorStyle.MPT:
+ ret = self.system + self.sep
+ for role, message in messages:
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += role + message + self.sep
+ else:
+ ret += role
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
+ wrap_sys = lambda msg: f"<>\n{msg}\n<>\n\n" if len(msg) > 0 else msg
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
+ ret = ""
+
+ for i, (role, message) in enumerate(messages):
+ if i == 0:
+ assert message, "first message should not be none"
+ assert role == self.roles[0], "first message should come from user"
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ if i == 0: message = wrap_sys(self.system) + message
+ if i % 2 == 0:
+ message = wrap_inst(message)
+ ret += self.sep + message
+ else:
+ ret += " " + message + " " + self.sep2
+ else:
+ ret += ""
+ ret = ret.lstrip(self.sep)
+ elif self.sep_style == SeparatorStyle.PLAIN:
+ seps = [self.sep, self.sep2]
+ ret = self.system
+ for i, (role, message) in enumerate(messages):
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += message + seps[i % 2]
+ else:
+ ret += ""
+ else:
+ raise ValueError(f"Invalid style: {self.sep_style}")
+
+ return ret
+
+ def append_message(self, role, message):
+ self.messages.append([role, message])
+
+ def process_image(self, image, image_process_mode, return_pil=False, image_format='PNG', max_len=1344, min_len=672):
+ if image_process_mode == "Pad":
+ def expand2square(pil_img, background_color=(122, 116, 104)):
+ width, height = pil_img.size
+ if width == height:
+ return pil_img
+ elif width > height:
+ result = Image.new(pil_img.mode, (width, width), background_color)
+ result.paste(pil_img, (0, (width - height) // 2))
+ return result
+ else:
+ result = Image.new(pil_img.mode, (height, height), background_color)
+ result.paste(pil_img, ((height - width) // 2, 0))
+ return result
+ image = expand2square(image)
+ elif image_process_mode in ["Default", "Crop"]:
+ pass
+ elif image_process_mode == "Resize":
+ image = image.resize((336, 336))
+ else:
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
+ if max(image.size) > max_len:
+ max_hw, min_hw = max(image.size), min(image.size)
+ aspect_ratio = max_hw / min_hw
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
+ longest_edge = int(shortest_edge * aspect_ratio)
+ W, H = image.size
+ if H > W:
+ H, W = longest_edge, shortest_edge
+ else:
+ H, W = shortest_edge, longest_edge
+ image = image.resize((W, H))
+ if return_pil:
+ return image
+ else:
+ buffered = BytesIO()
+ image.save(buffered, format=image_format)
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
+ return img_b64_str
+
+ def get_images(self, return_pil=False):
+ images = []
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
+ if i % 2 == 0:
+ if type(msg) is tuple:
+ msg, image, image_process_mode = msg
+ image = self.process_image(image, image_process_mode, return_pil=return_pil)
+ images.append(image)
+ return images
+
+ def to_gradio_chatbot(self):
+ ret = []
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
+ if i % 2 == 0:
+ if type(msg) is tuple:
+ msg, image, image_process_mode = msg
+ img_b64_str = self.process_image(
+ image, "Default", return_pil=False,
+ image_format='JPEG')
+ img_str = f'
'
+ msg = img_str + msg.replace('', '').strip()
+ ret.append([msg, None])
+ else:
+ ret.append([msg, None])
+ else:
+ ret[-1][-1] = msg
+ return ret
+
+ def copy(self):
+ return Conversation(
+ system=self.system,
+ roles=self.roles,
+ messages=[[x, y] for x, y in self.messages],
+ offset=self.offset,
+ sep_style=self.sep_style,
+ sep=self.sep,
+ sep2=self.sep2,
+ version=self.version)
+
+ def dict(self):
+ if len(self.get_images()) > 0:
+ return {
+ "system": self.system,
+ "roles": self.roles,
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
+ "offset": self.offset,
+ "sep": self.sep,
+ "sep2": self.sep2,
+ }
+ return {
+ "system": self.system,
+ "roles": self.roles,
+ "messages": self.messages,
+ "offset": self.offset,
+ "sep": self.sep,
+ "sep2": self.sep2,
+ }
+
+
+conv_vicuna_v0 = Conversation(
+ system="A chat between a curious human and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
+ roles=("Human", "Assistant"),
+ messages=(
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
+ ("Assistant",
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
+ "renewable and non-renewable energy sources:\n"
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
+ "energy sources are finite and will eventually run out.\n"
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
+ "and other negative effects.\n"
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
+ "have lower operational costs than non-renewable sources.\n"
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
+ "locations than non-renewable sources.\n"
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
+ ),
+ offset=2,
+ sep_style=SeparatorStyle.SINGLE,
+ sep="###",
+)
+
+conv_vicuna_v1 = Conversation(
+ system="A chat between a curious user and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
+ roles=("USER", "ASSISTANT"),
+ version="v1",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.TWO,
+ sep=" ",
+ sep2="",
+)
+
+conv_llama_2 = Conversation(
+ system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
+
+If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
+ roles=("USER", "ASSISTANT"),
+ version="llama_v2",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.LLAMA_2,
+ sep="",
+ sep2="",
+)
+
+conv_llava_llama_2 = Conversation(
+ system="You are a helpful language and vision assistant. "
+ "You are able to understand the visual content that the user provides, "
+ "and assist the user with a variety of tasks using natural language.",
+ roles=("USER", "ASSISTANT"),
+ version="llama_v2",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.LLAMA_2,
+ sep="",
+ sep2="",
+)
+
+conv_mpt = Conversation(
+ system="""<|im_start|>system
+A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
+ version="mpt",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.MPT,
+ sep="<|im_end|>",
+)
+
+conv_llava_plain = Conversation(
+ system="",
+ roles=("", ""),
+ messages=(
+ ),
+ offset=0,
+ sep_style=SeparatorStyle.PLAIN,
+ sep="\n",
+)
+
+conv_llava_v0 = Conversation(
+ system="A chat between a curious human and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
+ roles=("Human", "Assistant"),
+ messages=(
+ ),
+ offset=0,
+ sep_style=SeparatorStyle.SINGLE,
+ sep="###",
+)
+
+conv_llava_v0_mmtag = Conversation(
+ system="A chat between a curious user and an artificial intelligence assistant. "
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
+ "The visual content will be provided with the following format: visual content.",
+ roles=("Human", "Assistant"),
+ messages=(
+ ),
+ offset=0,
+ sep_style=SeparatorStyle.SINGLE,
+ sep="###",
+ version="v0_mmtag",
+)
+
+conv_llava_v1 = Conversation(
+ system="A chat between a curious human and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
+ roles=("USER", "ASSISTANT"),
+ version="v1",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.TWO,
+ sep=" ",
+ sep2="",
+)
+
+conv_llava_v1_mmtag = Conversation(
+ system="A chat between a curious user and an artificial intelligence assistant. "
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
+ "The visual content will be provided with the following format: visual content.",
+ roles=("USER", "ASSISTANT"),
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.TWO,
+ sep=" ",
+ sep2="",
+ version="v1_mmtag",
+)
+
+conv_mistral_instruct = Conversation(
+ system="",
+ roles=("USER", "ASSISTANT"),
+ version="llama_v2",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.LLAMA_2,
+ sep="",
+ sep2="",
+)
+
+conv_chatml_direct = Conversation(
+ system="""<|im_start|>system
+Answer the questions.""",
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
+ version="mpt",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.MPT,
+ sep="<|im_end|>",
+)
+
+conv_phi3_instruct = Conversation(
+ system="""<|system|>\nYou are a helpful AI assistant.""",
+ roles=("\n<|user|>\n", "\n<|assistant|>\n"),
+ version="phi3",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.MPT,
+ sep="<|end|>",
+)
+
+
+conv_phi3_instruct_v2 = Conversation(
+ system="""<|system|>\nProvide a detailed answer,Provide a detailed answer,Provide a detailed answer""",
+ roles=("\n<|user|>\n", "\n<|assistant|>\n"),
+ version="phi3",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.MPT,
+ sep="<|end|>",
+)
+
+
+
+default_conversation = conv_vicuna_v1
+conv_templates = {
+ "default": conv_vicuna_v0,
+ "v0": conv_vicuna_v0,
+ "v1": conv_vicuna_v1,
+ "vicuna_v1": conv_vicuna_v1,
+ "llama_2": conv_llama_2,
+ "mistral_instruct": conv_mistral_instruct,
+ "chatml_direct": conv_chatml_direct,
+ "mistral_direct": conv_chatml_direct,
+
+ "plain": conv_llava_plain,
+ "v0_plain": conv_llava_plain,
+ "llava_v0": conv_llava_v0,
+ "v0_mmtag": conv_llava_v0_mmtag,
+ "llava_v1": conv_llava_v1,
+ "v1_mmtag": conv_llava_v1_mmtag,
+ "llava_llama_2": conv_llava_llama_2,
+ "phi3_instruct": conv_phi3_instruct,
+ "phi3_instruct_v2": conv_phi3_instruct_v2,
+
+ "mpt": conv_mpt,
+}
+
+
+if __name__ == "__main__":
+ print(default_conversation.get_prompt())
diff --git a/llava/eval/eval_gpt_review.py b/llava/eval/eval_gpt_review.py
new file mode 100644
index 0000000000000000000000000000000000000000..8af4559c65fc2728b11fd2097a109981ee1ef686
--- /dev/null
+++ b/llava/eval/eval_gpt_review.py
@@ -0,0 +1,113 @@
+import argparse
+import json
+import os
+
+import openai
+import tqdm
+import ray
+import time
+
+NUM_SECONDS_TO_SLEEP = 3
+
+@ray.remote(num_cpus=4)
+def get_eval(content: str, max_tokens: int):
+ while True:
+ try:
+ response = openai.ChatCompletion.create(
+ model='gpt-4',
+ messages=[{
+ 'role': 'system',
+ 'content': 'You are a helpful and precise assistant for checking the quality of the answer.'
+ }, {
+ 'role': 'user',
+ 'content': content,
+ }],
+ temperature=0.2, # TODO: figure out which temperature is best for evaluation
+ max_tokens=max_tokens,
+ )
+ break
+ except openai.error.RateLimitError:
+ pass
+ except Exception as e:
+ print(e)
+ time.sleep(NUM_SECONDS_TO_SLEEP)
+
+ print('success!')
+ return response['choices'][0]['message']['content']
+
+
+def parse_score(review):
+ try:
+ score_pair = review.split('\n')[0]
+ score_pair = score_pair.replace(',', ' ')
+ sp = score_pair.split(' ')
+ if len(sp) == 2:
+ return [float(sp[0]), float(sp[1])]
+ else:
+ print('error', review)
+ return [-1, -1]
+ except Exception as e:
+ print(e)
+ print('error', review)
+ return [-1, -1]
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
+ parser.add_argument('-q', '--question')
+ # parser.add_argument('-a', '--answer')
+ parser.add_argument('-a', '--answer-list', nargs='+', default=[])
+ parser.add_argument('-r', '--rule')
+ parser.add_argument('-o', '--output')
+ parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
+ args = parser.parse_args()
+
+ ray.init()
+
+ f_q = open(os.path.expanduser(args.question))
+ f_ans1 = open(os.path.expanduser(args.answer_list[0]))
+ f_ans2 = open(os.path.expanduser(args.answer_list[1]))
+ rule_dict = json.load(open(os.path.expanduser(args.rule), 'r'))
+
+ review_file = open(f'{args.output}', 'w')
+
+ js_list = []
+ handles = []
+ idx = 0
+ for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2):
+ # if idx == 1:
+ # break
+
+ ques = json.loads(ques_js)
+ ans1 = json.loads(ans1_js)
+ ans2 = json.loads(ans2_js)
+
+ category = json.loads(ques_js)['category']
+ if category in rule_dict:
+ rule = rule_dict[category]
+ else:
+ rule = rule_dict['default']
+ prompt = rule['prompt']
+ role = rule['role']
+ content = (f'[Question]\n{ques["text"]}\n\n'
+ f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n'
+ f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n'
+ f'[System]\n{prompt}\n\n')
+ js_list.append({
+ 'id': idx+1,
+ 'question_id': ques['question_id'],
+ 'answer1_id': ans1['answer_id'],
+ 'answer2_id': ans2['answer_id'],
+ 'category': category})
+ idx += 1
+ handles.append(get_eval.remote(content, args.max_tokens))
+ # To avoid the rate limit set by OpenAI
+ time.sleep(NUM_SECONDS_TO_SLEEP)
+
+ reviews = ray.get(handles)
+ for idx, review in enumerate(reviews):
+ scores = parse_score(review)
+ js_list[idx]['content'] = review
+ js_list[idx]['tuple'] = scores
+ review_file.write(json.dumps(js_list[idx]) + '\n')
+ review_file.close()
diff --git a/llava/eval/eval_gpt_review_bench.py b/llava/eval/eval_gpt_review_bench.py
new file mode 100644
index 0000000000000000000000000000000000000000..06160f2422b5368f30fb967f7cae635208a1dc69
--- /dev/null
+++ b/llava/eval/eval_gpt_review_bench.py
@@ -0,0 +1,121 @@
+import argparse
+import json
+import os
+
+import openai
+import time
+
+NUM_SECONDS_TO_SLEEP = 0.5
+
+
+def get_eval(content: str, max_tokens: int):
+ while True:
+ try:
+ response = openai.ChatCompletion.create(
+ model='gpt-4-0314',
+ messages=[{
+ 'role': 'system',
+ 'content': 'You are a helpful and precise assistant for checking the quality of the answer.'
+ }, {
+ 'role': 'user',
+ 'content': content,
+ }],
+ temperature=0.2, # TODO: figure out which temperature is best for evaluation
+ max_tokens=max_tokens,
+ )
+ break
+ except openai.error.RateLimitError:
+ pass
+ except Exception as e:
+ print(e)
+ time.sleep(NUM_SECONDS_TO_SLEEP)
+
+ return response['choices'][0]['message']['content']
+
+
+def parse_score(review):
+ try:
+ score_pair = review.split('\n')[0]
+ score_pair = score_pair.replace(',', ' ')
+ sp = score_pair.split(' ')
+ if len(sp) == 2:
+ return [float(sp[0]), float(sp[1])]
+ else:
+ print('error', review)
+ return [-1, -1]
+ except Exception as e:
+ print(e)
+ print('error', review)
+ return [-1, -1]
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
+ parser.add_argument('-q', '--question')
+ parser.add_argument('-c', '--context')
+ parser.add_argument('-a', '--answer-list', nargs='+', default=[])
+ parser.add_argument('-r', '--rule')
+ parser.add_argument('-o', '--output')
+ parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
+ args = parser.parse_args()
+
+ f_q = open(os.path.expanduser(args.question))
+ f_ans1 = open(os.path.expanduser(args.answer_list[0]))
+ f_ans2 = open(os.path.expanduser(args.answer_list[1]))
+ rule_dict = json.load(open(os.path.expanduser(args.rule), 'r'))
+
+ if os.path.isfile(os.path.expanduser(args.output)):
+ cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))]
+ else:
+ cur_reviews = []
+
+ review_file = open(f'{args.output}', 'a')
+
+ context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))]
+ image_to_context = {context['image']: context for context in context_list}
+
+ handles = []
+ idx = 0
+ for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2):
+ ques = json.loads(ques_js)
+ ans1 = json.loads(ans1_js)
+ ans2 = json.loads(ans2_js)
+
+ inst = image_to_context[ques['image']]
+
+ if isinstance(inst['caption'], list):
+ cap_str = '\n'.join(inst['caption'])
+ else:
+ cap_str = inst['caption']
+
+ category = 'llava_bench_' + json.loads(ques_js)['category']
+ if category in rule_dict:
+ rule = rule_dict[category]
+ else:
+ assert False, f"Visual QA category not found in rule file: {category}."
+ prompt = rule['prompt']
+ role = rule['role']
+ content = (f'[Context]\n{cap_str}\n\n'
+ f'[Question]\n{ques["text"]}\n\n'
+ f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n'
+ f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n'
+ f'[System]\n{prompt}\n\n')
+ cur_js = {
+ 'id': idx+1,
+ 'question_id': ques['question_id'],
+ 'answer1_id': ans1.get('answer_id', ans1['question_id']),
+ 'answer2_id': ans2.get('answer_id', ans2['answer_id']),
+ 'category': category
+ }
+ if idx >= len(cur_reviews):
+ review = get_eval(content, args.max_tokens)
+ scores = parse_score(review)
+ cur_js['content'] = review
+ cur_js['tuple'] = scores
+ review_file.write(json.dumps(cur_js) + '\n')
+ review_file.flush()
+ else:
+ print(f'Skipping {idx} as we already have it.')
+ idx += 1
+ print(idx)
+ review_file.close()
diff --git a/llava/eval/eval_gpt_review_visual.py b/llava/eval/eval_gpt_review_visual.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6e407a400a67020d801e6c27a3c32a2ee38f30c
--- /dev/null
+++ b/llava/eval/eval_gpt_review_visual.py
@@ -0,0 +1,118 @@
+import argparse
+import json
+import os
+
+import openai
+import time
+
+NUM_SECONDS_TO_SLEEP = 0.5
+
+
+def get_eval(content: str, max_tokens: int):
+ while True:
+ try:
+ response = openai.ChatCompletion.create(
+ model='gpt-4-0314',
+ messages=[{
+ 'role': 'system',
+ 'content': 'You are a helpful and precise assistant for checking the quality of the answer.'
+ }, {
+ 'role': 'user',
+ 'content': content,
+ }],
+ temperature=0.2, # TODO: figure out which temperature is best for evaluation
+ max_tokens=max_tokens,
+ )
+ break
+ except openai.error.RateLimitError:
+ pass
+ except Exception as e:
+ print(e)
+ time.sleep(NUM_SECONDS_TO_SLEEP)
+
+ return response['choices'][0]['message']['content']
+
+
+def parse_score(review):
+ try:
+ score_pair = review.split('\n')[0]
+ score_pair = score_pair.replace(',', ' ')
+ sp = score_pair.split(' ')
+ if len(sp) == 2:
+ return [float(sp[0]), float(sp[1])]
+ else:
+ print('error', review)
+ return [-1, -1]
+ except Exception as e:
+ print(e)
+ print('error', review)
+ return [-1, -1]
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
+ parser.add_argument('-q', '--question')
+ parser.add_argument('-c', '--context')
+ parser.add_argument('-a', '--answer-list', nargs='+', default=[])
+ parser.add_argument('-r', '--rule')
+ parser.add_argument('-o', '--output')
+ parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
+ args = parser.parse_args()
+
+ f_q = open(os.path.expanduser(args.question))
+ f_ans1 = open(os.path.expanduser(args.answer_list[0]))
+ f_ans2 = open(os.path.expanduser(args.answer_list[1]))
+ rule_dict = json.load(open(os.path.expanduser(args.rule), 'r'))
+
+ if os.path.isfile(os.path.expanduser(args.output)):
+ cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))]
+ else:
+ cur_reviews = []
+
+ review_file = open(f'{args.output}', 'a')
+
+ context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))]
+ image_to_context = {context['image']: context for context in context_list}
+
+ handles = []
+ idx = 0
+ for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2):
+ ques = json.loads(ques_js)
+ ans1 = json.loads(ans1_js)
+ ans2 = json.loads(ans2_js)
+
+ inst = image_to_context[ques['image']]
+ cap_str = '\n'.join(inst['captions'])
+ box_str = '\n'.join([f'{instance["category"]}: {instance["bbox"]}' for instance in inst['instances']])
+
+ category = json.loads(ques_js)['category']
+ if category in rule_dict:
+ rule = rule_dict[category]
+ else:
+ assert False, f"Visual QA category not found in rule file: {category}."
+ prompt = rule['prompt']
+ role = rule['role']
+ content = (f'[Context]\n{cap_str}\n\n{box_str}\n\n'
+ f'[Question]\n{ques["text"]}\n\n'
+ f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n'
+ f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n'
+ f'[System]\n{prompt}\n\n')
+ cur_js = {
+ 'id': idx+1,
+ 'question_id': ques['question_id'],
+ 'answer1_id': ans1.get('answer_id', ans1['question_id']),
+ 'answer2_id': ans2.get('answer_id', ans2['answer_id']),
+ 'category': category
+ }
+ if idx >= len(cur_reviews):
+ review = get_eval(content, args.max_tokens)
+ scores = parse_score(review)
+ cur_js['content'] = review
+ cur_js['tuple'] = scores
+ review_file.write(json.dumps(cur_js) + '\n')
+ review_file.flush()
+ else:
+ print(f'Skipping {idx} as we already have it.')
+ idx += 1
+ print(idx)
+ review_file.close()
diff --git a/llava/eval/eval_pope.py b/llava/eval/eval_pope.py
new file mode 100644
index 0000000000000000000000000000000000000000..b115b8f2327ea9d972f9e41bcbb03c68be6b3508
--- /dev/null
+++ b/llava/eval/eval_pope.py
@@ -0,0 +1,81 @@
+import os
+import json
+import argparse
+
+def eval_pope(answers, label_file):
+ label_list = [json.loads(q)['label'] for q in open(label_file, 'r')]
+
+ for answer in answers:
+ text = answer['text']
+
+ # Only keep the first sentence
+ if text.find('.') != -1:
+ text = text.split('.')[0]
+
+ text = text.replace(',', '')
+ words = text.split(' ')
+ if 'No' in words or 'not' in words or 'no' in words:
+ answer['text'] = 'no'
+ else:
+ answer['text'] = 'yes'
+
+ for i in range(len(label_list)):
+ if label_list[i] == 'no':
+ label_list[i] = 0
+ else:
+ label_list[i] = 1
+
+ pred_list = []
+ for answer in answers:
+ if answer['text'] == 'no':
+ pred_list.append(0)
+ else:
+ pred_list.append(1)
+
+ pos = 1
+ neg = 0
+ yes_ratio = pred_list.count(1) / len(pred_list)
+
+ TP, TN, FP, FN = 0, 0, 0, 0
+ for pred, label in zip(pred_list, label_list):
+ if pred == pos and label == pos:
+ TP += 1
+ elif pred == pos and label == neg:
+ FP += 1
+ elif pred == neg and label == neg:
+ TN += 1
+ elif pred == neg and label == pos:
+ FN += 1
+
+ print('TP\tFP\tTN\tFN\t')
+ print('{}\t{}\t{}\t{}'.format(TP, FP, TN, FN))
+
+ precision = float(TP) / float(TP + FP)
+ recall = float(TP) / float(TP + FN)
+ f1 = 2*precision*recall / (precision + recall)
+ acc = (TP + TN) / (TP + TN + FP + FN)
+ print('Accuracy: {}'.format(acc))
+ print('Precision: {}'.format(precision))
+ print('Recall: {}'.format(recall))
+ print('F1 score: {}'.format(f1))
+ print('Yes ratio: {}'.format(yes_ratio))
+ print('%.3f, %.3f, %.3f, %.3f, %.3f' % (f1, acc, precision, recall, yes_ratio) )
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--annotation-dir", type=str)
+ parser.add_argument("--question-file", type=str)
+ parser.add_argument("--result-file", type=str)
+ args = parser.parse_args()
+
+ questions = [json.loads(line) for line in open(args.question_file)]
+ questions = {question['question_id']: question for question in questions}
+ answers = [json.loads(q) for q in open(args.result_file)]
+ for file in os.listdir(args.annotation_dir):
+ assert file.startswith('coco_pope_')
+ assert file.endswith('.json')
+ category = file[10:-5]
+ cur_answers = [x for x in answers if questions[x['question_id']]['category'] == category]
+ print('Category: {}, # samples: {}'.format(category, len(cur_answers)))
+ eval_pope(cur_answers, os.path.join(args.annotation_dir, file))
+ print("====================================")
diff --git a/llava/eval/eval_science_qa.py b/llava/eval/eval_science_qa.py
new file mode 100644
index 0000000000000000000000000000000000000000..ccf206bbd7a5d6376eef82d61b3ef8bbe0f71c6c
--- /dev/null
+++ b/llava/eval/eval_science_qa.py
@@ -0,0 +1,114 @@
+import argparse
+import json
+import os
+import re
+import random
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--base-dir', type=str)
+ parser.add_argument('--result-file', type=str)
+ parser.add_argument('--output-file', type=str)
+ parser.add_argument('--output-result', type=str)
+ parser.add_argument('--split', type=str, default='test')
+ parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"])
+ return parser.parse_args()
+
+
+def convert_caps(results):
+ fakecaps = []
+ for result in results:
+ image_id = result['question_id']
+ caption = result['text']
+ fakecaps.append({"image_id": int(image_id), "caption": caption})
+ return fakecaps
+
+
+def get_pred_idx(prediction, choices, options):
+ """
+ Get the index (e.g. 2) from the prediction (e.g. 'C')
+ """
+ if prediction in options[:len(choices)]:
+ return options.index(prediction)
+ else:
+ return -1
+ return random.choice(range(len(choices)))
+
+
+if __name__ == "__main__":
+ args = get_args()
+
+ base_dir = args.base_dir
+ split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split]
+ problems = json.load(open(os.path.join(base_dir, "problems.json")))
+ predictions = [json.loads(line) for line in open(args.result_file)]
+ predictions = {pred['question_id']: pred for pred in predictions}
+ split_problems = {idx: problems[idx] for idx in split_indices}
+
+ results = {'correct': [], 'incorrect': []}
+ sqa_results = {}
+ sqa_results['acc'] = None
+ sqa_results['correct'] = None
+ sqa_results['count'] = None
+ sqa_results['results'] = {}
+ sqa_results['outputs'] = {}
+
+ for prob_id, prob in split_problems.items():
+ if prob_id not in predictions:
+ pred = {'text': 'FAILED', 'prompt': 'Unknown'}
+ pred_text = 'FAILED'
+ else:
+ pred = predictions[prob_id]
+ pred_text = pred['text']
+
+ if pred_text in args.options:
+ answer = pred_text
+ elif len(pred_text) >= 3 and pred_text[0] in args.options and pred_text[1:3] == ". ":
+ answer = pred_text[0]
+ else:
+ pattern = re.compile(r'The answer is ([A-Z]).')
+ res = pattern.findall(pred_text)
+ if len(res) == 1:
+ answer = res[0] # 'A', 'B', ...
+ else:
+ answer = "FAILED"
+
+ pred_idx = get_pred_idx(answer, prob['choices'], args.options)
+
+ analysis = {
+ 'question_id': prob_id,
+ 'parsed_ans': answer,
+ 'ground_truth': args.options[prob['answer']],
+ 'question': pred['prompt'],
+ 'pred': pred_text,
+ 'is_multimodal': '' in pred['prompt'],
+ }
+
+ sqa_results['results'][prob_id] = get_pred_idx(answer, prob['choices'], args.options)
+ sqa_results['outputs'][prob_id] = pred_text
+
+ if pred_idx == prob['answer']:
+ results['correct'].append(analysis)
+ else:
+ results['incorrect'].append(analysis)
+
+ correct = len(results['correct'])
+ total = len(results['correct']) + len(results['incorrect'])
+
+ ###### IMG ######
+ multimodal_correct = len([x for x in results['correct'] if x['is_multimodal']])
+ multimodal_incorrect = len([x for x in results['incorrect'] if x['is_multimodal']])
+ multimodal_total = multimodal_correct + multimodal_incorrect
+ ###### IMG ######
+
+ print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%, IMG-Accuracy: {multimodal_correct / multimodal_total * 100:.2f}%')
+
+ sqa_results['acc'] = correct / total * 100
+ sqa_results['correct'] = correct
+ sqa_results['count'] = total
+
+ with open(args.output_file, 'w') as f:
+ json.dump(results, f, indent=2)
+ with open(args.output_result, 'w') as f:
+ json.dump(sqa_results, f, indent=2)
diff --git a/llava/eval/eval_science_qa_gpt4.py b/llava/eval/eval_science_qa_gpt4.py
new file mode 100644
index 0000000000000000000000000000000000000000..c2ff17c915481fb556aba6ec816a9e08f519c515
--- /dev/null
+++ b/llava/eval/eval_science_qa_gpt4.py
@@ -0,0 +1,104 @@
+import argparse
+import json
+import os
+import re
+import random
+from collections import defaultdict
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--base-dir', type=str)
+ parser.add_argument('--gpt4-result', type=str)
+ parser.add_argument('--our-result', type=str)
+ parser.add_argument('--split', type=str, default='test')
+ parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"])
+ return parser.parse_args()
+
+
+def convert_caps(results):
+ fakecaps = []
+ for result in results:
+ image_id = result['question_id']
+ caption = result['text']
+ fakecaps.append({"image_id": int(image_id), "caption": caption})
+ return fakecaps
+
+
+def get_pred_idx(prediction, choices, options):
+ """
+ Get the index (e.g. 2) from the prediction (e.g. 'C')
+ """
+ if prediction in options[:len(choices)]:
+ return options.index(prediction)
+ else:
+ return random.choice(range(len(choices)))
+
+
+if __name__ == "__main__":
+ args = get_args()
+
+ base_dir = args.base_dir
+ split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split]
+ problems = json.load(open(os.path.join(base_dir, "problems.json")))
+ our_predictions = [json.loads(line) for line in open(args.our_result)]
+ our_predictions = {pred['question_id']: pred for pred in our_predictions}
+ split_problems = {idx: problems[idx] for idx in split_indices}
+
+ gpt4_predictions = json.load(open(args.gpt4_result))['outputs']
+
+ results = defaultdict(lambda: 0)
+
+ for prob_id, prob in split_problems.items():
+ if prob_id not in our_predictions:
+ continue
+ if prob_id not in gpt4_predictions:
+ continue
+ our_pred = our_predictions[prob_id]['text']
+ gpt4_pred = gpt4_predictions[prob_id]
+
+ pattern = re.compile(r'The answer is ([A-Z]).')
+ our_res = pattern.findall(our_pred)
+ if len(our_res) == 1:
+ our_answer = our_res[0] # 'A', 'B', ...
+ else:
+ our_answer = "FAILED"
+ gpt4_res = pattern.findall(gpt4_pred)
+ if len(gpt4_res) == 1:
+ gpt4_answer = gpt4_res[0] # 'A', 'B', ...
+ else:
+ gpt4_answer = "FAILED"
+
+ our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options)
+ gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options)
+
+ if gpt4_answer == 'FAILED':
+ results['gpt4_failed'] += 1
+ # continue
+ gpt4_pred_idx = our_pred_idx
+ # if our_pred_idx != prob['answer']:
+ # print(our_predictions[prob_id]['prompt'])
+ # print('-----------------')
+ # print(f'LECTURE: {prob["lecture"]}')
+ # print(f'SOLUTION: {prob["solution"]}')
+ # print('=====================')
+ else:
+ # continue
+ pass
+ # gpt4_pred_idx = our_pred_idx
+
+ if gpt4_pred_idx == prob['answer']:
+ results['correct'] += 1
+ else:
+ results['incorrect'] += 1
+
+
+ if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']:
+ results['correct_upperbound'] += 1
+
+ correct = results['correct']
+ total = results['correct'] + results['incorrect']
+ print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%')
+ print(f'Total: {total}, Correct (upper): {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%')
+ print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%')
+
diff --git a/llava/eval/eval_science_qa_gpt4_requery.py b/llava/eval/eval_science_qa_gpt4_requery.py
new file mode 100644
index 0000000000000000000000000000000000000000..698546e995d365d1ccc2c25a87e6c5cd681e6eb6
--- /dev/null
+++ b/llava/eval/eval_science_qa_gpt4_requery.py
@@ -0,0 +1,149 @@
+import argparse
+import json
+import os
+import re
+import random
+from collections import defaultdict
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--base-dir', type=str)
+ parser.add_argument('--gpt4-result', type=str)
+ parser.add_argument('--requery-result', type=str)
+ parser.add_argument('--our-result', type=str)
+ parser.add_argument('--output-result', type=str)
+ parser.add_argument('--split', type=str, default='test')
+ parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"])
+ return parser.parse_args()
+
+
+def convert_caps(results):
+ fakecaps = []
+ for result in results:
+ image_id = result['question_id']
+ caption = result['text']
+ fakecaps.append({"image_id": int(image_id), "caption": caption})
+ return fakecaps
+
+
+def get_pred_idx(prediction, choices, options):
+ """
+ Get the index (e.g. 2) from the prediction (e.g. 'C')
+ """
+ if prediction in options[:len(choices)]:
+ return options.index(prediction)
+ else:
+ return random.choice(range(len(choices)))
+
+
+if __name__ == "__main__":
+ args = get_args()
+
+ base_dir = args.base_dir
+ split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split]
+ problems = json.load(open(os.path.join(base_dir, "problems.json")))
+ our_predictions = [json.loads(line) for line in open(args.our_result)]
+ our_predictions = {pred['question_id']: pred for pred in our_predictions}
+ split_problems = {idx: problems[idx] for idx in split_indices}
+
+ requery_predictions = [json.loads(line) for line in open(args.requery_result)]
+ requery_predictions = {pred['question_id']: pred for pred in requery_predictions}
+
+ gpt4_predictions = json.load(open(args.gpt4_result))['outputs']
+
+ results = defaultdict(lambda: 0)
+
+ sqa_results = {}
+ sqa_results['acc'] = None
+ sqa_results['correct'] = None
+ sqa_results['count'] = None
+ sqa_results['results'] = {}
+ sqa_results['outputs'] = {}
+
+ for prob_id, prob in split_problems.items():
+ if prob_id not in our_predictions:
+ assert False
+ if prob_id not in gpt4_predictions:
+ assert False
+ our_pred = our_predictions[prob_id]['text']
+ gpt4_pred = gpt4_predictions[prob_id]
+ if prob_id not in requery_predictions:
+ results['missing_requery'] += 1
+ requery_pred = "MISSING"
+ else:
+ requery_pred = requery_predictions[prob_id]['text']
+
+ pattern = re.compile(r'The answer is ([A-Z]).')
+ our_res = pattern.findall(our_pred)
+ if len(our_res) == 1:
+ our_answer = our_res[0] # 'A', 'B', ...
+ else:
+ our_answer = "FAILED"
+
+ requery_res = pattern.findall(requery_pred)
+ if len(requery_res) == 1:
+ requery_answer = requery_res[0] # 'A', 'B', ...
+ else:
+ requery_answer = "FAILED"
+
+ gpt4_res = pattern.findall(gpt4_pred)
+ if len(gpt4_res) == 1:
+ gpt4_answer = gpt4_res[0] # 'A', 'B', ...
+ else:
+ gpt4_answer = "FAILED"
+
+ our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options)
+ gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options)
+ requery_pred_idx = get_pred_idx(requery_answer, prob['choices'], args.options)
+
+ results['total'] += 1
+
+ if gpt4_answer == 'FAILED':
+ results['gpt4_failed'] += 1
+ if gpt4_pred_idx == prob['answer']:
+ results['gpt4_correct'] += 1
+ if our_pred_idx == prob['answer']:
+ results['gpt4_ourvisual_correct'] += 1
+ elif gpt4_pred_idx == prob['answer']:
+ results['gpt4_correct'] += 1
+ results['gpt4_ourvisual_correct'] += 1
+
+ if our_pred_idx == prob['answer']:
+ results['our_correct'] += 1
+
+ if requery_answer == 'FAILED':
+ sqa_results['results'][prob_id] = our_pred_idx
+ if our_pred_idx == prob['answer']:
+ results['requery_correct'] += 1
+ else:
+ sqa_results['results'][prob_id] = requery_pred_idx
+ if requery_pred_idx == prob['answer']:
+ results['requery_correct'] += 1
+ else:
+ print(f"""
+Question ({args.options[prob['answer']]}): {our_predictions[prob_id]['prompt']}
+Our ({our_answer}): {our_pred}
+GPT-4 ({gpt4_answer}): {gpt4_pred}
+Requery ({requery_answer}): {requery_pred}
+print("=====================================")
+""")
+
+ if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']:
+ results['correct_upperbound'] += 1
+
+ total = results['total']
+ print(f'Total: {total}, Our-Correct: {results["our_correct"]}, Accuracy: {results["our_correct"] / total * 100:.2f}%')
+ print(f'Total: {total}, GPT-4-Correct: {results["gpt4_correct"]}, Accuracy: {results["gpt4_correct"] / total * 100:.2f}%')
+ print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%')
+ print(f'Total: {total}, GPT-4-OursVisual-Correct: {results["gpt4_ourvisual_correct"]}, Accuracy: {results["gpt4_ourvisual_correct"] / total * 100:.2f}%')
+ print(f'Total: {total}, Requery-Correct: {results["requery_correct"]}, Accuracy: {results["requery_correct"] / total * 100:.2f}%')
+ print(f'Total: {total}, Correct upper: {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%')
+
+ sqa_results['acc'] = results["requery_correct"] / total * 100
+ sqa_results['correct'] = results["requery_correct"]
+ sqa_results['count'] = total
+
+ with open(args.output_result, 'w') as f:
+ json.dump(sqa_results, f, indent=2)
+
diff --git a/llava/eval/eval_textvqa.py b/llava/eval/eval_textvqa.py
new file mode 100644
index 0000000000000000000000000000000000000000..468f4bb120448a036bd5b5c7955464fe2e13892a
--- /dev/null
+++ b/llava/eval/eval_textvqa.py
@@ -0,0 +1,65 @@
+import os
+import argparse
+import json
+import re
+
+from llava.eval.m4c_evaluator import TextVQAAccuracyEvaluator
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--annotation-file', type=str)
+ parser.add_argument('--result-file', type=str)
+ parser.add_argument('--result-dir', type=str)
+ return parser.parse_args()
+
+
+def prompt_processor(prompt):
+ if prompt.startswith('OCR tokens: '):
+ pattern = r"Question: (.*?) Short answer:"
+ match = re.search(pattern, prompt, re.DOTALL)
+ question = match.group(1)
+ elif 'Reference OCR token: ' in prompt and len(prompt.split('\n')) == 3:
+ if prompt.startswith('Reference OCR token:'):
+ question = prompt.split('\n')[1]
+ else:
+ question = prompt.split('\n')[0]
+ elif len(prompt.split('\n')) == 2:
+ question = prompt.split('\n')[0]
+ else:
+ assert False
+
+ return question.lower()
+
+
+def eval_single(annotation_file, result_file):
+ experiment_name = os.path.splitext(os.path.basename(result_file))[0]
+ print(experiment_name)
+ annotations = json.load(open(annotation_file))['data']
+ annotations = {(annotation['image_id'], annotation['question'].lower()): annotation for annotation in annotations}
+ results = [json.loads(line) for line in open(result_file)]
+
+ pred_list = []
+ for result in results:
+ annotation = annotations[(result['question_id'], prompt_processor(result['prompt']))]
+ pred_list.append({
+ "pred_answer": result['text'],
+ "gt_answers": annotation['answers'],
+ })
+
+ evaluator = TextVQAAccuracyEvaluator()
+ print('Samples: {}\nAccuracy: {:.2f}%\n'.format(len(pred_list), 100. * evaluator.eval_pred_list(pred_list)))
+
+
+if __name__ == "__main__":
+ args = get_args()
+
+ if args.result_file is not None:
+ eval_single(args.annotation_file, args.result_file)
+
+ if args.result_dir is not None:
+ for result_file in sorted(os.listdir(args.result_dir)):
+ if not result_file.endswith('.jsonl'):
+ print(f'Skipping {result_file}')
+ continue
+ eval_single(args.annotation_file, os.path.join(args.result_dir, result_file))
diff --git a/llava/eval/generate_webpage_data_from_table.py b/llava/eval/generate_webpage_data_from_table.py
new file mode 100644
index 0000000000000000000000000000000000000000..92602258ccd953a1d7137056aaf15c8de8166e21
--- /dev/null
+++ b/llava/eval/generate_webpage_data_from_table.py
@@ -0,0 +1,111 @@
+"""Generate json file for webpage."""
+import json
+import os
+import re
+
+# models = ['llama', 'alpaca', 'gpt35', 'bard']
+models = ['vicuna']
+
+
+def read_jsonl(path: str, key: str=None):
+ data = []
+ with open(os.path.expanduser(path)) as f:
+ for line in f:
+ if not line:
+ continue
+ data.append(json.loads(line))
+ if key is not None:
+ data.sort(key=lambda x: x[key])
+ data = {item[key]: item for item in data}
+ return data
+
+
+def trim_hanging_lines(s: str, n: int) -> str:
+ s = s.strip()
+ for _ in range(n):
+ s = s.split('\n', 1)[1].strip()
+ return s
+
+
+if __name__ == '__main__':
+ questions = read_jsonl('table/question.jsonl', key='question_id')
+
+ # alpaca_answers = read_jsonl('table/answer/answer_alpaca-13b.jsonl', key='question_id')
+ # bard_answers = read_jsonl('table/answer/answer_bard.jsonl', key='question_id')
+ # gpt35_answers = read_jsonl('table/answer/answer_gpt35.jsonl', key='question_id')
+ # llama_answers = read_jsonl('table/answer/answer_llama-13b.jsonl', key='question_id')
+ vicuna_answers = read_jsonl('table/answer/answer_vicuna-13b.jsonl', key='question_id')
+ ours_answers = read_jsonl('table/results/llama-13b-hf-alpaca.jsonl', key='question_id')
+
+ review_vicuna = read_jsonl('table/review/review_vicuna-13b_llama-13b-hf-alpaca.jsonl', key='question_id')
+ # review_alpaca = read_jsonl('table/review/review_alpaca-13b_vicuna-13b.jsonl', key='question_id')
+ # review_bard = read_jsonl('table/review/review_bard_vicuna-13b.jsonl', key='question_id')
+ # review_gpt35 = read_jsonl('table/review/review_gpt35_vicuna-13b.jsonl', key='question_id')
+ # review_llama = read_jsonl('table/review/review_llama-13b_vicuna-13b.jsonl', key='question_id')
+
+ records = []
+ for qid in questions.keys():
+ r = {
+ 'id': qid,
+ 'category': questions[qid]['category'],
+ 'question': questions[qid]['text'],
+ 'answers': {
+ # 'alpaca': alpaca_answers[qid]['text'],
+ # 'llama': llama_answers[qid]['text'],
+ # 'bard': bard_answers[qid]['text'],
+ # 'gpt35': gpt35_answers[qid]['text'],
+ 'vicuna': vicuna_answers[qid]['text'],
+ 'ours': ours_answers[qid]['text'],
+ },
+ 'evaluations': {
+ # 'alpaca': review_alpaca[qid]['text'],
+ # 'llama': review_llama[qid]['text'],
+ # 'bard': review_bard[qid]['text'],
+ 'vicuna': review_vicuna[qid]['content'],
+ # 'gpt35': review_gpt35[qid]['text'],
+ },
+ 'scores': {
+ 'vicuna': review_vicuna[qid]['tuple'],
+ # 'alpaca': review_alpaca[qid]['score'],
+ # 'llama': review_llama[qid]['score'],
+ # 'bard': review_bard[qid]['score'],
+ # 'gpt35': review_gpt35[qid]['score'],
+ },
+ }
+
+ # cleanup data
+ cleaned_evals = {}
+ for k, v in r['evaluations'].items():
+ v = v.strip()
+ lines = v.split('\n')
+ # trim the first line if it's a pair of numbers
+ if re.match(r'\d+[, ]+\d+', lines[0]):
+ lines = lines[1:]
+ v = '\n'.join(lines)
+ cleaned_evals[k] = v.replace('Assistant 1', "**Assistant 1**").replace('Assistant 2', '**Assistant 2**')
+
+ r['evaluations'] = cleaned_evals
+ records.append(r)
+
+ # Reorder the records, this is optional
+ for r in records:
+ if r['id'] <= 20:
+ r['id'] += 60
+ else:
+ r['id'] -= 20
+ for r in records:
+ if r['id'] <= 50:
+ r['id'] += 10
+ elif 50 < r['id'] <= 60:
+ r['id'] -= 50
+ for r in records:
+ if r['id'] == 7:
+ r['id'] = 1
+ elif r['id'] < 7:
+ r['id'] += 1
+
+ records.sort(key=lambda x: x['id'])
+
+ # Write to file
+ with open('webpage/data.json', 'w') as f:
+ json.dump({'questions': records, 'models': models}, f, indent=2)
diff --git a/llava/eval/m4c_evaluator.py b/llava/eval/m4c_evaluator.py
new file mode 100644
index 0000000000000000000000000000000000000000..e30e958da061a4f0a0bfe34b12d2fcaeba7ff2f4
--- /dev/null
+++ b/llava/eval/m4c_evaluator.py
@@ -0,0 +1,334 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import re
+
+from tqdm import tqdm
+
+
+class EvalAIAnswerProcessor:
+ """
+ Processes an answer similar to Eval AI
+ copied from
+ https://github.com/facebookresearch/mmf/blob/c46b3b3391275b4181567db80943473a89ab98ab/pythia/tasks/processors.py#L897
+ """
+
+ CONTRACTIONS = {
+ "aint": "ain't",
+ "arent": "aren't",
+ "cant": "can't",
+ "couldve": "could've",
+ "couldnt": "couldn't",
+ "couldn'tve": "couldn't've",
+ "couldnt've": "couldn't've",
+ "didnt": "didn't",
+ "doesnt": "doesn't",
+ "dont": "don't",
+ "hadnt": "hadn't",
+ "hadnt've": "hadn't've",
+ "hadn'tve": "hadn't've",
+ "hasnt": "hasn't",
+ "havent": "haven't",
+ "hed": "he'd",
+ "hed've": "he'd've",
+ "he'dve": "he'd've",
+ "hes": "he's",
+ "howd": "how'd",
+ "howll": "how'll",
+ "hows": "how's",
+ "Id've": "I'd've",
+ "I'dve": "I'd've",
+ "Im": "I'm",
+ "Ive": "I've",
+ "isnt": "isn't",
+ "itd": "it'd",
+ "itd've": "it'd've",
+ "it'dve": "it'd've",
+ "itll": "it'll",
+ "let's": "let's",
+ "maam": "ma'am",
+ "mightnt": "mightn't",
+ "mightnt've": "mightn't've",
+ "mightn'tve": "mightn't've",
+ "mightve": "might've",
+ "mustnt": "mustn't",
+ "mustve": "must've",
+ "neednt": "needn't",
+ "notve": "not've",
+ "oclock": "o'clock",
+ "oughtnt": "oughtn't",
+ "ow's'at": "'ow's'at",
+ "'ows'at": "'ow's'at",
+ "'ow'sat": "'ow's'at",
+ "shant": "shan't",
+ "shed've": "she'd've",
+ "she'dve": "she'd've",
+ "she's": "she's",
+ "shouldve": "should've",
+ "shouldnt": "shouldn't",
+ "shouldnt've": "shouldn't've",
+ "shouldn'tve": "shouldn't've",
+ "somebody'd": "somebodyd",
+ "somebodyd've": "somebody'd've",
+ "somebody'dve": "somebody'd've",
+ "somebodyll": "somebody'll",
+ "somebodys": "somebody's",
+ "someoned": "someone'd",
+ "someoned've": "someone'd've",
+ "someone'dve": "someone'd've",
+ "someonell": "someone'll",
+ "someones": "someone's",
+ "somethingd": "something'd",
+ "somethingd've": "something'd've",
+ "something'dve": "something'd've",
+ "somethingll": "something'll",
+ "thats": "that's",
+ "thered": "there'd",
+ "thered've": "there'd've",
+ "there'dve": "there'd've",
+ "therere": "there're",
+ "theres": "there's",
+ "theyd": "they'd",
+ "theyd've": "they'd've",
+ "they'dve": "they'd've",
+ "theyll": "they'll",
+ "theyre": "they're",
+ "theyve": "they've",
+ "twas": "'twas",
+ "wasnt": "wasn't",
+ "wed've": "we'd've",
+ "we'dve": "we'd've",
+ "weve": "we've",
+ "werent": "weren't",
+ "whatll": "what'll",
+ "whatre": "what're",
+ "whats": "what's",
+ "whatve": "what've",
+ "whens": "when's",
+ "whered": "where'd",
+ "wheres": "where's",
+ "whereve": "where've",
+ "whod": "who'd",
+ "whod've": "who'd've",
+ "who'dve": "who'd've",
+ "wholl": "who'll",
+ "whos": "who's",
+ "whove": "who've",
+ "whyll": "why'll",
+ "whyre": "why're",
+ "whys": "why's",
+ "wont": "won't",
+ "wouldve": "would've",
+ "wouldnt": "wouldn't",
+ "wouldnt've": "wouldn't've",
+ "wouldn'tve": "wouldn't've",
+ "yall": "y'all",
+ "yall'll": "y'all'll",
+ "y'allll": "y'all'll",
+ "yall'd've": "y'all'd've",
+ "y'alld've": "y'all'd've",
+ "y'all'dve": "y'all'd've",
+ "youd": "you'd",
+ "youd've": "you'd've",
+ "you'dve": "you'd've",
+ "youll": "you'll",
+ "youre": "you're",
+ "youve": "you've",
+ }
+
+ NUMBER_MAP = {
+ "none": "0",
+ "zero": "0",
+ "one": "1",
+ "two": "2",
+ "three": "3",
+ "four": "4",
+ "five": "5",
+ "six": "6",
+ "seven": "7",
+ "eight": "8",
+ "nine": "9",
+ "ten": "10",
+ }
+ ARTICLES = ["a", "an", "the"]
+ PERIOD_STRIP = re.compile(r"(?!<=\d)(\.)(?!\d)")
+ COMMA_STRIP = re.compile(r"(?<=\d)(\,)+(?=\d)")
+ PUNCTUATIONS = [
+ ";",
+ r"/",
+ "[",
+ "]",
+ '"',
+ "{",
+ "}",
+ "(",
+ ")",
+ "=",
+ "+",
+ "\\",
+ "_",
+ "-",
+ ">",
+ "<",
+ "@",
+ "`",
+ ",",
+ "?",
+ "!",
+ ]
+
+ def __init__(self, *args, **kwargs):
+ pass
+
+ def word_tokenize(self, word):
+ word = word.lower()
+ word = word.replace(",", "").replace("?", "").replace("'s", " 's")
+ return word.strip()
+
+ def process_punctuation(self, in_text):
+ out_text = in_text
+ for p in self.PUNCTUATIONS:
+ if (p + " " in in_text or " " + p in in_text) or (
+ re.search(self.COMMA_STRIP, in_text) is not None
+ ):
+ out_text = out_text.replace(p, "")
+ else:
+ out_text = out_text.replace(p, " ")
+ out_text = self.PERIOD_STRIP.sub("", out_text, re.UNICODE)
+ return out_text
+
+ def process_digit_article(self, in_text):
+ out_text = []
+ temp_text = in_text.lower().split()
+ for word in temp_text:
+ word = self.NUMBER_MAP.setdefault(word, word)
+ if word not in self.ARTICLES:
+ out_text.append(word)
+ else:
+ pass
+ for word_id, word in enumerate(out_text):
+ if word in self.CONTRACTIONS:
+ out_text[word_id] = self.CONTRACTIONS[word]
+ out_text = " ".join(out_text)
+ return out_text
+
+ def __call__(self, item):
+ item = self.word_tokenize(item)
+ item = item.replace("\n", " ").replace("\t", " ").strip()
+ item = self.process_punctuation(item)
+ item = self.process_digit_article(item)
+ return item
+
+
+class TextVQAAccuracyEvaluator:
+ def __init__(self):
+ self.answer_processor = EvalAIAnswerProcessor()
+
+ def _compute_answer_scores(self, raw_answers):
+ """
+ compute the accuracy (soft score) of human answers
+ """
+ answers = [self.answer_processor(a) for a in raw_answers]
+ assert len(answers) == 10
+ gt_answers = list(enumerate(answers))
+ unique_answers = set(answers)
+ unique_answer_scores = {}
+
+ for unique_answer in unique_answers:
+ accs = []
+ for gt_answer in gt_answers:
+ other_answers = [item for item in gt_answers if item != gt_answer]
+ matching_answers = [
+ item for item in other_answers if item[1] == unique_answer
+ ]
+ acc = min(1, float(len(matching_answers)) / 3)
+ accs.append(acc)
+ unique_answer_scores[unique_answer] = sum(accs) / len(accs)
+
+ return unique_answer_scores
+
+ def eval_pred_list(self, pred_list):
+ pred_scores = []
+ for entry in tqdm(pred_list):
+ pred_answer = self.answer_processor(entry["pred_answer"])
+ unique_answer_scores = self._compute_answer_scores(entry["gt_answers"])
+ score = unique_answer_scores.get(pred_answer, 0.0)
+ pred_scores.append(score)
+
+ accuracy = sum(pred_scores) / len(pred_scores)
+ return accuracy
+
+
+class STVQAAccuracyEvaluator:
+ def __init__(self):
+ self.answer_processor = EvalAIAnswerProcessor()
+
+ def eval_pred_list(self, pred_list):
+ pred_scores = []
+ for entry in pred_list:
+ pred_answer = self.answer_processor(entry["pred_answer"])
+ gts = [self.answer_processor(a) for a in entry["gt_answers"]]
+ score = 1.0 if pred_answer in gts else 0.0
+ pred_scores.append(score)
+
+ accuracy = sum(pred_scores) / len(pred_scores)
+ return accuracy
+
+
+class STVQAANLSEvaluator:
+ def __init__(self):
+ import editdistance # install with `pip install editdistance`
+
+ self.get_edit_distance = editdistance.eval
+
+ def get_anls(self, s1, s2):
+ s1 = s1.lower().strip()
+ s2 = s2.lower().strip()
+ iou = 1 - self.get_edit_distance(s1, s2) / max(len(s1), len(s2))
+ anls = iou if iou >= 0.5 else 0.0
+ return anls
+
+ def eval_pred_list(self, pred_list):
+ pred_scores = []
+ for entry in pred_list:
+ anls = max(
+ self.get_anls(entry["pred_answer"], gt) for gt in entry["gt_answers"]
+ )
+ pred_scores.append(anls)
+
+ accuracy = sum(pred_scores) / len(pred_scores)
+ return accuracy
+
+
+class TextCapsBleu4Evaluator:
+ def __init__(self):
+ # The following script requires Java 1.8.0 and pycocotools installed.
+ # The pycocoevalcap can be installed with pip as
+ # pip install git+https://github.com/ronghanghu/coco-caption.git@python23
+ # Original pycocoevalcap code is at https://github.com/tylin/coco-caption
+ # but has no python3 support yet.
+ try:
+ from pycocoevalcap.bleu.bleu import Bleu
+ from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
+ except ModuleNotFoundError:
+ print(
+ "Please install pycocoevalcap module using "
+ "pip install git+https://github.com/ronghanghu/coco-caption.git@python23" # noqa
+ )
+ raise
+
+ self.tokenizer = PTBTokenizer()
+ self.scorer = Bleu(4)
+
+ def eval_pred_list(self, pred_list):
+ # Create reference and hypotheses captions.
+ gts = {}
+ res = {}
+ for idx, entry in enumerate(pred_list):
+ gts[idx] = [{"caption": a} for a in entry["gt_answers"]]
+ res[idx] = [{"caption": entry["pred_answer"]}]
+
+ gts = self.tokenizer.tokenize(gts)
+ res = self.tokenizer.tokenize(res)
+ score, _ = self.scorer.compute_score(gts, res)
+
+ bleu4 = score[3] # score is (Bleu-1, Bleu-2, Bleu-3, Bleu-4)
+ return bleu4
diff --git a/llava/eval/model_qa.py b/llava/eval/model_qa.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e254da152ac644ff54fb5fa57e625d9e6ba31d1
--- /dev/null
+++ b/llava/eval/model_qa.py
@@ -0,0 +1,64 @@
+import argparse
+from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria
+import torch
+import os
+import json
+from tqdm import tqdm
+import shortuuid
+
+from llava.conversation import default_conversation
+from llava.utils import disable_torch_init
+
+
+@torch.inference_mode()
+def eval_model(model_name, questions_file, answers_file):
+ # Model
+ disable_torch_init()
+ model_name = os.path.expanduser(model_name)
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
+ model = AutoModelForCausalLM.from_pretrained(model_name,
+ torch_dtype=torch.float16).cuda()
+
+
+ ques_file = open(os.path.expanduser(questions_file), "r")
+ ans_file = open(os.path.expanduser(answers_file), "w")
+ for i, line in enumerate(tqdm(ques_file)):
+ idx = json.loads(line)["question_id"]
+ qs = json.loads(line)["text"]
+ cat = json.loads(line)["category"]
+ conv = default_conversation.copy()
+ conv.append_message(conv.roles[0], qs)
+ prompt = conv.get_prompt()
+ inputs = tokenizer([prompt])
+ input_ids = torch.as_tensor(inputs.input_ids).cuda()
+ output_ids = model.generate(
+ input_ids,
+ do_sample=True,
+ use_cache=True,
+ temperature=0.7,
+ max_new_tokens=1024,)
+ outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
+ try:
+ index = outputs.index(conv.sep, len(prompt))
+ except ValueError:
+ outputs += conv.sep
+ index = outputs.index(conv.sep, len(prompt))
+
+ outputs = outputs[len(prompt) + len(conv.roles[1]) + 2:index].strip()
+ ans_id = shortuuid.uuid()
+ ans_file.write(json.dumps({"question_id": idx,
+ "text": outputs,
+ "answer_id": ans_id,
+ "model_id": model_name,
+ "metadata": {}}) + "\n")
+ ans_file.flush()
+ ans_file.close()
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
+ parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
+ parser.add_argument("--answers-file", type=str, default="answer.jsonl")
+ args = parser.parse_args()
+
+ eval_model(args.model_name, args.question_file, args.answers_file)
diff --git a/llava/eval/model_vqa.py b/llava/eval/model_vqa.py
new file mode 100644
index 0000000000000000000000000000000000000000..938706438b1d332505fdd0e9670df72c31eee1b2
--- /dev/null
+++ b/llava/eval/model_vqa.py
@@ -0,0 +1,101 @@
+import argparse
+import torch
+import os
+import json
+from tqdm import tqdm
+import shortuuid
+
+from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
+from llava.conversation import conv_templates, SeparatorStyle
+from llava.model.builder import load_pretrained_model
+from llava.utils import disable_torch_init
+from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
+
+from PIL import Image
+import math
+
+
+def split_list(lst, n):
+ """Split a list into n (roughly) equal-sized chunks"""
+ chunk_size = math.ceil(len(lst) / n) # integer division
+ return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
+
+
+def get_chunk(lst, n, k):
+ chunks = split_list(lst, n)
+ return chunks[k]
+
+
+def eval_model(args):
+ # Model
+ disable_torch_init()
+ model_path = os.path.expanduser(args.model_path)
+ model_name = get_model_name_from_path(model_path)
+ tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
+
+ questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
+ questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
+ answers_file = os.path.expanduser(args.answers_file)
+ os.makedirs(os.path.dirname(answers_file), exist_ok=True)
+ ans_file = open(answers_file, "w")
+ for line in tqdm(questions):
+ idx = line["question_id"]
+ image_file = line["image"]
+ qs = line["text"]
+ cur_prompt = qs
+ if model.config.mm_use_im_start_end:
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
+ else:
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
+
+ conv = conv_templates[args.conv_mode].copy()
+ conv.append_message(conv.roles[0], qs)
+ conv.append_message(conv.roles[1], None)
+ prompt = conv.get_prompt()
+
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
+
+ image = Image.open(os.path.join(args.image_folder, image_file)).convert('RGB')
+ image_tensor = process_images([image], image_processor, model.config)[0]
+
+ with torch.inference_mode():
+ output_ids = model.generate(
+ input_ids,
+ images=image_tensor.unsqueeze(0).half().cuda(),
+ image_sizes=[image.size],
+ do_sample=True if args.temperature > 0 else False,
+ temperature=args.temperature,
+ top_p=args.top_p,
+ num_beams=args.num_beams,
+ # no_repeat_ngram_size=3,
+ max_new_tokens=1024,
+ use_cache=True)
+
+ outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
+
+ ans_id = shortuuid.uuid()
+ ans_file.write(json.dumps({"question_id": idx,
+ "prompt": cur_prompt,
+ "text": outputs,
+ "answer_id": ans_id,
+ "model_id": model_name,
+ "metadata": {}}) + "\n")
+ ans_file.flush()
+ ans_file.close()
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
+ parser.add_argument("--model-base", type=str, default=None)
+ parser.add_argument("--image-folder", type=str, default="")
+ parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
+ parser.add_argument("--answers-file", type=str, default="answer.jsonl")
+ parser.add_argument("--conv-mode", type=str, default="llava_v1")
+ parser.add_argument("--num-chunks", type=int, default=1)
+ parser.add_argument("--chunk-idx", type=int, default=0)
+ parser.add_argument("--temperature", type=float, default=0.2)
+ parser.add_argument("--top_p", type=float, default=None)
+ parser.add_argument("--num_beams", type=int, default=1)
+ args = parser.parse_args()
+
+ eval_model(args)
diff --git a/llava/eval/model_vqa_loader.py b/llava/eval/model_vqa_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..66d8af3b5df481b0cc2dc5e509d5cda21d8bc7b4
--- /dev/null
+++ b/llava/eval/model_vqa_loader.py
@@ -0,0 +1,149 @@
+import argparse
+import torch
+import os
+import json
+from tqdm import tqdm
+import shortuuid
+
+from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
+from llava.conversation import conv_templates, SeparatorStyle
+from llava.model.builder import load_pretrained_model
+from llava.utils import disable_torch_init
+from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
+from torch.utils.data import Dataset, DataLoader
+
+from PIL import Image
+import math
+
+
+def split_list(lst, n):
+ """Split a list into n (roughly) equal-sized chunks"""
+ chunk_size = math.ceil(len(lst) / n) # integer division
+ return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
+
+
+def get_chunk(lst, n, k):
+ chunks = split_list(lst, n)
+ return chunks[k]
+
+
+# Custom dataset class
+class CustomDataset(Dataset):
+ def __init__(self, questions, image_folder, tokenizer, image_processor, model_config):
+ self.questions = questions
+ self.image_folder = image_folder
+ self.tokenizer = tokenizer
+ self.image_processor = image_processor
+ self.model_config = model_config
+
+ def __getitem__(self, index):
+ line = self.questions[index]
+ image_file = line["image"]
+ qs = line["text"]
+ if self.model_config.mm_use_im_start_end:
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
+ else:
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
+
+ conv = conv_templates[args.conv_mode].copy()
+ conv.append_message(conv.roles[0], qs)
+ conv.append_message(conv.roles[1], None)
+ prompt = conv.get_prompt()
+
+ image = Image.open(os.path.join(self.image_folder, image_file)).convert('RGB')
+ image_tensor = process_images([image], self.image_processor, self.model_config)[0]
+
+ input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
+
+ print("input_ids.size():", input_ids.size())
+ print("image_tensor.size():", image_tensor.size())
+
+ exit()
+
+ return input_ids, image_tensor, image.size
+
+ def __len__(self):
+ return len(self.questions)
+
+
+def collate_fn(batch):
+ input_ids, image_tensors, image_sizes = zip(*batch)
+ input_ids = torch.stack(input_ids, dim=0)
+ image_tensors = torch.stack(image_tensors, dim=0)
+ return input_ids, image_tensors, image_sizes
+
+
+# DataLoader
+def create_data_loader(questions, image_folder, tokenizer, image_processor, model_config, batch_size=1, num_workers=4):
+ assert batch_size == 1, "batch_size must be 1"
+ dataset = CustomDataset(questions, image_folder, tokenizer, image_processor, model_config)
+ data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, collate_fn=collate_fn)
+ return data_loader
+
+
+def eval_model(args):
+ # Model
+ disable_torch_init()
+ model_path = os.path.expanduser(args.model_path)
+ model_name = get_model_name_from_path(model_path)
+ tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
+
+ questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
+ questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
+ answers_file = os.path.expanduser(args.answers_file)
+ os.makedirs(os.path.dirname(answers_file), exist_ok=True)
+ ans_file = open(answers_file, "w")
+
+ if 'plain' in model_name and 'finetune' not in model_name.lower() and 'mmtag' not in args.conv_mode:
+ args.conv_mode = args.conv_mode + '_mmtag'
+ print(f'It seems that this is a plain model, but it is not using a mmtag prompt, auto switching to {args.conv_mode}.')
+
+ data_loader = create_data_loader(questions, args.image_folder, tokenizer, image_processor, model.config)
+
+ for (input_ids, image_tensor, image_sizes), line in tqdm(zip(data_loader, questions), total=len(questions)):
+ idx = line["question_id"]
+ cur_prompt = line["text"]
+
+ input_ids = input_ids.to(device='cuda', non_blocking=True)
+
+ with torch.inference_mode():
+ output_ids = model.generate(
+ input_ids,
+ images=image_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True),
+ image_sizes=image_sizes,
+ do_sample=True if args.temperature > 0 else False,
+ temperature=args.temperature,
+ top_p=args.top_p,
+ num_beams=args.num_beams,
+ max_new_tokens=args.max_new_tokens,
+ use_cache=True)
+
+ outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
+
+ ans_id = shortuuid.uuid()
+ ans_file.write(json.dumps({"question_id": idx,
+ "prompt": cur_prompt,
+ "text": outputs,
+ "answer_id": ans_id,
+ "model_id": model_name,
+ "metadata": {}}) + "\n")
+ # ans_file.flush()
+ ans_file.close()
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
+ parser.add_argument("--model-base", type=str, default=None)
+ parser.add_argument("--image-folder", type=str, default="")
+ parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
+ parser.add_argument("--answers-file", type=str, default="answer.jsonl")
+ parser.add_argument("--conv-mode", type=str, default="llava_v1")
+ parser.add_argument("--num-chunks", type=int, default=1)
+ parser.add_argument("--chunk-idx", type=int, default=0)
+ parser.add_argument("--temperature", type=float, default=0.2)
+ parser.add_argument("--top_p", type=float, default=None)
+ parser.add_argument("--num_beams", type=int, default=1)
+ parser.add_argument("--max_new_tokens", type=int, default=128)
+ args = parser.parse_args()
+
+ eval_model(args)
diff --git a/llava/eval/model_vqa_mmbench.py b/llava/eval/model_vqa_mmbench.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd7a4c8085ddb7b237b17b054e5eaa0569018178
--- /dev/null
+++ b/llava/eval/model_vqa_mmbench.py
@@ -0,0 +1,160 @@
+import argparse
+import torch
+import os
+import json
+import pandas as pd
+from tqdm import tqdm
+import shortuuid
+
+from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
+from llava.conversation import conv_templates, SeparatorStyle
+from llava.model.builder import load_pretrained_model
+from llava.utils import disable_torch_init
+from llava.mm_utils import tokenizer_image_token, process_images, load_image_from_base64, get_model_name_from_path
+
+from PIL import Image
+import math
+
+
+all_options = ['A', 'B', 'C', 'D']
+
+
+def split_list(lst, n):
+ """Split a list into n (roughly) equal-sized chunks"""
+ chunk_size = math.ceil(len(lst) / n) # integer division
+ return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
+
+
+def get_chunk(lst, n, k):
+ chunks = split_list(lst, n)
+ return chunks[k]
+
+
+def is_none(value):
+ if value is None:
+ return True
+ if type(value) is float and math.isnan(value):
+ return True
+ if type(value) is str and value.lower() == 'nan':
+ return True
+ if type(value) is str and value.lower() == 'none':
+ return True
+ return False
+
+def get_options(row, options):
+ parsed_options = []
+ for option in options:
+ option_value = row[option]
+ if is_none(option_value):
+ break
+ parsed_options.append(option_value)
+ return parsed_options
+
+
+def eval_model(args):
+ # Model
+ disable_torch_init()
+ model_path = os.path.expanduser(args.model_path)
+ model_name = get_model_name_from_path(model_path)
+ tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
+
+ questions = pd.read_table(os.path.expanduser(args.question_file))
+ questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
+ answers_file = os.path.expanduser(args.answers_file)
+ os.makedirs(os.path.dirname(answers_file), exist_ok=True)
+ ans_file = open(answers_file, "w")
+
+ if 'plain' in model_name and 'finetune' not in model_name.lower() and 'mmtag' not in args.conv_mode:
+ args.conv_mode = args.conv_mode + '_mmtag'
+ print(f'It seems that this is a plain model, but it is not using a mmtag prompt, auto switching to {args.conv_mode}.')
+
+ for index, row in tqdm(questions.iterrows(), total=len(questions)):
+ options = get_options(row, all_options)
+ cur_option_char = all_options[:len(options)]
+
+ if args.all_rounds:
+ num_rounds = len(options)
+ else:
+ num_rounds = 1
+
+ for round_idx in range(num_rounds):
+ idx = row['index']
+ question = row['question']
+ hint = row['hint']
+ image = load_image_from_base64(row['image'])
+ if not is_none(hint):
+ question = hint + '\n' + question
+ for option_char, option in zip(all_options[:len(options)], options):
+ question = question + '\n' + option_char + '. ' + option
+ qs = cur_prompt = question
+ if model.config.mm_use_im_start_end:
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
+ else:
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
+
+ if args.single_pred_prompt:
+ if args.lang == 'cn':
+ qs = qs + '\n' + "่ฏท็ดๆฅๅ็ญ้้กนๅญๆฏใ"
+ else:
+ qs = qs + '\n' + "Answer with the option's letter from the given choices directly."
+
+ conv = conv_templates[args.conv_mode].copy()
+ conv.append_message(conv.roles[0], qs)
+ conv.append_message(conv.roles[1], None)
+ prompt = conv.get_prompt()
+
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
+
+ image_tensor = process_images([image], image_processor, model.config)[0]
+
+ with torch.inference_mode():
+ output_ids = model.generate(
+ input_ids,
+ images=image_tensor.unsqueeze(0).half().cuda(),
+ image_sizes=[image.size],
+ do_sample=True if args.temperature > 0 else False,
+ temperature=args.temperature,
+ top_p=args.top_p,
+ num_beams=args.num_beams,
+ # no_repeat_ngram_size=3,
+ max_new_tokens=1024,
+ use_cache=True)
+
+ outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
+
+ ans_id = shortuuid.uuid()
+ ans_file.write(json.dumps({"question_id": idx,
+ "round_id": round_idx,
+ "prompt": cur_prompt,
+ "text": outputs,
+ "options": options,
+ "option_char": cur_option_char,
+ "answer_id": ans_id,
+ "model_id": model_name,
+ "metadata": {}}) + "\n")
+ ans_file.flush()
+
+ # rotate options
+ options = options[1:] + options[:1]
+ cur_option_char = cur_option_char[1:] + cur_option_char[:1]
+ ans_file.close()
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
+ parser.add_argument("--model-base", type=str, default=None)
+ parser.add_argument("--image-folder", type=str, default="")
+ parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
+ parser.add_argument("--answers-file", type=str, default="answer.jsonl")
+ parser.add_argument("--conv-mode", type=str, default="llava_v1")
+ parser.add_argument("--num-chunks", type=int, default=1)
+ parser.add_argument("--chunk-idx", type=int, default=0)
+ parser.add_argument("--temperature", type=float, default=0.2)
+ parser.add_argument("--top_p", type=float, default=None)
+ parser.add_argument("--num_beams", type=int, default=1)
+ parser.add_argument("--all-rounds", action="store_true")
+ parser.add_argument("--single-pred-prompt", action="store_true")
+ parser.add_argument("--lang", type=str, default="en")
+ args = parser.parse_args()
+
+ eval_model(args)
diff --git a/llava/eval/model_vqa_science.py b/llava/eval/model_vqa_science.py
new file mode 100644
index 0000000000000000000000000000000000000000..90fc681a20ee72131862772107f6be572f010c99
--- /dev/null
+++ b/llava/eval/model_vqa_science.py
@@ -0,0 +1,111 @@
+import argparse
+import torch
+import os
+import json
+from tqdm import tqdm
+import shortuuid
+
+from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
+from llava.conversation import conv_templates, SeparatorStyle
+from llava.model.builder import load_pretrained_model
+from llava.utils import disable_torch_init
+from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
+
+from PIL import Image
+import math
+
+
+def split_list(lst, n):
+ """Split a list into n (roughly) equal-sized chunks"""
+ chunk_size = math.ceil(len(lst) / n) # integer division
+ return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
+
+
+def get_chunk(lst, n, k):
+ chunks = split_list(lst, n)
+ return chunks[k]
+
+
+def eval_model(args):
+ # Model
+ disable_torch_init()
+ model_path = os.path.expanduser(args.model_path)
+ model_name = get_model_name_from_path(model_path)
+ tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
+
+ questions = json.load(open(os.path.expanduser(args.question_file), "r"))
+ questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
+ answers_file = os.path.expanduser(args.answers_file)
+ os.makedirs(os.path.dirname(answers_file), exist_ok=True)
+ ans_file = open(answers_file, "w")
+ for i, line in enumerate(tqdm(questions)):
+ idx = line["id"]
+ question = line['conversations'][0]
+ qs = question['value'].replace('', '').strip()
+ cur_prompt = qs
+
+ if 'image' in line:
+ image_file = line["image"]
+ image = Image.open(os.path.join(args.image_folder, image_file))
+ image_tensor = process_images([image], image_processor, model.config)[0]
+ images = image_tensor.unsqueeze(0).half().cuda()
+ image_sizes = [image.size]
+ if getattr(model.config, 'mm_use_im_start_end', False):
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
+ else:
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
+ cur_prompt = '' + '\n' + cur_prompt
+ else:
+ images = None
+ image_sizes = None
+
+ if args.single_pred_prompt:
+ qs = qs + '\n' + "Answer with the option's letter from the given choices directly."
+ cur_prompt = cur_prompt + '\n' + "Answer with the option's letter from the given choices directly."
+
+ conv = conv_templates[args.conv_mode].copy()
+ conv.append_message(conv.roles[0], qs)
+ conv.append_message(conv.roles[1], None)
+ prompt = conv.get_prompt()
+
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
+
+ with torch.inference_mode():
+ output_ids = model.generate(
+ input_ids,
+ images=images,
+ image_sizes=image_sizes,
+ do_sample=True if args.temperature > 0 else False,
+ temperature=args.temperature,
+ max_new_tokens=1024,
+ use_cache=True,
+ )
+
+ outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
+
+ ans_id = shortuuid.uuid()
+ ans_file.write(json.dumps({"question_id": idx,
+ "prompt": cur_prompt,
+ "text": outputs,
+ "answer_id": ans_id,
+ "model_id": model_name,
+ "metadata": {}}) + "\n")
+ ans_file.flush()
+ ans_file.close()
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
+ parser.add_argument("--model-base", type=str, default=None)
+ parser.add_argument("--image-folder", type=str, default="")
+ parser.add_argument("--question-file", type=str, default="tables/question.json")
+ parser.add_argument("--answers-file", type=str, default="answer.jsonl")
+ parser.add_argument("--conv-mode", type=str, default="llava_v0")
+ parser.add_argument("--num-chunks", type=int, default=1)
+ parser.add_argument("--chunk-idx", type=int, default=0)
+ parser.add_argument("--temperature", type=float, default=0.2)
+ parser.add_argument("--answer-prompter", action="store_true")
+ parser.add_argument("--single-pred-prompt", action="store_true")
+ args = parser.parse_args()
+
+ eval_model(args)
diff --git a/llava/eval/qa_baseline_gpt35.py b/llava/eval/qa_baseline_gpt35.py
new file mode 100644
index 0000000000000000000000000000000000000000..babab6e12b4bb8cfa74a7edfa5e56cd1b3e2bf6c
--- /dev/null
+++ b/llava/eval/qa_baseline_gpt35.py
@@ -0,0 +1,74 @@
+"""Generate answers with GPT-3.5"""
+# Note: you need to be using OpenAI Python v0.27.0 for the code below to work
+import argparse
+import json
+import os
+import time
+import concurrent.futures
+
+import openai
+import tqdm
+import shortuuid
+
+MODEL = 'gpt-3.5-turbo'
+MODEL_ID = 'gpt-3.5-turbo:20230327'
+
+def get_answer(question_id: int, question: str, max_tokens: int):
+ ans = {
+ 'answer_id': shortuuid.uuid(),
+ 'question_id': question_id,
+ 'model_id': MODEL_ID,
+ }
+ for _ in range(3):
+ try:
+ response = openai.ChatCompletion.create(
+ model=MODEL,
+ messages=[{
+ 'role': 'system',
+ 'content': 'You are a helpful assistant.'
+ }, {
+ 'role': 'user',
+ 'content': question,
+ }],
+ max_tokens=max_tokens,
+ )
+ ans['text'] = response['choices'][0]['message']['content']
+ return ans
+ except Exception as e:
+ print('[ERROR]', e)
+ ans['text'] = '#ERROR#'
+ time.sleep(1)
+ return ans
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='ChatGPT answer generation.')
+ parser.add_argument('-q', '--question')
+ parser.add_argument('-o', '--output')
+ parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
+ args = parser.parse_args()
+
+ questions_dict = {}
+ with open(os.path.expanduser(args.question)) as f:
+ for line in f:
+ if not line:
+ continue
+ q = json.loads(line)
+ questions_dict[q['question_id']] = q['text']
+
+ answers = []
+
+ with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
+ futures = []
+ for qid, question in questions_dict.items():
+ future = executor.submit(get_answer, qid, question, args.max_tokens)
+ futures.append(future)
+
+ for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
+ answers.append(future.result())
+
+ answers.sort(key=lambda x: x['question_id'])
+
+ with open(os.path.expanduser(args.output), 'w') as f:
+ table = [json.dumps(ans) for ans in answers]
+ f.write('\n'.join(table))
diff --git a/llava/eval/run_llava.py b/llava/eval/run_llava.py
new file mode 100644
index 0000000000000000000000000000000000000000..24b0fffcc11a2045dfc7f5ac6cae4f057aaba6d6
--- /dev/null
+++ b/llava/eval/run_llava.py
@@ -0,0 +1,145 @@
+import argparse
+import torch
+
+from llava.constants import (
+ IMAGE_TOKEN_INDEX,
+ DEFAULT_IMAGE_TOKEN,
+ DEFAULT_IM_START_TOKEN,
+ DEFAULT_IM_END_TOKEN,
+ IMAGE_PLACEHOLDER,
+)
+from llava.conversation import conv_templates, SeparatorStyle
+from llava.model.builder import load_pretrained_model
+from llava.utils import disable_torch_init
+from llava.mm_utils import (
+ process_images,
+ tokenizer_image_token,
+ get_model_name_from_path,
+)
+
+from PIL import Image
+
+import requests
+from PIL import Image
+from io import BytesIO
+import re
+
+
+def image_parser(args):
+ out = args.image_file.split(args.sep)
+ return out
+
+
+def load_image(image_file):
+ if image_file.startswith("http") or image_file.startswith("https"):
+ response = requests.get(image_file)
+ image = Image.open(BytesIO(response.content)).convert("RGB")
+ else:
+ image = Image.open(image_file).convert("RGB")
+ return image
+
+
+def load_images(image_files):
+ out = []
+ for image_file in image_files:
+ image = load_image(image_file)
+ out.append(image)
+ return out
+
+
+def eval_model(args):
+ # Model
+ disable_torch_init()
+
+ model_name = get_model_name_from_path(args.model_path)
+ tokenizer, model, image_processor, context_len = load_pretrained_model(
+ args.model_path, args.model_base, model_name
+ )
+
+ qs = args.query
+ image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
+ if IMAGE_PLACEHOLDER in qs:
+ if model.config.mm_use_im_start_end:
+ qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
+ else:
+ qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
+ else:
+ if model.config.mm_use_im_start_end:
+ qs = image_token_se + "\n" + qs
+ else:
+ qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
+
+ if "llama-2" in model_name.lower():
+ conv_mode = "llava_llama_2"
+ elif "mistral" in model_name.lower():
+ conv_mode = "mistral_instruct"
+ elif "v1.6-34b" in model_name.lower():
+ conv_mode = "chatml_direct"
+ elif "v1" in model_name.lower():
+ conv_mode = "llava_v1"
+ elif "mpt" in model_name.lower():
+ conv_mode = "mpt"
+ else:
+ conv_mode = "llava_v0"
+
+ if args.conv_mode is not None and conv_mode != args.conv_mode:
+ print(
+ "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
+ conv_mode, args.conv_mode, args.conv_mode
+ )
+ )
+ else:
+ args.conv_mode = conv_mode
+
+ conv = conv_templates[args.conv_mode].copy()
+ conv.append_message(conv.roles[0], qs)
+ conv.append_message(conv.roles[1], None)
+ prompt = conv.get_prompt()
+
+ image_files = image_parser(args)
+ images = load_images(image_files)
+ image_sizes = [x.size for x in images]
+ images_tensor = process_images(
+ images,
+ image_processor,
+ model.config
+ ).to(model.device, dtype=torch.float16)
+
+ input_ids = (
+ tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
+ .unsqueeze(0)
+ .cuda()
+ )
+
+ with torch.inference_mode():
+ output_ids = model.generate(
+ input_ids,
+ images=images_tensor,
+ image_sizes=image_sizes,
+ do_sample=True if args.temperature > 0 else False,
+ temperature=args.temperature,
+ top_p=args.top_p,
+ num_beams=args.num_beams,
+ max_new_tokens=args.max_new_tokens,
+ use_cache=True,
+ )
+
+ outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
+ print(outputs)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
+ parser.add_argument("--model-base", type=str, default=None)
+ parser.add_argument("--image-file", type=str, required=True)
+ parser.add_argument("--query", type=str, required=True)
+ parser.add_argument("--conv-mode", type=str, default=None)
+ parser.add_argument("--sep", type=str, default=",")
+ parser.add_argument("--temperature", type=float, default=0.2)
+ parser.add_argument("--top_p", type=float, default=None)
+ parser.add_argument("--num_beams", type=int, default=1)
+ parser.add_argument("--max_new_tokens", type=int, default=512)
+ args = parser.parse_args()
+
+ eval_model(args)
diff --git a/llava/eval/summarize_gpt_review.py b/llava/eval/summarize_gpt_review.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f796a3880341739677a5fe3bfbcc90515a0f324
--- /dev/null
+++ b/llava/eval/summarize_gpt_review.py
@@ -0,0 +1,60 @@
+import json
+import os
+from collections import defaultdict
+
+import numpy as np
+
+import argparse
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
+ parser.add_argument('-d', '--dir', default=None)
+ parser.add_argument('-v', '--version', default=None)
+ parser.add_argument('-s', '--select', nargs='*', default=None)
+ parser.add_argument('-f', '--files', nargs='*', default=[])
+ parser.add_argument('-i', '--ignore', nargs='*', default=[])
+ return parser.parse_args()
+
+
+if __name__ == '__main__':
+ args = parse_args()
+
+ if args.ignore is not None:
+ args.ignore = [int(x) for x in args.ignore]
+
+ if len(args.files) > 0:
+ review_files = args.files
+ else:
+ review_files = [x for x in os.listdir(args.dir) if x.endswith('.jsonl') and (x.startswith('gpt4_text') or x.startswith('reviews_') or x.startswith('review_') or 'review' in args.dir)]
+
+ for review_file in sorted(review_files):
+ config = os.path.basename(review_file).replace('gpt4_text_', '').replace('.jsonl', '')
+ if args.select is not None and any(x not in config for x in args.select):
+ continue
+ if '0613' in config:
+ version = '0613'
+ else:
+ version = '0314'
+ if args.version is not None and args.version != version:
+ continue
+ scores = defaultdict(list)
+ print(config)
+ with open(os.path.join(args.dir, review_file) if args.dir is not None else review_file) as f:
+ for review_str in f:
+ review = json.loads(review_str)
+ if review['question_id'] in args.ignore:
+ continue
+ if 'category' in review:
+ scores[review['category']].append(review['tuple'])
+ scores['all'].append(review['tuple'])
+ else:
+ if 'tuple' in review:
+ scores['all'].append(review['tuple'])
+ else:
+ scores['all'].append(review['score'])
+ for k, v in sorted(scores.items()):
+ stats = np.asarray(v).mean(0).tolist()
+ stats = [round(x, 3) for x in stats]
+ # print(k, stats, round(stats[1]/stats[0]*100, 1))
+ print(k, round(stats[1]/stats[0]*100, 1), round(stats[0] * 10, 1), round(stats[1] * 10, 1))
+ print('=================================')
diff --git a/llava/eval/webpage/figures/alpaca.png b/llava/eval/webpage/figures/alpaca.png
new file mode 100644
index 0000000000000000000000000000000000000000..a36f3d99f975e8cd31767c342e5bbeb678b3ea1e
--- /dev/null
+++ b/llava/eval/webpage/figures/alpaca.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f61148aa7dfb6d17b1f341bc0dc852ed1270f2ee128d77fda57ab7df185a051d
+size 96061
diff --git a/llava/eval/webpage/figures/bard.jpg b/llava/eval/webpage/figures/bard.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..785afd017c17283ca6de0dbdc0e41a95a9a30b55
--- /dev/null
+++ b/llava/eval/webpage/figures/bard.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fd86b3a25383ea2cca3b4147036a00d6125ce51ffb26daf384e814ecd90a2df6
+size 15309
diff --git a/llava/eval/webpage/figures/chatgpt.svg b/llava/eval/webpage/figures/chatgpt.svg
new file mode 100644
index 0000000000000000000000000000000000000000..bd7af3e6985d359025437eb95c5e77075db59443
--- /dev/null
+++ b/llava/eval/webpage/figures/chatgpt.svg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a362957d7e9682d46bb0e9a488a1039b506be2a85ca836986c455c77c5fe4c4e
+size 1694
diff --git a/llava/eval/webpage/figures/llama.jpg b/llava/eval/webpage/figures/llama.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..dfe5a492b1ba0ec7f64b69fd0ed26c60001c4d83
--- /dev/null
+++ b/llava/eval/webpage/figures/llama.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3c1aeb2844dabea9bede23d998b8eb3c4d09a3f1abe0a1fedbc501e7ef079e6b
+size 56537
diff --git a/llava/eval/webpage/figures/swords_FILL0_wght300_GRAD0_opsz48.svg b/llava/eval/webpage/figures/swords_FILL0_wght300_GRAD0_opsz48.svg
new file mode 100644
index 0000000000000000000000000000000000000000..28336a22bef699a474e7e59546b736503e65bcb0
--- /dev/null
+++ b/llava/eval/webpage/figures/swords_FILL0_wght300_GRAD0_opsz48.svg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b4e1ccfd86dec456fdff29d645ba502bbaabb32e5917d1692e3ff5396508c669
+size 1083
diff --git a/llava/eval/webpage/figures/vicuna.jpeg b/llava/eval/webpage/figures/vicuna.jpeg
new file mode 100644
index 0000000000000000000000000000000000000000..a399b6955186c3bbbbd8671d206016b27b00059d
--- /dev/null
+++ b/llava/eval/webpage/figures/vicuna.jpeg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:56c088a38183ba47599d9d9829b578eeed606c99761ed7dda05a302b803e34f5
+size 53975
diff --git a/llava/eval/webpage/index.html b/llava/eval/webpage/index.html
new file mode 100644
index 0000000000000000000000000000000000000000..f8cf97a7d5eb0c2025c998e8d8b2a9da06da3a89
--- /dev/null
+++ b/llava/eval/webpage/index.html
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8710c86bcdfa4782e8555cf2a63b1c44395046875554343a58e39d0833b48864
+size 7669
diff --git a/llava/eval/webpage/script.js b/llava/eval/webpage/script.js
new file mode 100644
index 0000000000000000000000000000000000000000..a055eb9d37962ef2e6e7e1d838295e62cb906dd9
--- /dev/null
+++ b/llava/eval/webpage/script.js
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a868b812cefde66271def08ad84154c8750a11bc9f95a390dc7f7f97eda3302a
+size 9991
diff --git a/llava/eval/webpage/styles.css b/llava/eval/webpage/styles.css
new file mode 100644
index 0000000000000000000000000000000000000000..1d0b035f3723afd9fb2ec6b48b27fd042ade3939
--- /dev/null
+++ b/llava/eval/webpage/styles.css
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e4f99a8e0714130707a1454a62abc0867e28c012acbbdcbd892fb14f112ab3b7
+size 1822
diff --git a/llava/mm_utils.py b/llava/mm_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..de97345cf424fe72cc90de30f42d127ff20b99ef
--- /dev/null
+++ b/llava/mm_utils.py
@@ -0,0 +1,247 @@
+from PIL import Image
+from io import BytesIO
+import base64
+import torch
+import math
+import ast
+
+from transformers import StoppingCriteria
+from llava.constants import IMAGE_TOKEN_INDEX
+
+
+def select_best_resolution(original_size, possible_resolutions):
+ """
+ Selects the best resolution from a list of possible resolutions based on the original size.
+
+ Args:
+ original_size (tuple): The original size of the image in the format (width, height).
+ possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
+
+ Returns:
+ tuple: The best fit resolution in the format (width, height).
+ """
+ original_width, original_height = original_size
+ best_fit = None
+ max_effective_resolution = 0
+ min_wasted_resolution = float('inf')
+
+ for width, height in possible_resolutions:
+ scale = min(width / original_width, height / original_height)
+ downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
+ effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
+ wasted_resolution = (width * height) - effective_resolution
+
+ if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
+ max_effective_resolution = effective_resolution
+ min_wasted_resolution = wasted_resolution
+ best_fit = (width, height)
+
+ return best_fit
+
+
+def resize_and_pad_image(image, target_resolution):
+ """
+ Resize and pad an image to a target resolution while maintaining aspect ratio.
+
+ Args:
+ image (PIL.Image.Image): The input image.
+ target_resolution (tuple): The target resolution (width, height) of the image.
+
+ Returns:
+ PIL.Image.Image: The resized and padded image.
+ """
+ original_width, original_height = image.size
+ target_width, target_height = target_resolution
+
+ scale_w = target_width / original_width
+ scale_h = target_height / original_height
+
+ if scale_w < scale_h:
+ new_width = target_width
+ new_height = min(math.ceil(original_height * scale_w), target_height)
+ else:
+ new_height = target_height
+ new_width = min(math.ceil(original_width * scale_h), target_width)
+
+ # Resize the image
+ resized_image = image.resize((new_width, new_height))
+
+ new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
+ paste_x = (target_width - new_width) // 2
+ paste_y = (target_height - new_height) // 2
+ new_image.paste(resized_image, (paste_x, paste_y))
+
+ return new_image
+
+
+def divide_to_patches(image, patch_size):
+ """
+ Divides an image into patches of a specified size.
+
+ Args:
+ image (PIL.Image.Image): The input image.
+ patch_size (int): The size of each patch.
+
+ Returns:
+ list: A list of PIL.Image.Image objects representing the patches.
+ """
+ patches = []
+ width, height = image.size
+ for i in range(0, height, patch_size):
+ for j in range(0, width, patch_size):
+ box = (j, i, j + patch_size, i + patch_size)
+ patch = image.crop(box)
+ patches.append(patch)
+
+ return patches
+
+
+def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
+ """
+ Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
+
+ Args:
+ image_size (tuple): The size of the input image in the format (width, height).
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
+ patch_size (int): The size of each image patch.
+
+ Returns:
+ tuple: The shape of the image patch grid in the format (width, height).
+ """
+ if type(grid_pinpoints) is list:
+ possible_resolutions = grid_pinpoints
+ else:
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
+ width, height = select_best_resolution(image_size, possible_resolutions)
+ return width // patch_size, height // patch_size
+
+
+def process_anyres_image(image, processor, grid_pinpoints):
+ """
+ Process an image with variable resolutions.
+
+ Args:
+ image (PIL.Image.Image): The input image to be processed.
+ processor: The image processor object.
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
+
+ Returns:
+ torch.Tensor: A tensor containing the processed image patches.
+ """
+ if type(grid_pinpoints) is list:
+ possible_resolutions = grid_pinpoints
+ else:
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
+ best_resolution = select_best_resolution(image.size, possible_resolutions)
+ image_padded = resize_and_pad_image(image, best_resolution)
+
+ patches = divide_to_patches(image_padded, processor.crop_size['height'])
+
+ image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))
+
+ image_patches = [image_original_resize] + patches
+ image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0]
+ for image_patch in image_patches]
+ return torch.stack(image_patches, dim=0)
+
+
+def load_image_from_base64(image):
+ return Image.open(BytesIO(base64.b64decode(image)))
+
+
+def expand2square(pil_img, background_color):
+ width, height = pil_img.size
+ if width == height:
+ return pil_img
+ elif width > height:
+ result = Image.new(pil_img.mode, (width, width), background_color)
+ result.paste(pil_img, (0, (width - height) // 2))
+ return result
+ else:
+ result = Image.new(pil_img.mode, (height, height), background_color)
+ result.paste(pil_img, ((height - width) // 2, 0))
+ return result
+
+
+def process_images(images, image_processor, model_cfg):
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
+ new_images = []
+ if image_aspect_ratio == 'pad':
+ for image in images:
+ image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
+ image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
+ new_images.append(image)
+ elif image_aspect_ratio == "anyres":
+ for image in images:
+ image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
+ new_images.append(image)
+ else:
+ return image_processor(images, return_tensors='pt')['pixel_values']
+ if all(x.shape == new_images[0].shape for x in new_images):
+ new_images = torch.stack(new_images, dim=0)
+ return new_images
+
+
+def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')]
+
+ def insert_separator(X, sep):
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
+
+ input_ids = []
+ offset = 0
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
+ offset = 1
+ input_ids.append(prompt_chunks[0][0])
+
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
+ input_ids.extend(x[offset:])
+
+ if return_tensors is not None:
+ if return_tensors == 'pt':
+ return torch.tensor(input_ids, dtype=torch.long)
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
+ return input_ids
+
+
+def get_model_name_from_path(model_path):
+ model_path = model_path.strip("/")
+ model_paths = model_path.split("/")
+ if model_paths[-1].startswith('checkpoint-'):
+ return model_paths[-2] + "_" + model_paths[-1]
+ else:
+ return model_paths[-1]
+
+class KeywordsStoppingCriteria(StoppingCriteria):
+ def __init__(self, keywords, tokenizer, input_ids):
+ self.keywords = keywords
+ self.keyword_ids = []
+ self.max_keyword_len = 0
+ for keyword in keywords:
+ cur_keyword_ids = tokenizer(keyword).input_ids
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
+ cur_keyword_ids = cur_keyword_ids[1:]
+ if len(cur_keyword_ids) > self.max_keyword_len:
+ self.max_keyword_len = len(cur_keyword_ids)
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
+ self.tokenizer = tokenizer
+ self.start_len = input_ids.shape[1]
+
+ def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
+ offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
+ for keyword_id in self.keyword_ids:
+ truncated_output_ids = output_ids[0, -keyword_id.shape[0]:]
+ if torch.equal(truncated_output_ids, keyword_id):
+ return True
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
+ for keyword in self.keywords:
+ if keyword in outputs:
+ return True
+ return False
+
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
+ outputs = []
+ for i in range(output_ids.shape[0]):
+ outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
+ return all(outputs)
diff --git a/llava/model/__init__.py b/llava/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c74f6cc2e9701f341cf484d5ca0a5d4d09d904d3
--- /dev/null
+++ b/llava/model/__init__.py
@@ -0,0 +1,9 @@
+from .language_model.llava_phi3 import LlavaPhiForCausalLM, LlavaPhiConfig
+
+# try:
+# from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig
+# from .language_model.llava_mpt import LlavaMptForCausalLM, LlavaMptConfig
+# from .language_model.llava_mistral import LlavaMistralForCausalLM, LlavaMistralConfig
+# from .language_model.llava_phi3 import LlavaPhiForCausalLM, LlavaPhiConfig
+# except:
+# pass
diff --git a/llava/model/__pycache__/__init__.cpython-310.pyc b/llava/model/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..05a61cba296ffeb7404428fd32901ea2e68aaa3f
--- /dev/null
+++ b/llava/model/__pycache__/__init__.cpython-310.pyc
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ccc74fb7d5a786959b1684d921e01038433f8105ebc8f21183edc310ec1c55f4
+size 242
diff --git a/llava/model/__pycache__/builder.cpython-310.pyc b/llava/model/__pycache__/builder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c8c3f2139f3dc1c3afbf20f813afa79e547ef885
--- /dev/null
+++ b/llava/model/__pycache__/builder.cpython-310.pyc
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a3fe8a76b11766408258e0f60102ff2a0522fcc240b8ccbff6dda006648eaff1
+size 4786
diff --git a/llava/model/__pycache__/llava_arch.cpython-310.pyc b/llava/model/__pycache__/llava_arch.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6b4caab0093f5ef78c34d13c1ac81168d5df0e5c
--- /dev/null
+++ b/llava/model/__pycache__/llava_arch.cpython-310.pyc
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8b49e8d013dad067cab3f8960ad2149e28ac76ca92f65af1886cb39c88f0883d
+size 10901
diff --git a/llava/model/__pycache__/utils.cpython-310.pyc b/llava/model/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bc8f86a61d13d74a95332de15a2986890c69d798
--- /dev/null
+++ b/llava/model/__pycache__/utils.cpython-310.pyc
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:748c81e4d25f789b72bf94c34c31b50d2e3340078359f98bfd7457dd48f4b79a
+size 10189
diff --git a/llava/model/apply_delta.py b/llava/model/apply_delta.py
new file mode 100644
index 0000000000000000000000000000000000000000..666dd9691bde7d54ddf2871e311d6f621e29f099
--- /dev/null
+++ b/llava/model/apply_delta.py
@@ -0,0 +1,48 @@
+"""
+Usage:
+python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta
+"""
+import argparse
+
+import torch
+from tqdm import tqdm
+from transformers import AutoTokenizer, AutoModelForCausalLM
+from llava import LlavaLlamaForCausalLM
+
+
+def apply_delta(base_model_path, target_model_path, delta_path):
+ print("Loading base model")
+ base = AutoModelForCausalLM.from_pretrained(
+ base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
+
+ print("Loading delta")
+ delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
+ delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
+
+ print("Applying delta")
+ for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
+ if name not in base.state_dict():
+ assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
+ continue
+ if param.data.shape == base.state_dict()[name].shape:
+ param.data += base.state_dict()[name]
+ else:
+ assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \
+ f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
+ bparam = base.state_dict()[name]
+ param.data[:bparam.shape[0], :bparam.shape[1]] += bparam
+
+ print("Saving target model")
+ delta.save_pretrained(target_model_path)
+ delta_tokenizer.save_pretrained(target_model_path)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--base-model-path", type=str, required=True)
+ parser.add_argument("--target-model-path", type=str, required=True)
+ parser.add_argument("--delta-path", type=str, required=True)
+
+ args = parser.parse_args()
+
+ apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
diff --git a/llava/model/bpe_simple_vocab_16e6.txt.gz b/llava/model/bpe_simple_vocab_16e6.txt.gz
new file mode 100644
index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113
--- /dev/null
+++ b/llava/model/bpe_simple_vocab_16e6.txt.gz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
+size 1356917
diff --git a/llava/model/builder.py b/llava/model/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..75fe7b535a9a5a06cc78e6fcafab29e5ab885baa
--- /dev/null
+++ b/llava/model/builder.py
@@ -0,0 +1,169 @@
+# Modified from LLaVA: https://github.com/haotian-liu/LLaVA.git
+#
+# Copyright 2023 Haotian Liu
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import warnings
+import shutil
+
+from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
+import torch
+from llava.model import *
+from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
+
+
+def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, **kwargs):
+ kwargs = {"device_map": device_map, **kwargs}
+
+ if device != "cuda":
+ kwargs['device_map'] = {"": device}
+
+ if load_8bit:
+ kwargs['load_in_8bit'] = True
+ elif load_4bit:
+ kwargs['load_in_4bit'] = True
+ kwargs['quantization_config'] = BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_compute_dtype=torch.float16,
+ bnb_4bit_use_double_quant=True,
+ bnb_4bit_quant_type='nf4'
+ )
+ else:
+ kwargs['torch_dtype'] = torch.float16
+
+ if use_flash_attn:
+ kwargs['attn_implementation'] = 'flash_attention_2'
+
+ if 'llava' in model_name.lower():
+ # Load LLaVA model
+ if 'lora' in model_name.lower() and model_base is None:
+ warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.')
+ if 'lora' in model_name.lower() and model_base is not None:
+ from llava.model.language_model.llava_phi3 import LlavaPhiConfig
+ lora_cfg_pretrained = LlavaPhiConfig.from_pretrained(model_path)
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
+ print('Loading LLaVA from base model...')
+ model = LlavaPhiForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
+ token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
+ if model.lm_head.weight.shape[0] != token_num:
+ model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
+ model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
+
+ print('Loading additional LLaVA weights...')
+ if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
+ non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
+ else:
+ # this is probably from HF Hub
+ from huggingface_hub import hf_hub_download
+ def load_from_hf(repo_id, filename, subfolder=None):
+ cache_file = hf_hub_download(
+ repo_id=repo_id,
+ filename=filename,
+ subfolder=subfolder)
+ return torch.load(cache_file, map_location='cpu')
+ non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
+ non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
+ if any(k.startswith('model.model.') for k in non_lora_trainables):
+ non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
+ model.load_state_dict(non_lora_trainables, strict=False)
+
+ from peft import PeftModel
+ print('Loading LoRA weights...')
+ model = PeftModel.from_pretrained(model, model_path)
+ print('Merging LoRA weights...')
+ model = model.merge_and_unload()
+ print('Model is loaded...')
+ elif model_base is not None:
+ # this may be mm projector only
+ print('Loading LLaVA from base model...')
+ if 'mpt' in model_name.lower():
+ if not os.path.isfile(os.path.join(model_path, 'configuration_mpt.py')):
+ shutil.copyfile(os.path.join(model_base, 'configuration_mpt.py'), os.path.join(model_path, 'configuration_mpt.py'))
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
+ cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
+ model = LlavaMptForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
+ else:
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
+ model = LlavaPhiForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
+
+ mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
+ mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
+ model.load_state_dict(mm_projector_weights, strict=False)
+ else:
+ if 'mpt' in model_name.lower():
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
+ model = LlavaMptForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
+ elif 'mistral' in model_name.lower():
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
+ model = LlavaMistralForCausalLM.from_pretrained(
+ model_path,
+ low_cpu_mem_usage=True,
+ **kwargs
+ )
+ else:
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
+ model = LlavaPhiForCausalLM.from_pretrained(
+ model_path,
+ low_cpu_mem_usage=True,
+ **kwargs
+ )
+ else:
+ # Load language model
+ if model_base is not None:
+ # PEFT model
+ from peft import PeftModel
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
+ model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs)
+ print(f"Loading LoRA weights from {model_path}")
+ model = PeftModel.from_pretrained(model, model_path)
+ print(f"Merging weights")
+ model = model.merge_and_unload()
+ print('Convert to FP16...')
+ model.to(torch.float16)
+ else:
+ use_fast = False
+ if 'mpt' in model_name.lower():
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs)
+ else:
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
+
+ image_processor = None
+
+ if 'llava' in model_name.lower():
+ mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
+ mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
+ if mm_use_im_patch_token:
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
+ if mm_use_im_start_end:
+ tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
+ model.resize_token_embeddings(len(tokenizer))
+
+ # vision_tower = model.get_vision_tower()
+ # if not vision_tower.is_loaded:
+ # vision_tower.load_model(device_map=device_map)
+ # if device_map != 'auto':
+ # vision_tower.to(device=device_map, dtype=torch.float16)
+ # image_processor = vision_tower.image_processor
+
+ if hasattr(model.config, "max_sequence_length"):
+ context_len = model.config.max_sequence_length
+ else:
+ context_len = 2048
+
+ return tokenizer, model, context_len
diff --git a/llava/model/consolidate.py b/llava/model/consolidate.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e324210e229eeba23b75791bba82df7c6e639eb
--- /dev/null
+++ b/llava/model/consolidate.py
@@ -0,0 +1,29 @@
+"""
+Usage:
+python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate
+"""
+import argparse
+
+import torch
+from transformers import AutoTokenizer, AutoModelForCausalLM
+from llava.model import *
+from llava.model.utils import auto_upgrade
+
+
+def consolidate_ckpt(src_path, dst_path):
+ print("Loading model")
+ auto_upgrade(src_path)
+ src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
+ src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
+ src_model.save_pretrained(dst_path)
+ src_tokenizer.save_pretrained(dst_path)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--src", type=str, required=True)
+ parser.add_argument("--dst", type=str, required=True)
+
+ args = parser.parse_args()
+
+ consolidate_ckpt(args.src, args.dst)
diff --git a/llava/model/language_model/__pycache__/llava_phi3.cpython-310.pyc b/llava/model/language_model/__pycache__/llava_phi3.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fee5d7318219bf5405bbfed65c1d7152f0a535ad
--- /dev/null
+++ b/llava/model/language_model/__pycache__/llava_phi3.cpython-310.pyc
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8b53993e450a9c71cb50d527bf2b9dc3d79865fedcfc94d261d84d8057bae7aa
+size 3815
diff --git a/llava/model/language_model/llava_llama.py b/llava/model/language_model/llava_llama.py
new file mode 100644
index 0000000000000000000000000000000000000000..069d0d1c10da42f5d278598e8534f166d1f9f5ff
--- /dev/null
+++ b/llava/model/language_model/llava_llama.py
@@ -0,0 +1,158 @@
+# Copyright 2023 Haotian Liu
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from transformers import AutoConfig, AutoModelForCausalLM, \
+ LlamaConfig, LlamaModel, LlamaForCausalLM
+
+from transformers.modeling_outputs import CausalLMOutputWithPast
+from transformers.generation.utils import GenerateOutput
+
+from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
+
+
+class LlavaConfig(LlamaConfig):
+ model_type = "llava_llama"
+
+
+class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
+ config_class = LlavaConfig
+
+ def __init__(self, config: LlamaConfig):
+ super(LlavaLlamaModel, self).__init__(config)
+
+
+class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
+ config_class = LlavaConfig
+
+ def __init__(self, config):
+ super(LlamaForCausalLM, self).__init__(config)
+ self.model = LlavaLlamaModel(config)
+ self.pretraining_tp = config.pretraining_tp
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_model(self):
+ return self.model
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ images: Optional[torch.FloatTensor] = None,
+ image_sizes: Optional[List[List[int]]] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+
+ if inputs_embeds is None:
+ (
+ input_ids,
+ position_ids,
+ attention_mask,
+ past_key_values,
+ inputs_embeds,
+ labels
+ ) = self.prepare_inputs_labels_for_multimodal(
+ input_ids,
+ position_ids,
+ attention_mask,
+ past_key_values,
+ labels,
+ images,
+ image_sizes
+ )
+
+ return super().forward(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ labels=labels,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict
+ )
+
+ @torch.no_grad()
+ def generate(
+ self,
+ inputs: Optional[torch.Tensor] = None,
+ images: Optional[torch.Tensor] = None,
+ image_sizes: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> Union[GenerateOutput, torch.LongTensor]:
+ position_ids = kwargs.pop("position_ids", None)
+ attention_mask = kwargs.pop("attention_mask", None)
+ if "inputs_embeds" in kwargs:
+ raise NotImplementedError("`inputs_embeds` is not supported")
+
+ if images is not None:
+ (
+ inputs,
+ position_ids,
+ attention_mask,
+ _,
+ inputs_embeds,
+ _
+ ) = self.prepare_inputs_labels_for_multimodal(
+ inputs,
+ position_ids,
+ attention_mask,
+ None,
+ None,
+ images,
+ image_sizes=image_sizes
+ )
+ else:
+ inputs_embeds = self.get_model().embed_tokens(inputs)
+
+ return super().generate(
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ **kwargs
+ )
+
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
+ inputs_embeds=None, **kwargs):
+ images = kwargs.pop("images", None)
+ image_sizes = kwargs.pop("image_sizes", None)
+ inputs = super().prepare_inputs_for_generation(
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
+ )
+ if images is not None:
+ inputs['images'] = images
+ if image_sizes is not None:
+ inputs['image_sizes'] = image_sizes
+ return inputs
+
+AutoConfig.register("llava_llama", LlavaConfig)
+AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
diff --git a/llava/model/language_model/llava_mistral.py b/llava/model/language_model/llava_mistral.py
new file mode 100644
index 0000000000000000000000000000000000000000..0def682ea3c497e36aa85f1c53eb2cfab6e2fb87
--- /dev/null
+++ b/llava/model/language_model/llava_mistral.py
@@ -0,0 +1,158 @@
+# Copyright 2023 Haotian Liu
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+from torch.nn import CrossEntropyLoss
+
+from transformers import AutoConfig, AutoModelForCausalLM, \
+ MistralConfig, MistralModel, MistralForCausalLM
+
+from transformers.modeling_outputs import CausalLMOutputWithPast
+from transformers.generation.utils import GenerateOutput
+
+from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
+
+
+class LlavaMistralConfig(MistralConfig):
+ model_type = "llava_mistral"
+
+
+class LlavaMistralModel(LlavaMetaModel, MistralModel):
+ config_class = LlavaMistralConfig
+
+ def __init__(self, config: MistralConfig):
+ super(LlavaMistralModel, self).__init__(config)
+
+
+class LlavaMistralForCausalLM(MistralForCausalLM, LlavaMetaForCausalLM):
+ config_class = LlavaMistralConfig
+
+ def __init__(self, config):
+ super(MistralForCausalLM, self).__init__(config)
+ self.model = LlavaMistralModel(config)
+
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_model(self):
+ return self.model
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ images: Optional[torch.FloatTensor] = None,
+ image_sizes: Optional[List[List[int]]] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+
+ if inputs_embeds is None:
+ (
+ input_ids,
+ position_ids,
+ attention_mask,
+ past_key_values,
+ inputs_embeds,
+ labels
+ ) = self.prepare_inputs_labels_for_multimodal(
+ input_ids,
+ position_ids,
+ attention_mask,
+ past_key_values,
+ labels,
+ images,
+ image_sizes
+ )
+
+ return super().forward(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ labels=labels,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict
+ )
+
+ @torch.no_grad()
+ def generate(
+ self,
+ inputs: Optional[torch.Tensor] = None,
+ images: Optional[torch.Tensor] = None,
+ image_sizes: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> Union[GenerateOutput, torch.LongTensor]:
+ position_ids = kwargs.pop("position_ids", None)
+ attention_mask = kwargs.pop("attention_mask", None)
+ if "inputs_embeds" in kwargs:
+ raise NotImplementedError("`inputs_embeds` is not supported")
+
+ if images is not None:
+ (
+ inputs,
+ position_ids,
+ attention_mask,
+ _,
+ inputs_embeds,
+ _
+ ) = self.prepare_inputs_labels_for_multimodal(
+ inputs,
+ position_ids,
+ attention_mask,
+ None,
+ None,
+ images,
+ image_sizes=image_sizes
+ )
+ else:
+ inputs_embeds = self.get_model().embed_tokens(inputs)
+
+ return super().generate(
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ **kwargs
+ )
+
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
+ inputs_embeds=None, **kwargs):
+ images = kwargs.pop("images", None)
+ image_sizes = kwargs.pop("image_sizes", None)
+ inputs = super().prepare_inputs_for_generation(
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
+ )
+ if images is not None:
+ inputs['images'] = images
+ if image_sizes is not None:
+ inputs['image_sizes'] = image_sizes
+ return inputs
+
+AutoConfig.register("llava_mistral", LlavaMistralConfig)
+AutoModelForCausalLM.register(LlavaMistralConfig, LlavaMistralForCausalLM)
diff --git a/llava/model/language_model/llava_mpt.py b/llava/model/language_model/llava_mpt.py
new file mode 100644
index 0000000000000000000000000000000000000000..02e5237ece031af23fcd76b5b4e0d9b0bc5f55cc
--- /dev/null
+++ b/llava/model/language_model/llava_mpt.py
@@ -0,0 +1,97 @@
+# Copyright 2023 Haotian Liu
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import Optional, Tuple
+
+import torch
+
+from transformers import AutoConfig, AutoModelForCausalLM, \
+ MptConfig, MptForCausalLM, MptModel
+from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
+
+
+class LlavaMptConfig(MptConfig):
+ model_type = "llava_mpt"
+
+
+class LlavaMptModel(LlavaMetaModel, MptModel):
+ config_class = LlavaMptConfig
+
+ def __init__(self, config: MptConfig):
+ config.hidden_size = config.d_model
+ super(LlavaMptModel, self).__init__(config)
+
+ def embed_tokens(self, x):
+ return self.wte(x)
+
+
+class LlavaMptForCausalLM(MptForCausalLM, LlavaMetaForCausalLM):
+ config_class = LlavaMptConfig
+ supports_gradient_checkpointing = True
+
+ def __init__(self, config):
+ super(MptForCausalLM, self).__init__(config)
+
+ self.transformer = LlavaMptModel(config)
+ self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_model(self):
+ return self.transformer
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, LlavaMptModel):
+ module.gradient_checkpointing = value
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ images=None):
+
+ input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
+
+ return super().forward(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ labels=labels,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
+ images = kwargs.pop("images", None)
+ _inputs = super().prepare_inputs_for_generation(
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
+ )
+ _inputs['images'] = images
+ return _inputs
+
+
+AutoConfig.register("llava_mpt", LlavaMptConfig)
+AutoModelForCausalLM.register(LlavaMptConfig, LlavaMptForCausalLM)
diff --git a/llava/model/language_model/llava_phi3.py b/llava/model/language_model/llava_phi3.py
new file mode 100644
index 0000000000000000000000000000000000000000..93377914ea715efe62791195a367b0bbb42a6942
--- /dev/null
+++ b/llava/model/language_model/llava_phi3.py
@@ -0,0 +1,165 @@
+# Modified from LLaVA: https://github.com/haotian-liu/LLaVA.git
+#
+# Copyright 2023 Haotian Liu
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from transformers import AutoConfig, AutoModelForCausalLM, \
+ Phi3Model, Phi3Config, Phi3ForCausalLM
+
+from transformers.modeling_outputs import CausalLMOutputWithPast
+from transformers.generation.utils import GenerateOutput
+
+from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
+
+
+class LlavaPhiConfig(Phi3Config):
+ model_type = "llava_phi"
+
+
+class LlavaPhiModel(LlavaMetaModel, Phi3Model):
+ config_class = LlavaPhiConfig
+
+ def __init__(self, config: Phi3Config):
+ super(LlavaPhiModel, self).__init__(config)
+
+
+class LlavaPhiForCausalLM(Phi3ForCausalLM, LlavaMetaForCausalLM):
+ config_class = LlavaPhiConfig
+
+ def __init__(self, config):
+ super(Phi3ForCausalLM, self).__init__(config)
+ self.model = LlavaPhiModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_model(self):
+ return self.model
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ images: Optional[torch.FloatTensor] = None,
+ image_sizes: Optional[List[List[int]]] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+
+
+ if inputs_embeds is None:
+ (
+ input_ids,
+ position_ids,
+ attention_mask,
+ past_key_values,
+ inputs_embeds,
+ labels,
+
+ ) = self.prepare_inputs_labels_for_multimodal(
+ input_ids,
+ position_ids,
+ attention_mask,
+ past_key_values,
+ labels,
+ images,
+ image_sizes
+ )
+ with self.maybe_autocast():
+ return super().forward(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ labels=labels,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict
+ )
+
+
+ @torch.no_grad()
+ def generate(
+ self,
+ inputs: Optional[torch.Tensor] = None,
+ images: Optional[torch.Tensor] = None,
+ image_sizes: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> Union[GenerateOutput, torch.LongTensor]:
+ position_ids = kwargs.pop("position_ids", None)
+ attention_mask = kwargs.pop("attention_mask", None)
+ if "inputs_embeds" in kwargs:
+ raise NotImplementedError("`inputs_embeds` is not supported")
+
+
+
+ if images is not None:
+ (
+ inputs,
+ position_ids,
+ attention_mask,
+ _,
+ inputs_embeds,
+ _,
+ ) = self.prepare_inputs_labels_for_multimodal(
+ inputs,
+ position_ids,
+ attention_mask,
+ None,
+ None,
+ images,
+ )
+ else:
+ inputs_embeds = self.get_model().embed_tokens(inputs)
+
+ return super().generate(
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ **kwargs
+ )
+
+
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
+ inputs_embeds=None, **kwargs):
+ images = kwargs.pop("images", None)
+ image_sizes = kwargs.pop("image_sizes", None)
+ inputs = super().prepare_inputs_for_generation(
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
+ )
+ if images is not None:
+ inputs['images'] = images
+ if image_sizes is not None:
+ inputs['image_sizes'] = image_sizes
+ return inputs
+
+
+AutoConfig.register("llava_phi", LlavaPhiConfig)
+AutoModelForCausalLM.register(LlavaPhiConfig, LlavaPhiForCausalLM)
diff --git a/llava/model/llava_arch.py b/llava/model/llava_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba0cb78e5c6453e8954f08fc37b48ac467731bb0
--- /dev/null
+++ b/llava/model/llava_arch.py
@@ -0,0 +1,395 @@
+# Copyright 2023 Haotian Liu
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch.nn.init as init
+from abc import ABC, abstractmethod
+import torch
+from .multimodal_encoder.builder import build_text_encoder, build_pc_encoder
+from .multimodal_projector.builder import build_vision_projector, Mlp
+from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, \
+ DEFAULT_IM_END_TOKEN
+
+import contextlib
+
+
+
+class LlavaMetaModel:
+
+ def __init__(self, config):
+ super(LlavaMetaModel, self).__init__(config)
+
+ self.pc_encoder_float_flag =False
+
+
+ def get_vision_tower(self):
+ vision_tower = getattr(self, 'vision_tower', None)
+ if type(vision_tower) is list:
+ vision_tower = vision_tower[0]
+ return vision_tower
+
+ # ไป่ฟๅๅงๅ็encoder ๅ projector
+ def initialize_other_modules(self, model_args, fsdp=None):
+ # ๆต่ฏไธmm_vision_tower
+ # vision_tower = model_args.vision_tower
+ pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
+
+ self.model_args = model_args
+ # self.config.mm_vision_tower = vision_tower
+
+ # ๅ็จไธvision_tower๏ผ
+ # text encoder ๅpc encoder ้ฝ็จ่ฟไธชๅฝๆฐๅ
+ if model_args.encoder_type == 'text_encoder':
+ print("loading text encoder pre-treained weight")
+ self.vision_tower = build_text_encoder()
+
+ elif model_args.encoder_type == 'pc_encoder':
+ self.vision_tower = build_pc_encoder(model_args)
+
+ checkpoint = torch.load(model_args.pc_ckpt_path, map_location='cpu')
+ sd = checkpoint['module']
+ if next(iter(sd.items()))[0].startswith('module'):
+ sd = {k[len('module'):]: v for k, v in sd.items()}
+
+ base_ckpt = sd
+ for k in list(base_ckpt.keys()):
+ if k.startswith('point_encoder'):
+ base_ckpt[k[len('point_encoder.'):]] = base_ckpt[k]
+ del base_ckpt[k]
+
+ self.vision_tower.load_state_dict(base_ckpt, strict=False)
+ else:
+ raise NotImplementedError
+
+ # ๅๅงๅ ๆฅๅฐLLM็MLP
+ self.mm_projector = build_vision_projector(self.config, self.model_args)
+
+ if pretrain_mm_mlp_adapter is not None:
+ mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
+
+ def get_w(weights, keyword):
+ return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
+
+ def get_w_4_query_token(weights, keyword):
+ for k, v in weights.items():
+ if keyword in k:
+ return v
+
+ aa = self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'), strict=False)
+ print("projctor")
+ print(aa)
+
+ aa = self.vision_tower.load_state_dict(get_w(mm_projector_weights, 'vision_tower'), strict=False)
+ print("vision_tower.unexpected_keys")
+ print(aa.unexpected_keys)
+
+ else:
+ self.random_initialize_model(self.mm_projector)
+
+
+ def random_initialize_model(self, model, mean=0.0, std=0.02):
+ """
+ ้ๆบๅๅงๅ็ปๅฎๆจกๅ็ๆๆๅๆฐใ
+
+ ๅๆฐ:
+ - model (nn.Module): ่ฆๅๅงๅ็PyTorchๆจกๅๅฎไพใ
+ - mean (float): ๆ้ๅๅงๅ็ๅๅผ๏ผ้ป่ฎคไธบ0.0ใ
+ - std (float): ๆ้ๅๅงๅ็ๆ ๅๅทฎ๏ผ้ป่ฎคไธบ0.02ใ
+ """
+ for name, param in model.named_parameters():
+ if 'weight' in name:
+ # ๅฏนไบๆ้๏ผไฝฟ็จๆญฃๆๅๅธ่ฟ่กๅๅงๅ
+ init.normal_(param.data, mean=mean, std=std)
+ elif 'bias' in name:
+ # ๅฏนไบๅ็ฝฎ้กน๏ผ้ป่ฎคๅๅงๅไธบ0๏ผไนๅฏไปฅๆ นๆฎ้่ฆ่ฐๆด
+ init.constant_(param.data, val=0)
+
+
+def unpad_image(tensor, original_size):
+ """
+ Unpads a PyTorch tensor of a padded and resized image.
+
+ Args:
+ tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
+ original_size (tuple): The original size of PIL image (width, height).
+
+ Returns:
+ torch.Tensor: The unpadded image tensor.
+ """
+ original_width, original_height = original_size
+ current_height, current_width = tensor.shape[1:]
+
+ original_aspect_ratio = original_width / original_height
+ current_aspect_ratio = current_width / current_height
+
+ if original_aspect_ratio > current_aspect_ratio:
+ scale_factor = current_width / original_width
+ new_height = int(original_height * scale_factor)
+ padding = (current_height - new_height) // 2
+ unpadded_tensor = tensor[:, padding:current_height - padding, :]
+ else:
+ scale_factor = current_height / original_height
+ new_width = int(original_width * scale_factor)
+ padding = (current_width - new_width) // 2
+ unpadded_tensor = tensor[:, :, padding:current_width - padding]
+
+ return unpadded_tensor
+
+
+class LlavaMetaForCausalLM(ABC):
+
+ def maybe_autocast(self, dtype=torch.float16):
+ # if on cpu, don't use autocast
+ # if on gpu, use autocast with dtype if provided, otherwise use torch.float16
+ enable_autocast = self.device != torch.device("cpu")
+ self.csc_loss = torch.nn.SmoothL1Loss()
+
+ if enable_autocast:
+ return torch.cuda.amp.autocast(dtype=dtype)
+ else:
+ return contextlib.nullcontext()
+
+ @abstractmethod
+ def get_model(self):
+ pass
+
+ def get_vision_tower(self):
+ return self.get_model().get_vision_tower()
+
+ # ่ฟ้็image ๅฏ่ฝๆฏtext ไนๅฏ่ฝๆฏpc
+ def encode_datas(self, images):
+
+ with self.maybe_autocast():
+
+
+ if self.get_model().model_args.encoder_type == "text_encoder":
+
+ class_embeddings = self.get_model().vision_tower.encode_text(images)
+
+ class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
+
+ std = self.get_model().model_args.std
+
+ class_embeddings = class_embeddings + std * torch.randn_like(class_embeddings)
+
+ image_features = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
+ image_features = image_features.unsqueeze(1)
+
+
+
+ elif self.get_model().model_args.encoder_type == "pc_encoder":
+
+ if self.get_model().get_vision_tower().cls_token.dtype!=torch.float:
+ self.get_model().get_vision_tower().to(dtype=torch.float)
+
+ image_features = self.get_model().get_vision_tower()(images, self.get_model().model_args.get_pc_tokens_way)
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
+
+ std = self.get_model().model_args.std
+ image_features = image_features + std * torch.randn_like(image_features)
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
+
+ else:
+ raise NotImplementedError
+
+ image_features = self.get_model().mm_projector(image_features)
+
+ return image_features
+
+
+ def prepare_inputs_labels_for_multimodal(
+ self, input_ids, position_ids, attention_mask, past_key_values, labels,
+ images, image_sizes=None
+ ):
+
+ if images is None or input_ids.shape[1] == 1:
+ return input_ids, position_ids, attention_mask, past_key_values, None, labels
+
+ image_features= self.encode_datas(images)
+
+ # TODO: image start / end is not implemented here to support pretraining.
+ if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
+ raise NotImplementedError
+
+ # Let's just add dummy tensors if they do not exist,
+ # it is a headache to deal with None all the time.
+ # But it is not ideal, and if you have a better idea,
+ # please open an issue / submit a PR, thanks.
+ _labels = labels
+ _position_ids = position_ids
+ _attention_mask = attention_mask
+ if attention_mask is None:
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
+ else:
+ attention_mask = attention_mask.bool()
+ if position_ids is None:
+ position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
+ if labels is None:
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
+
+ # remove the padding using attention_mask -- FIXME
+ _input_ids = input_ids
+ input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in
+ zip(input_ids, attention_mask)]
+ labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
+
+ new_input_embeds = []
+ new_labels = []
+ cur_image_idx = 0
+ for batch_idx, cur_input_ids in enumerate(input_ids):
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
+ if num_images == 0:
+ cur_image_features = image_features[cur_image_idx]
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
+ new_input_embeds.append(cur_input_embeds)
+ new_labels.append(labels[batch_idx])
+ cur_image_idx += 1
+ continue
+
+ image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [
+ cur_input_ids.shape[0]]
+ cur_input_ids_noim = []
+ cur_labels = labels[batch_idx]
+ cur_labels_noim = []
+ for i in range(len(image_token_indices) - 1):
+ cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1:image_token_indices[i + 1]])
+ cur_labels_noim.append(cur_labels[image_token_indices[i] + 1:image_token_indices[i + 1]])
+ split_sizes = [x.shape[0] for x in cur_labels_noim]
+ cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
+ cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
+ cur_new_input_embeds = []
+ cur_new_labels = []
+
+ for i in range(num_images + 1):
+ cur_new_input_embeds.append(cur_input_embeds_no_im[i])
+ cur_new_labels.append(cur_labels_noim[i])
+ if i < num_images:
+ cur_image_features = image_features[cur_image_idx]
+ cur_image_idx += 1
+ cur_new_input_embeds.append(cur_image_features)
+ cur_new_labels.append(
+ torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device,
+ dtype=cur_labels.dtype))
+
+ cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
+
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds)
+ cur_new_labels = torch.cat(cur_new_labels)
+
+ new_input_embeds.append(cur_new_input_embeds)
+ new_labels.append(cur_new_labels)
+
+ # Truncate sequences to max length as image embeddings can make the sequence longer
+ tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
+ if tokenizer_model_max_length is not None:
+ new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
+ new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
+
+ # Combine them
+ max_len = max(x.shape[0] for x in new_input_embeds)
+ batch_size = len(new_input_embeds)
+
+ new_input_embeds_padded = []
+ new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype,
+ device=new_labels[0].device)
+ attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
+ position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
+
+ for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
+ cur_len = cur_new_embed.shape[0]
+ if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
+ new_input_embeds_padded.append(torch.cat((
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype,
+ device=cur_new_embed.device),
+ cur_new_embed
+ ), dim=0))
+ if cur_len > 0:
+ new_labels_padded[i, -cur_len:] = cur_new_labels
+ attention_mask[i, -cur_len:] = True
+ position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype,
+ device=position_ids.device)
+ else:
+ new_input_embeds_padded.append(torch.cat((
+ cur_new_embed,
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype,
+ device=cur_new_embed.device)
+ ), dim=0))
+ if cur_len > 0:
+ new_labels_padded[i, :cur_len] = cur_new_labels
+ attention_mask[i, :cur_len] = True
+ position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype,
+ device=position_ids.device)
+
+ new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
+
+ if _labels is None:
+ new_labels = None
+ else:
+ new_labels = new_labels_padded
+
+ if _attention_mask is None:
+ attention_mask = None
+ else:
+ attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
+
+ if _position_ids is None:
+ position_ids = None
+
+ return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
+
+ def initialize_vision_tokenizer(self, model_args, tokenizer):
+ if model_args.mm_use_im_patch_token:
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
+ self.resize_token_embeddings(len(tokenizer))
+
+ if model_args.mm_use_im_start_end:
+ num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
+ self.resize_token_embeddings(len(tokenizer))
+
+ if num_new_tokens > 0:
+ input_embeddings = self.get_input_embeddings().weight.data
+ output_embeddings = self.get_output_embeddings().weight.data
+
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
+ dim=0, keepdim=True)
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
+ dim=0, keepdim=True)
+
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
+
+ if model_args.tune_mm_mlp_adapter:
+ for p in self.get_input_embeddings().parameters():
+ p.requires_grad = True
+ for p in self.get_output_embeddings().parameters():
+ p.requires_grad = False
+
+ if model_args.pretrain_mm_mlp_adapter:
+ mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
+ embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
+ assert num_new_tokens == 2
+ if input_embeddings.shape == embed_tokens_weight.shape:
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
+ elif embed_tokens_weight.shape[0] == num_new_tokens:
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight
+ else:
+ raise ValueError(
+ f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
+ elif model_args.mm_use_im_patch_token:
+ if model_args.tune_mm_mlp_adapter:
+ for p in self.get_input_embeddings().parameters():
+ p.requires_grad = False
+ for p in self.get_output_embeddings().parameters():
+ p.requires_grad = False
diff --git a/llava/model/make_delta.py b/llava/model/make_delta.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ae55d59c2c8bab80299272314a41bbeb959d8ed
--- /dev/null
+++ b/llava/model/make_delta.py
@@ -0,0 +1,52 @@
+"""
+Usage:
+python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta
+"""
+import argparse
+
+import torch
+from tqdm import tqdm
+from transformers import AutoTokenizer, AutoModelForCausalLM
+from llava.model.utils import auto_upgrade
+
+
+def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
+ print("Loading base model")
+ base = AutoModelForCausalLM.from_pretrained(
+ base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
+
+ print("Loading target model")
+ auto_upgrade(target_model_path)
+ target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
+
+ print("Calculating delta")
+ for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
+ if name not in base.state_dict():
+ assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
+ continue
+ if param.data.shape == base.state_dict()[name].shape:
+ param.data -= base.state_dict()[name]
+ else:
+ assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
+ bparam = base.state_dict()[name]
+ param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam
+
+ print("Saving delta")
+ if hub_repo_id:
+ kwargs = {"push_to_hub": True, "repo_id": hub_repo_id}
+ else:
+ kwargs = {}
+ target.save_pretrained(delta_path, **kwargs)
+ target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
+ target_tokenizer.save_pretrained(delta_path, **kwargs)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--base-model-path", type=str, required=True)
+ parser.add_argument("--target-model-path", type=str, required=True)
+ parser.add_argument("--delta-path", type=str, required=True)
+ parser.add_argument("--hub-repo-id", type=str, default=None)
+ args = parser.parse_args()
+
+ make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id)
diff --git a/llava/model/multimodal_encoder/__pycache__/builder.cpython-310.pyc b/llava/model/multimodal_encoder/__pycache__/builder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..feee26fca5e9dbd24927f5dfef3b22ed896dd686
--- /dev/null
+++ b/llava/model/multimodal_encoder/__pycache__/builder.cpython-310.pyc
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:178c12b3393e0a12df711219c7608039d527487e5c2a7c4746cb188cb4379a03
+size 1228
diff --git a/llava/model/multimodal_encoder/__pycache__/clip_encoder.cpython-310.pyc b/llava/model/multimodal_encoder/__pycache__/clip_encoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e6853dd0d088b27f41ed2990b038d5cf5355524b
--- /dev/null
+++ b/llava/model/multimodal_encoder/__pycache__/clip_encoder.cpython-310.pyc
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:87b021a0f58f3ec2d95e7eb66e0a445cf6f1a443d6b90a3ce860a888ece60418
+size 5323
diff --git a/llava/model/multimodal_encoder/__pycache__/point_encoder.cpython-310.pyc b/llava/model/multimodal_encoder/__pycache__/point_encoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e151aaec81f3c3a072cb3a24de13973f490a08c0
--- /dev/null
+++ b/llava/model/multimodal_encoder/__pycache__/point_encoder.cpython-310.pyc
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:08ba5cdf252e09f45486df0e167d170efb7f9715f84644a69334a983bc060e81
+size 9096
diff --git a/llava/model/multimodal_encoder/builder.py b/llava/model/multimodal_encoder/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..3dedda9986c468204e025df522a98e2e673c86a2
--- /dev/null
+++ b/llava/model/multimodal_encoder/builder.py
@@ -0,0 +1,45 @@
+import open_clip
+from .point_encoder import PointcloudEncoder
+import timm
+
+
+def build_text_encoder():
+ clip_model_type = "EVA02-E-14-plus"
+ pretrained = "./pretrained_weight/clip_used_in_Uni3D/open_clip_pytorch_model.bin"
+ clip_model, _, _ = open_clip.create_model_and_transforms(model_name=clip_model_type, pretrained=pretrained)
+
+ return clip_model
+
+
+
+def build_pc_encoder(args):
+
+ pretrained_pc = ''
+ drop_path_rate = 0.0
+
+ pc_encoder_type = getattr(args, 'pc_encoder_type', 'small')
+
+ if pc_encoder_type == "giant":
+ pc_model = "eva_giant_patch14_560"
+ args.pc_feat_dim = 1408
+ elif pc_encoder_type == "large":
+ pc_model = "eva02_large_patch14_448"
+ args.pc_feat_dim = 1024
+ elif pc_encoder_type == "base":
+ pc_model = "eva02_base_patch14_448"
+ args.pc_feat_dim = 768
+ elif pc_encoder_type == "small":
+ pc_model = "eva02_small_patch14_224"
+ args.pc_feat_dim = 384
+ elif pc_encoder_type == "tiny":
+ pc_model = "eva02_tiny_patch14_224"
+ args.pc_feat_dim = 192
+
+
+ # create transformer blocks for point cloud via timm
+ point_transformer = timm.create_model(pc_model, checkpoint_path=pretrained_pc, drop_path_rate= drop_path_rate)
+
+ # create whole point cloud encoder
+ point_encoder = PointcloudEncoder(point_transformer, args)
+
+ return point_encoder
\ No newline at end of file
diff --git a/llava/model/multimodal_encoder/clip_encoder.py b/llava/model/multimodal_encoder/clip_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c81415cd0f4ebbbe66385450236c427f5e8fb02
--- /dev/null
+++ b/llava/model/multimodal_encoder/clip_encoder.py
@@ -0,0 +1,147 @@
+import torch
+import torch.nn as nn
+
+from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
+
+
+class CLIPVisionTower(nn.Module):
+ def __init__(self, vision_tower, args, delay_load=False):
+ super().__init__()
+
+ self.is_loaded = False
+
+ self.vision_tower_name = vision_tower
+ self.select_layer = args.mm_vision_select_layer
+ self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
+
+ if not delay_load:
+ self.load_model()
+ elif getattr(args, 'unfreeze_mm_vision_tower', False):
+ self.load_model()
+ else:
+ self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
+
+ def load_model(self, device_map=None):
+ if self.is_loaded:
+ print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
+ return
+
+ self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
+ self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
+ self.vision_tower.requires_grad_(False)
+
+ self.is_loaded = True
+
+ def feature_select(self, image_forward_outs):
+ image_features = image_forward_outs.hidden_states[self.select_layer]
+ if self.select_feature == 'patch':
+ image_features = image_features[:, 1:]
+ elif self.select_feature == 'cls_patch':
+ image_features = image_features
+ else:
+ raise ValueError(f'Unexpected select feature: {self.select_feature}')
+ return image_features
+
+ @torch.no_grad()
+ def forward(self, images):
+ if type(images) is list:
+ image_features = []
+ for image in images:
+ image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
+ image_features.append(image_feature)
+ else:
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
+
+ return image_features
+
+ @property
+ def dummy_feature(self):
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
+
+ @property
+ def dtype(self):
+ return self.vision_tower.dtype
+
+ @property
+ def device(self):
+ return self.vision_tower.device
+
+ @property
+ def config(self):
+ if self.is_loaded:
+ return self.vision_tower.config
+ else:
+ return self.cfg_only
+
+ @property
+ def hidden_size(self):
+ return self.config.hidden_size
+
+ @property
+ def num_patches_per_side(self):
+ return self.config.image_size // self.config.patch_size
+
+ @property
+ def num_patches(self):
+ return (self.config.image_size // self.config.patch_size) ** 2
+
+
+
+class CLIPVisionTowerS2(CLIPVisionTower):
+ def __init__(self, vision_tower, args, delay_load=False):
+ super().__init__(vision_tower, args, delay_load)
+
+ self.s2_scales = getattr(args, 's2_scales', '336,672,1008')
+ self.s2_scales = list(map(int, self.s2_scales.split(',')))
+ self.s2_scales.sort()
+ self.s2_split_size = self.s2_scales[0]
+ self.s2_image_size = self.s2_scales[-1]
+
+ try:
+ from s2wrapper import forward as multiscale_forward
+ except ImportError:
+ raise ImportError('Package s2wrapper not found! Please install by running: \npip install git+https://github.com/bfshi/scaling_on_scales.git')
+ self.multiscale_forward = multiscale_forward
+
+ # change resize/crop size in preprocessing to the largest image size in s2_scale
+ if not delay_load or getattr(args, 'unfreeze_mm_vision_tower', False):
+ self.image_processor.size['shortest_edge'] = self.s2_image_size
+ self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size
+
+ def load_model(self, device_map=None):
+ if self.is_loaded:
+ print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
+ return
+
+ self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
+ self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
+ self.vision_tower.requires_grad_(False)
+
+ self.image_processor.size['shortest_edge'] = self.s2_image_size
+ self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size
+
+ self.is_loaded = True
+
+ @torch.no_grad()
+ def forward_feature(self, images):
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
+ return image_features
+
+ @torch.no_grad()
+ def forward(self, images):
+ if type(images) is list:
+ image_features = []
+ for image in images:
+ image_feature = self.multiscale_forward(self.forward_feature, image.unsqueeze(0), img_sizes=self.s2_scales, max_split_size=self.s2_split_size)
+ image_features.append(image_feature)
+ else:
+ image_features = self.multiscale_forward(self.forward_feature, images, img_sizes=self.s2_scales, max_split_size=self.s2_split_size)
+
+ return image_features
+
+ @property
+ def hidden_size(self):
+ return self.config.hidden_size * len(self.s2_scales)
diff --git a/llava/model/multimodal_encoder/point_encoder.py b/llava/model/multimodal_encoder/point_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..b951ef0275b2f568f64a3f3d2d0c88cd526a60a1
--- /dev/null
+++ b/llava/model/multimodal_encoder/point_encoder.py
@@ -0,0 +1,346 @@
+import torch
+import torch.nn as nn
+from pointnet2_ops import pointnet2_utils
+import contextlib
+import logging
+import torch.nn.functional as F
+
+def fps(data, number):
+ '''
+ data B N 3
+ number int
+ '''
+ fps_idx = pointnet2_utils.furthest_point_sample(data, number)
+ fps_data = pointnet2_utils.gather_operation(data.transpose(1, 2).contiguous(), fps_idx).transpose(1,2).contiguous()
+ return fps_data
+
+
+def index_points(points, idx):
+ """
+ Input:
+ points: input points data, [B, N, C]
+ idx: sample index data, [B, S]
+ Return:
+ new_points:, indexed points data, [B, S, C]
+ """
+ device = points.device
+ B = points.shape[0]
+ view_shape = list(idx.shape)
+ view_shape[1:] = [1] * (len(view_shape) - 1)
+ repeat_shape = list(idx.shape)
+ repeat_shape[0] = 1
+ batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
+ new_points = points[batch_indices, idx, :]
+ return new_points
+
+
+# https://github.com/Strawberry-Eat-Mango/PCT_Pytorch/blob/main/util.py
+def knn_point(nsample, xyz, new_xyz):
+ """
+ Input:
+ nsample: max sample number in local region
+ xyz: all points, [B, N, C]
+ new_xyz: query points, [B, S, C]
+ Return:
+ group_idx: grouped points index, [B, S, nsample]
+ """
+ sqrdists = square_distance(new_xyz, xyz)
+ _, group_idx = torch.topk(sqrdists, nsample, dim = -1, largest=False, sorted=False)
+ return group_idx
+
+def square_distance(src, dst):
+ """
+ Calculate Euclid distance between each two points.
+ src^T * dst = xn * xm + yn * ym + zn * zm;
+ sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
+ sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
+ dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
+ = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
+ Input:
+ src: source points, [B, N, C]
+ dst: target points, [B, M, C]
+ Output:
+ dist: per-point square distance, [B, N, M]
+ """
+ B, N, _ = src.shape
+ _, M, _ = dst.shape
+ dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
+ dist += torch.sum(src ** 2, -1).view(B, N, 1)
+ dist += torch.sum(dst ** 2, -1).view(B, 1, M)
+ return dist
+
+
+class PatchDropout(nn.Module):
+ """
+ https://arxiv.org/abs/2212.00794
+ """
+
+ def __init__(self, prob, exclude_first_token=True):
+ super().__init__()
+ assert 0 <= prob < 1.
+ self.prob = prob
+ self.exclude_first_token = exclude_first_token # exclude CLS token
+ logging.info("patch dropout prob is {}".format(prob))
+
+ def forward(self, x):
+ # if not self.training or self.prob == 0.:
+ # return x
+
+ if self.exclude_first_token:
+ cls_tokens, x = x[:, :1], x[:, 1:]
+ else:
+ cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
+
+ batch = x.size()[0]
+ num_tokens = x.size()[1]
+
+ batch_indices = torch.arange(batch)
+ batch_indices = batch_indices[..., None]
+
+ keep_prob = 1 - self.prob
+ num_patches_keep = max(1, int(num_tokens * keep_prob))
+
+ rand = torch.randn(batch, num_tokens)
+ patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
+
+ x = x[batch_indices, patch_indices_keep]
+
+ if self.exclude_first_token:
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ return x
+
+
+class Group(nn.Module):
+ def __init__(self, num_group, group_size):
+ super().__init__()
+ self.num_group = num_group
+ self.group_size = group_size
+
+ def forward(self, xyz, color):
+ '''
+ input: B N 3
+ ---------------------------
+ output: B G M 3
+ center : B G 3
+ '''
+ batch_size, num_points, _ = xyz.shape
+ # fps the centers out
+ center = fps(xyz, self.num_group) # B G 3
+ # knn to get the neighborhood
+ # _, idx = self.knn(xyz, center) # B G M
+ idx = knn_point(self.group_size, xyz, center) # B G M
+ assert idx.size(1) == self.num_group
+ assert idx.size(2) == self.group_size
+ idx_base = torch.arange(0, batch_size, device=xyz.device).view(-1, 1, 1) * num_points
+ idx = idx + idx_base
+ idx = idx.view(-1)
+ neighborhood = xyz.view(batch_size * num_points, -1)[idx, :]
+ neighborhood = neighborhood.view(batch_size, self.num_group, self.group_size, 3).contiguous()
+
+ neighborhood_color = color.view(batch_size * num_points, -1)[idx, :]
+ neighborhood_color = neighborhood_color.view(batch_size, self.num_group, self.group_size, 3).contiguous()
+
+ # normalize
+ neighborhood = neighborhood - center.unsqueeze(2)
+
+ features = torch.cat((neighborhood, neighborhood_color), dim=-1)
+ return neighborhood, center, features
+
+class Encoder(nn.Module):
+ def __init__(self, encoder_channel):
+ super().__init__()
+ self.encoder_channel = encoder_channel
+ self.first_conv = nn.Sequential(
+ nn.Conv1d(6, 128, 1),
+ nn.BatchNorm1d(128),
+ nn.ReLU(inplace=True),
+ nn.Conv1d(128, 256, 1)
+ )
+ self.second_conv = nn.Sequential(
+ nn.Conv1d(512, 512, 1),
+ nn.BatchNorm1d(512),
+ nn.ReLU(inplace=True),
+ nn.Conv1d(512, self.encoder_channel, 1)
+ )
+ def forward(self, point_groups):
+ '''
+ point_groups : B G N 3
+ -----------------
+ feature_global : B G C
+ '''
+ bs, g, n , _ = point_groups.shape
+ point_groups = point_groups.reshape(bs * g, n, 6)
+ # encoder
+ feature = self.first_conv(point_groups.transpose(2,1)) # BG 256 n
+ feature_global = torch.max(feature,dim=2,keepdim=True)[0] # BG 256 1
+ feature = torch.cat([feature_global.expand(-1,-1,n), feature], dim=1)# BG 512 n
+ feature = self.second_conv(feature) # BG 1024 n
+ feature_global = torch.max(feature, dim=2, keepdim=False)[0] # BG 1024
+ return feature_global.reshape(bs, g, self.encoder_channel)
+
+
+
+class skeleton_Group(nn.Module):
+ def __init__(self, num_group=32, group_size=8):
+ super().__init__()
+ self.num_group = num_group
+ self.group_size = group_size
+
+ def forward(self, xyz, token_feat, num_group=32, group_size=8):
+ '''
+ xyz: ๆๆtoken็xyz
+
+ input: B N 3
+ ---------------------------
+ output: B G M 3
+ center : B G 3
+ '''
+ self.num_group = num_group
+ self.group_size = group_size
+
+ batch_size, num_points, _ = xyz.shape
+ _, _, C_ = token_feat.shape
+ # fps the centers out
+ center = fps(xyz, self.num_group) # B G 3
+ # knn to get the neighborhood
+ # _, idx = self.knn(xyz, center) # B G M
+ idx = knn_point(self.group_size, xyz, center) # B G M
+
+ assert idx.size(1) == self.num_group
+ assert idx.size(2) == self.group_size
+ idx_base = torch.arange(0, batch_size, device=xyz.device).view(-1, 1, 1) * num_points
+ idx = idx + idx_base
+ idx = idx.view(-1)
+
+ token_feat = token_feat.contiguous()
+ neighborhood_token_feat = token_feat.view(batch_size * num_points, -1)[idx, :]
+ # T_p: B, 32, 8,384 -> B, M, K, C
+ neighborhood_token_feat = neighborhood_token_feat.view(batch_size, self.num_group, self.group_size, C_).contiguous()
+
+
+
+ # T_m: # B๏ผ32, 384 -> B, M, C
+ center_token, _ = torch.max(neighborhood_token_feat, dim=2)
+
+
+ # #### fix OMโโpooling no resnet, no dim norm
+ # T_m: B๏ผ32,1, 384 -> B, M, 1, C
+ center_token = center_token.unsqueeze(2)
+ # T_m (T): B๏ผ32,384, 1 -> B, M, C, 1
+ center_token = center_token.permute(0, 1, 3, 2)
+
+ neighborhood_token_feat_ = torch.nn.functional.normalize(neighborhood_token_feat, p=2, dim=-1)
+ center_token_ = torch.nn.functional.normalize(center_token, p=2, dim=-2)
+
+ # A: B, 32, 8, 1 -> B, M, K, 1
+ router_weights = torch.einsum('btnj,btjm->btnm', [neighborhood_token_feat_, center_token_])
+ # A(T): B, 32, 1, 8 -> B, M, 1, K
+ router_weights = router_weights.permute(0, 1, 3, 2)
+ # A(T): B, 32, 1, 8 -> B, M, 1, K
+ router_weights = F.softmax(router_weights, dim=-1)
+
+
+ # T^p_pc: B, 32, 1, 384 -> B, M, 1, C
+ skeleton_token = torch.einsum('bmok, bmkc->bmoc', [router_weights, neighborhood_token_feat])
+ skeleton_token = skeleton_token.squeeze(dim=-2)
+ ####
+
+ return skeleton_token
+
+
+
+
+
+class PointcloudEncoder(nn.Module):
+ def __init__(self, point_transformer, args):
+ super().__init__()
+ from easydict import EasyDict
+ self.trans_dim = args.pc_feat_dim # 768
+ self.embed_dim = args.embed_dim # 512
+ self.group_size = args.group_size # 32
+ self.num_group = args.num_group # 512
+ # grouper
+ self.group_divider = Group(num_group = self.num_group, group_size = self.group_size)
+ self.skeleton_Group = skeleton_Group()
+ # define the encoder
+ self.encoder_dim = args.pc_encoder_dim # 256
+ self.encoder = Encoder(encoder_channel = self.encoder_dim)
+
+ # bridge encoder and transformer
+ self.encoder2trans = nn.Linear(self.encoder_dim, self.trans_dim)
+
+ # bridge transformer and clip embedding
+ self.trans2embed = nn.Linear(self.trans_dim, self.embed_dim)
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, self.trans_dim))
+ self.cls_pos = nn.Parameter(torch.randn(1, 1, self.trans_dim))
+
+ self.pos_embed = nn.Sequential(
+ nn.Linear(3, 128),
+ nn.GELU(),
+ nn.Linear(128, self.trans_dim)
+ )
+ # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
+ self.patch_dropout = PatchDropout(args.patch_dropout) if args.patch_dropout > 0. else nn.Identity()
+ self.visual = point_transformer
+
+ self.float_flag =False #
+
+
+ def forward(self, pc, get_pc_tokens_way=None):
+
+ pc = pc.to(dtype=torch.float)
+
+ pts = pc[:,:,:3].contiguous()
+ colors = pc[:,:,3:].contiguous()
+
+ # divide the point cloud in the same form. This is important
+ _, center, features = self.group_divider(pts, colors)
+
+
+ # encoder the input cloud patches
+ group_input_tokens = self.encoder(features) # B G N
+
+ group_input_tokens = self.encoder2trans(group_input_tokens)
+ # prepare cls
+ cls_tokens = self.cls_token.expand(group_input_tokens.size(0), -1, -1)
+ cls_pos = self.cls_pos.expand(group_input_tokens.size(0), -1, -1)
+ # add pos embedding
+ pos = self.pos_embed(center)
+ # final input
+ x = torch.cat((cls_tokens, group_input_tokens), dim=1)
+ pos = torch.cat((cls_pos, pos), dim=1)
+ # transformer
+ x = x + pos
+ # x = x.half()
+
+ # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
+ x = self.patch_dropout(x)
+
+ x = self.visual.pos_drop(x)
+
+ # ModuleList not support forward
+ block_len = len(self.visual.blocks)
+ for i, blk in enumerate(self.visual.blocks):
+ x = blk(x)
+ if block_len-2 == i:
+ last_sec_x =x
+
+
+ if get_pc_tokens_way=="CLS": # CLS
+ x = self.visual.norm(x[:, 0, :])
+ x = self.visual.fc_norm(x)
+ x = self.trans2embed(x)
+ x = x.unsqueeze(1)
+
+ elif get_pc_tokens_way=="OM_Pooling":
+
+
+ pc_skeleton = self.skeleton_Group(center, last_sec_x[:, 1:])
+
+ x = torch.cat([x[:, 0].unsqueeze(1), x[:, 1:].max(1)[0].unsqueeze(1), x[:, 1:].mean(1).unsqueeze(1), torch.sum(x[:, 1:], dim=1).unsqueeze(1), pc_skeleton], dim=-2)
+ x = self.visual.norm(x)
+ x = self.visual.fc_norm(x)
+ x = self.trans2embed(x)
+
+ return x
\ No newline at end of file
diff --git a/llava/model/multimodal_projector/__pycache__/Qformer.cpython-310.pyc b/llava/model/multimodal_projector/__pycache__/Qformer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ab913de4cfb60371fb96c520f6a1dd2ff2229907
--- /dev/null
+++ b/llava/model/multimodal_projector/__pycache__/Qformer.cpython-310.pyc
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9d5a3558e8885f5f870741cfc504c00deba54938fc57c340edf096e7b3646e8d
+size 30642
diff --git a/llava/model/multimodal_projector/__pycache__/builder.cpython-310.pyc b/llava/model/multimodal_projector/__pycache__/builder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9d30ff331c671142d8d7a7b9854b37a99f7ff372
--- /dev/null
+++ b/llava/model/multimodal_projector/__pycache__/builder.cpython-310.pyc
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7239be19b37eab9b518971a1ca5ffe4e738d2c72824054b94942cd7b16b0169a
+size 2534
diff --git a/llava/model/multimodal_projector/builder.py b/llava/model/multimodal_projector/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb25cbedd7a3ca773e8f03cc703d73996c598fea
--- /dev/null
+++ b/llava/model/multimodal_projector/builder.py
@@ -0,0 +1,67 @@
+import torch
+import torch.nn as nn
+import re
+
+
+class IdentityMap(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x, *args, **kwargs):
+ return x
+
+ @property
+ def config(self):
+ return {"mm_projector_type": 'identity'}
+
+
+class SimpleResBlock(nn.Module):
+ def __init__(self, channels):
+ super().__init__()
+ self.pre_norm = nn.LayerNorm(channels)
+
+ self.proj = nn.Sequential(
+ nn.Linear(channels, channels),
+ nn.GELU(),
+ nn.Linear(channels, channels)
+ )
+ def forward(self, x):
+ x = self.pre_norm(x)
+ return x + self.proj(x)
+
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., norm_layer=nn.LayerNorm):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ self.norm1 = norm_layer(in_features)
+
+ def forward(self, x):
+ x = self.norm1(x)
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+def build_vision_projector(config, model_args , delay_load=False, **kwargs):
+ projector_type = getattr(config, 'mm_projector_type', 'linear')
+
+ mm_hidden_size = 1024
+
+ modules = [nn.Linear(mm_hidden_size, config.hidden_size)]
+ modules.append(nn.GELU())
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
+ proj_mlp = nn.Sequential(*modules)
+
+ return proj_mlp
+
+
+
+
diff --git a/llava/model/utils.py b/llava/model/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..925a3f4e82ae12c9e3fcd39e3f68cf8c853e8c5a
--- /dev/null
+++ b/llava/model/utils.py
@@ -0,0 +1,281 @@
+import datetime
+import logging
+import logging.handlers
+import os
+import sys
+
+import requests
+
+from llava.constants import LOGDIR
+
+server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
+moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
+
+handler = None
+
+
+def build_logger(logger_name, logger_filename):
+ global handler
+
+ formatter = logging.Formatter(
+ fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ )
+
+ # Set the format of root handlers
+ if not logging.getLogger().handlers:
+ logging.basicConfig(level=logging.INFO)
+ logging.getLogger().handlers[0].setFormatter(formatter)
+
+ # Redirect stdout and stderr to loggers
+ stdout_logger = logging.getLogger("stdout")
+ stdout_logger.setLevel(logging.INFO)
+ sl = StreamToLogger(stdout_logger, logging.INFO)
+ sys.stdout = sl
+
+ stderr_logger = logging.getLogger("stderr")
+ stderr_logger.setLevel(logging.ERROR)
+ sl = StreamToLogger(stderr_logger, logging.ERROR)
+ sys.stderr = sl
+
+ # Get logger
+ logger = logging.getLogger(logger_name)
+ logger.setLevel(logging.INFO)
+
+ # Add a file handler for all loggers
+ if handler is None:
+ os.makedirs(LOGDIR, exist_ok=True)
+ filename = os.path.join(LOGDIR, logger_filename)
+ handler = logging.handlers.TimedRotatingFileHandler(
+ filename, when='D', utc=True, encoding='UTF-8')
+ handler.setFormatter(formatter)
+
+ for name, item in logging.root.manager.loggerDict.items():
+ if isinstance(item, logging.Logger):
+ item.addHandler(handler)
+
+ return logger
+
+
+class StreamToLogger(object):
+ """
+ Fake file-like stream object that redirects writes to a logger instance.
+ """
+ def __init__(self, logger, log_level=logging.INFO):
+ self.terminal = sys.stdout
+ self.logger = logger
+ self.log_level = log_level
+ self.linebuf = ''
+
+ def __getattr__(self, attr):
+ return getattr(self.terminal, attr)
+
+ def write(self, buf):
+ temp_linebuf = self.linebuf + buf
+ self.linebuf = ''
+ for line in temp_linebuf.splitlines(True):
+ # From the io.TextIOWrapper docs:
+ # On output, if newline is None, any '\n' characters written
+ # are translated to the system default line separator.
+ # By default sys.stdout.write() expects '\n' newlines and then
+ # translates them so this is still cross platform.
+ if line[-1] == '\n':
+ self.logger.log(self.log_level, line.rstrip())
+ else:
+ self.linebuf += line
+
+ def flush(self):
+ if self.linebuf != '':
+ self.logger.log(self.log_level, self.linebuf.rstrip())
+ self.linebuf = ''
+
+
+def disable_torch_init():
+ """
+ Disable the redundant torch default initialization to accelerate model creation.
+ """
+ import torch
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
+
+
+def violates_moderation(text):
+ """
+ Check whether the text violates OpenAI moderation API.
+ """
+ url = "https://api.openai.com/v1/moderations"
+ headers = {"Content-Type": "application/json",
+ "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
+ text = text.replace("\n", "")
+ data = "{" + '"input": ' + f'"{text}"' + "}"
+ data = data.encode("utf-8")
+ try:
+ ret = requests.post(url, headers=headers, data=data, timeout=5)
+ flagged = ret.json()["results"][0]["flagged"]
+ except requests.exceptions.RequestException as e:
+ flagged = False
+ except KeyError as e:
+ flagged = False
+
+ return flagged
+
+
+def pretty_print_semaphore(semaphore):
+ if semaphore is None:
+ return "None"
+ return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
+
+
+
+
+# Modified from github.com/openai/CLIP
+import gzip
+import html
+import os
+from functools import lru_cache
+
+import ftfy
+import regex as re
+import torch
+
+
+@lru_cache()
+def default_bpe():
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
+
+
+@lru_cache()
+def bytes_to_unicode():
+ """
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
+ The reversible bpe codes work on unicode strings.
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
+ """
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("ยก"), ord("ยฌ")+1))+list(range(ord("ยฎ"), ord("รฟ")+1))
+ cs = bs[:]
+ n = 0
+ for b in range(2**8):
+ if b not in bs:
+ bs.append(b)
+ cs.append(2**8+n)
+ n += 1
+ cs = [chr(n) for n in cs]
+ return dict(zip(bs, cs))
+
+
+def get_pairs(word):
+ """Return set of symbol pairs in a word.
+ Word is represented as tuple of symbols (symbols being variable-length strings).
+ """
+ pairs = set()
+ prev_char = word[0]
+ for char in word[1:]:
+ pairs.add((prev_char, char))
+ prev_char = char
+ return pairs
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r'\s+', ' ', text)
+ text = text.strip()
+ return text
+
+
+class SimpleTokenizer(object):
+ def __init__(self, bpe_path: str = default_bpe()):
+ self.byte_encoder = bytes_to_unicode()
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
+ merges = merges[1:49152-256-2+1]
+ merges = [tuple(merge.split()) for merge in merges]
+ vocab = list(bytes_to_unicode().values())
+ vocab = vocab + [v+'' for v in vocab]
+ for merge in merges:
+ vocab.append(''.join(merge))
+ vocab.extend(['<|startoftext|>', '<|endoftext|>'])
+ self.encoder = dict(zip(vocab, range(len(vocab))))
+ self.decoder = {v: k for k, v in self.encoder.items()}
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
+ self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
+ self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
+
+ def bpe(self, token):
+ if token in self.cache:
+ return self.cache[token]
+ word = tuple(token[:-1]) + ( token[-1] + '',)
+ pairs = get_pairs(word)
+
+ if not pairs:
+ return token+''
+
+ while True:
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
+ if bigram not in self.bpe_ranks:
+ break
+ first, second = bigram
+ new_word = []
+ i = 0
+ while i < len(word):
+ try:
+ j = word.index(first, i)
+ new_word.extend(word[i:j])
+ i = j
+ except:
+ new_word.extend(word[i:])
+ break
+
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
+ new_word.append(first+second)
+ i += 2
+ else:
+ new_word.append(word[i])
+ i += 1
+ new_word = tuple(new_word)
+ word = new_word
+ if len(word) == 1:
+ break
+ else:
+ pairs = get_pairs(word)
+ word = ' '.join(word)
+ self.cache[token] = word
+ return word
+
+ def encode(self, text):
+ bpe_tokens = []
+ text = whitespace_clean(basic_clean(text)).lower()
+ for token in re.findall(self.pat, text):
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
+ return bpe_tokens
+
+ def decode(self, tokens):
+ text = ''.join([self.decoder[token] for token in tokens])
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
+ return text
+
+ def __call__(self, texts, context_length=77):
+ if isinstance(texts, str):
+ texts = [texts]
+
+ sot_token = self.encoder["<|startoftext|>"]
+ eot_token = self.encoder["<|endoftext|>"]
+ all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts]
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
+
+ for i, tokens in enumerate(all_tokens):
+ tokens = tokens[:context_length]
+ result[i, :len(tokens)] = torch.tensor(tokens)
+
+ if len(result) == 1:
+ return result[0]
+ return result
\ No newline at end of file
diff --git a/llava/serve/__init__.py b/llava/serve/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/llava/serve/__pycache__/__init__.cpython-310.pyc b/llava/serve/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e1e2e33bfbedf0874365ae00efc2a484a4035d46
--- /dev/null
+++ b/llava/serve/__pycache__/__init__.cpython-310.pyc
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cc49de6dc8e7b21b604f1656c83403b9fd4cdfac7df84b6d599565312c9d6cda
+size 148
diff --git a/llava/serve/__pycache__/cli.cpython-310.pyc b/llava/serve/__pycache__/cli.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d073b7a7573e708a9754d67f21d1c38769e3e301
--- /dev/null
+++ b/llava/serve/__pycache__/cli.cpython-310.pyc
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4debee0234ab6dc6d578d4acf65777a3a171ab9f3f21f10ece0e94c624f20d3e
+size 4987
diff --git a/llava/serve/controller.py b/llava/serve/controller.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4bf1b4c47ccdb1401b18f8397868ec016d1c43a
--- /dev/null
+++ b/llava/serve/controller.py
@@ -0,0 +1,298 @@
+"""
+A controller manages distributed workers.
+It sends worker addresses to clients.
+"""
+import argparse
+import asyncio
+import dataclasses
+from enum import Enum, auto
+import json
+import logging
+import time
+from typing import List, Union
+import threading
+
+from fastapi import FastAPI, Request
+from fastapi.responses import StreamingResponse
+import numpy as np
+import requests
+import uvicorn
+
+from llava.constants import CONTROLLER_HEART_BEAT_EXPIRATION
+from llava.utils import build_logger, server_error_msg
+
+
+logger = build_logger("controller", "controller.log")
+
+
+class DispatchMethod(Enum):
+ LOTTERY = auto()
+ SHORTEST_QUEUE = auto()
+
+ @classmethod
+ def from_str(cls, name):
+ if name == "lottery":
+ return cls.LOTTERY
+ elif name == "shortest_queue":
+ return cls.SHORTEST_QUEUE
+ else:
+ raise ValueError(f"Invalid dispatch method")
+
+
+@dataclasses.dataclass
+class WorkerInfo:
+ model_names: List[str]
+ speed: int
+ queue_length: int
+ check_heart_beat: bool
+ last_heart_beat: str
+
+
+def heart_beat_controller(controller):
+ while True:
+ time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
+ controller.remove_stable_workers_by_expiration()
+
+
+class Controller:
+ def __init__(self, dispatch_method: str):
+ # Dict[str -> WorkerInfo]
+ self.worker_info = {}
+ self.dispatch_method = DispatchMethod.from_str(dispatch_method)
+
+ self.heart_beat_thread = threading.Thread(
+ target=heart_beat_controller, args=(self,), daemon=True)
+ self.heart_beat_thread.start()
+
+ logger.info("Init controller")
+
+ def register_worker(self, worker_name: str, check_heart_beat: bool,
+ worker_status: dict):
+ if worker_name not in self.worker_info:
+ logger.info(f"Register a new worker: {worker_name}")
+ else:
+ logger.info(f"Register an existing worker: {worker_name}")
+
+ if not worker_status:
+ worker_status = self.get_worker_status(worker_name)
+ if not worker_status:
+ return False
+
+ self.worker_info[worker_name] = WorkerInfo(
+ worker_status["model_names"], worker_status["speed"], worker_status["queue_length"],
+ check_heart_beat, time.time())
+
+ logger.info(f"Register done: {worker_name}, {worker_status}")
+ return True
+
+ def get_worker_status(self, worker_name: str):
+ try:
+ r = requests.post(worker_name + "/worker_get_status", timeout=5)
+ except requests.exceptions.RequestException as e:
+ logger.error(f"Get status fails: {worker_name}, {e}")
+ return None
+
+ if r.status_code != 200:
+ logger.error(f"Get status fails: {worker_name}, {r}")
+ return None
+
+ return r.json()
+
+ def remove_worker(self, worker_name: str):
+ del self.worker_info[worker_name]
+
+ def refresh_all_workers(self):
+ old_info = dict(self.worker_info)
+ self.worker_info = {}
+
+ for w_name, w_info in old_info.items():
+ if not self.register_worker(w_name, w_info.check_heart_beat, None):
+ logger.info(f"Remove stale worker: {w_name}")
+
+ def list_models(self):
+ model_names = set()
+
+ for w_name, w_info in self.worker_info.items():
+ model_names.update(w_info.model_names)
+
+ return list(model_names)
+
+ def get_worker_address(self, model_name: str):
+ if self.dispatch_method == DispatchMethod.LOTTERY:
+ worker_names = []
+ worker_speeds = []
+ for w_name, w_info in self.worker_info.items():
+ if model_name in w_info.model_names:
+ worker_names.append(w_name)
+ worker_speeds.append(w_info.speed)
+ worker_speeds = np.array(worker_speeds, dtype=np.float32)
+ norm = np.sum(worker_speeds)
+ if norm < 1e-4:
+ return ""
+ worker_speeds = worker_speeds / norm
+ if True: # Directly return address
+ pt = np.random.choice(np.arange(len(worker_names)),
+ p=worker_speeds)
+ worker_name = worker_names[pt]
+ return worker_name
+
+ # Check status before returning
+ while True:
+ pt = np.random.choice(np.arange(len(worker_names)),
+ p=worker_speeds)
+ worker_name = worker_names[pt]
+
+ if self.get_worker_status(worker_name):
+ break
+ else:
+ self.remove_worker(worker_name)
+ worker_speeds[pt] = 0
+ norm = np.sum(worker_speeds)
+ if norm < 1e-4:
+ return ""
+ worker_speeds = worker_speeds / norm
+ continue
+ return worker_name
+ elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:
+ worker_names = []
+ worker_qlen = []
+ for w_name, w_info in self.worker_info.items():
+ if model_name in w_info.model_names:
+ worker_names.append(w_name)
+ worker_qlen.append(w_info.queue_length / w_info.speed)
+ if len(worker_names) == 0:
+ return ""
+ min_index = np.argmin(worker_qlen)
+ w_name = worker_names[min_index]
+ self.worker_info[w_name].queue_length += 1
+ logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}")
+ return w_name
+ else:
+ raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
+
+ def receive_heart_beat(self, worker_name: str, queue_length: int):
+ if worker_name not in self.worker_info:
+ logger.info(f"Receive unknown heart beat. {worker_name}")
+ return False
+
+ self.worker_info[worker_name].queue_length = queue_length
+ self.worker_info[worker_name].last_heart_beat = time.time()
+ logger.info(f"Receive heart beat. {worker_name}")
+ return True
+
+ def remove_stable_workers_by_expiration(self):
+ expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
+ to_delete = []
+ for worker_name, w_info in self.worker_info.items():
+ if w_info.check_heart_beat and w_info.last_heart_beat < expire:
+ to_delete.append(worker_name)
+
+ for worker_name in to_delete:
+ self.remove_worker(worker_name)
+
+ def worker_api_generate_stream(self, params):
+ worker_addr = self.get_worker_address(params["model"])
+ if not worker_addr:
+ logger.info(f"no worker: {params['model']}")
+ ret = {
+ "text": server_error_msg,
+ "error_code": 2,
+ }
+ yield json.dumps(ret).encode() + b"\0"
+
+ try:
+ response = requests.post(worker_addr + "/worker_generate_stream",
+ json=params, stream=True, timeout=5)
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
+ if chunk:
+ yield chunk + b"\0"
+ except requests.exceptions.RequestException as e:
+ logger.info(f"worker timeout: {worker_addr}")
+ ret = {
+ "text": server_error_msg,
+ "error_code": 3,
+ }
+ yield json.dumps(ret).encode() + b"\0"
+
+
+ # Let the controller act as a worker to achieve hierarchical
+ # management. This can be used to connect isolated sub networks.
+ def worker_api_get_status(self):
+ model_names = set()
+ speed = 0
+ queue_length = 0
+
+ for w_name in self.worker_info:
+ worker_status = self.get_worker_status(w_name)
+ if worker_status is not None:
+ model_names.update(worker_status["model_names"])
+ speed += worker_status["speed"]
+ queue_length += worker_status["queue_length"]
+
+ return {
+ "model_names": list(model_names),
+ "speed": speed,
+ "queue_length": queue_length,
+ }
+
+
+app = FastAPI()
+
+
+@app.post("/register_worker")
+async def register_worker(request: Request):
+ data = await request.json()
+ controller.register_worker(
+ data["worker_name"], data["check_heart_beat"],
+ data.get("worker_status", None))
+
+
+@app.post("/refresh_all_workers")
+async def refresh_all_workers():
+ models = controller.refresh_all_workers()
+
+
+@app.post("/list_models")
+async def list_models():
+ models = controller.list_models()
+ return {"models": models}
+
+
+@app.post("/get_worker_address")
+async def get_worker_address(request: Request):
+ data = await request.json()
+ addr = controller.get_worker_address(data["model"])
+ return {"address": addr}
+
+
+@app.post("/receive_heart_beat")
+async def receive_heart_beat(request: Request):
+ data = await request.json()
+ exist = controller.receive_heart_beat(
+ data["worker_name"], data["queue_length"])
+ return {"exist": exist}
+
+
+@app.post("/worker_generate_stream")
+async def worker_api_generate_stream(request: Request):
+ params = await request.json()
+ generator = controller.worker_api_generate_stream(params)
+ return StreamingResponse(generator)
+
+
+@app.post("/worker_get_status")
+async def worker_api_get_status(request: Request):
+ return controller.worker_api_get_status()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--host", type=str, default="localhost")
+ parser.add_argument("--port", type=int, default=21001)
+ parser.add_argument("--dispatch-method", type=str, choices=[
+ "lottery", "shortest_queue"], default="shortest_queue")
+ args = parser.parse_args()
+ logger.info(f"args: {args}")
+
+ controller = Controller(args.dispatch_method)
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
diff --git a/llava/serve/examples/extreme_ironing.jpg b/llava/serve/examples/extreme_ironing.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..cf1071a1fbfa904309335e3521cecbcec341b37f
--- /dev/null
+++ b/llava/serve/examples/extreme_ironing.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a54caa21bc513ed25c8ca7f5747555c05dfd4e33f6a3cf5c08b3d9138a4da1d9
+size 62587
diff --git a/llava/serve/examples/waterview.jpg b/llava/serve/examples/waterview.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..5ea03ee6fa60f4025999012b817e674984c706cd
--- /dev/null
+++ b/llava/serve/examples/waterview.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d092764cc9f21b9bc535ff5284b5add4d8256148bab1bc2f5b5ab3fd32759a36
+size 95499
diff --git a/llava/serve/gradio_web_server.py b/llava/serve/gradio_web_server.py
new file mode 100644
index 0000000000000000000000000000000000000000..c07efc122950da37455608b609dcf1f2b4103d56
--- /dev/null
+++ b/llava/serve/gradio_web_server.py
@@ -0,0 +1,479 @@
+import argparse
+import datetime
+import json
+import os
+import time
+
+import gradio as gr
+import requests
+
+from llava.conversation import (default_conversation, conv_templates,
+ SeparatorStyle)
+from llava.constants import LOGDIR
+from llava.utils import (build_logger, server_error_msg,
+ violates_moderation, moderation_msg)
+import hashlib
+
+
+logger = build_logger("gradio_web_server", "gradio_web_server.log")
+
+headers = {"User-Agent": "LLaVA Client"}
+
+no_change_btn = gr.Button()
+enable_btn = gr.Button(interactive=True)
+disable_btn = gr.Button(interactive=False)
+
+priority = {
+ "vicuna-13b": "aaaaaaa",
+ "koala-13b": "aaaaaab",
+}
+
+
+def get_conv_log_filename():
+ t = datetime.datetime.now()
+ name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
+ return name
+
+
+def get_model_list():
+ ret = requests.post(args.controller_url + "/refresh_all_workers")
+ assert ret.status_code == 200
+ ret = requests.post(args.controller_url + "/list_models")
+ models = ret.json()["models"]
+ models.sort(key=lambda x: priority.get(x, x))
+ logger.info(f"Models: {models}")
+ return models
+
+
+get_window_url_params = """
+function() {
+ const params = new URLSearchParams(window.location.search);
+ url_params = Object.fromEntries(params);
+ console.log(url_params);
+ return url_params;
+ }
+"""
+
+
+def load_demo(url_params, request: gr.Request):
+ logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
+
+ dropdown_update = gr.Dropdown(visible=True)
+ if "model" in url_params:
+ model = url_params["model"]
+ if model in models:
+ dropdown_update = gr.Dropdown(value=model, visible=True)
+
+ state = default_conversation.copy()
+ return state, dropdown_update
+
+
+def load_demo_refresh_model_list(request: gr.Request):
+ logger.info(f"load_demo. ip: {request.client.host}")
+ models = get_model_list()
+ state = default_conversation.copy()
+ dropdown_update = gr.Dropdown(
+ choices=models,
+ value=models[0] if len(models) > 0 else ""
+ )
+ return state, dropdown_update
+
+
+def vote_last_response(state, vote_type, model_selector, request: gr.Request):
+ with open(get_conv_log_filename(), "a") as fout:
+ data = {
+ "tstamp": round(time.time(), 4),
+ "type": vote_type,
+ "model": model_selector,
+ "state": state.dict(),
+ "ip": request.client.host,
+ }
+ fout.write(json.dumps(data) + "\n")
+
+
+def upvote_last_response(state, model_selector, request: gr.Request):
+ logger.info(f"upvote. ip: {request.client.host}")
+ vote_last_response(state, "upvote", model_selector, request)
+ return ("",) + (disable_btn,) * 3
+
+
+def downvote_last_response(state, model_selector, request: gr.Request):
+ logger.info(f"downvote. ip: {request.client.host}")
+ vote_last_response(state, "downvote", model_selector, request)
+ return ("",) + (disable_btn,) * 3
+
+
+def flag_last_response(state, model_selector, request: gr.Request):
+ logger.info(f"flag. ip: {request.client.host}")
+ vote_last_response(state, "flag", model_selector, request)
+ return ("",) + (disable_btn,) * 3
+
+
+def regenerate(state, image_process_mode, request: gr.Request):
+ logger.info(f"regenerate. ip: {request.client.host}")
+ state.messages[-1][-1] = None
+ prev_human_msg = state.messages[-2]
+ if type(prev_human_msg[1]) in (tuple, list):
+ prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
+ state.skip_next = False
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
+
+
+def clear_history(request: gr.Request):
+ logger.info(f"clear_history. ip: {request.client.host}")
+ state = default_conversation.copy()
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
+
+
+def add_text(state, text, image, image_process_mode, request: gr.Request):
+ logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
+ if len(text) <= 0 and image is None:
+ state.skip_next = True
+ return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
+ if args.moderate:
+ flagged = violates_moderation(text)
+ if flagged:
+ state.skip_next = True
+ return (state, state.to_gradio_chatbot(), moderation_msg, None) + (
+ no_change_btn,) * 5
+
+ text = text[:1536] # Hard cut-off
+ if image is not None:
+ text = text[:1200] # Hard cut-off for images
+ if '' not in text:
+ # text = '' + text
+ text = text + '\n'
+ text = (text, image, image_process_mode)
+ state = default_conversation.copy()
+ state.append_message(state.roles[0], text)
+ state.append_message(state.roles[1], None)
+ state.skip_next = False
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
+
+
+def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
+ logger.info(f"http_bot. ip: {request.client.host}")
+ start_tstamp = time.time()
+ model_name = model_selector
+
+ if state.skip_next:
+ # This generate call is skipped due to invalid inputs
+ yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
+ return
+
+ if len(state.messages) == state.offset + 2:
+ # First round of conversation
+ if "llava" in model_name.lower():
+ if 'llama-2' in model_name.lower():
+ template_name = "llava_llama_2"
+ elif "mistral" in model_name.lower() or "mixtral" in model_name.lower():
+ if 'orca' in model_name.lower():
+ template_name = "mistral_orca"
+ elif 'hermes' in model_name.lower():
+ template_name = "chatml_direct"
+ else:
+ template_name = "mistral_instruct"
+ elif 'llava-v1.6-34b' in model_name.lower():
+ template_name = "chatml_direct"
+ elif "v1" in model_name.lower():
+ if 'mmtag' in model_name.lower():
+ template_name = "v1_mmtag"
+ elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
+ template_name = "v1_mmtag"
+ else:
+ template_name = "llava_v1"
+ elif "mpt" in model_name.lower():
+ template_name = "mpt"
+ else:
+ if 'mmtag' in model_name.lower():
+ template_name = "v0_mmtag"
+ elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
+ template_name = "v0_mmtag"
+ else:
+ template_name = "llava_v0"
+ elif "mpt" in model_name:
+ template_name = "mpt_text"
+ elif "llama-2" in model_name:
+ template_name = "llama_2"
+ else:
+ template_name = "vicuna_v1"
+ new_state = conv_templates[template_name].copy()
+ new_state.append_message(new_state.roles[0], state.messages[-2][1])
+ new_state.append_message(new_state.roles[1], None)
+ state = new_state
+
+ # Query worker address
+ controller_url = args.controller_url
+ ret = requests.post(controller_url + "/get_worker_address",
+ json={"model": model_name})
+ worker_addr = ret.json()["address"]
+ logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
+
+ # No available worker
+ if worker_addr == "":
+ state.messages[-1][-1] = server_error_msg
+ yield (state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
+ return
+
+ # Construct prompt
+ prompt = state.get_prompt()
+
+ all_images = state.get_images(return_pil=True)
+ all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
+ for image, hash in zip(all_images, all_image_hash):
+ t = datetime.datetime.now()
+ filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
+ if not os.path.isfile(filename):
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
+ image.save(filename)
+
+ # Make requests
+ pload = {
+ "model": model_name,
+ "prompt": prompt,
+ "temperature": float(temperature),
+ "top_p": float(top_p),
+ "max_new_tokens": min(int(max_new_tokens), 1536),
+ "stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2,
+ "images": f'List of {len(state.get_images())} images: {all_image_hash}',
+ }
+ logger.info(f"==== request ====\n{pload}")
+
+ pload['images'] = state.get_images()
+
+ state.messages[-1][-1] = "โ"
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
+
+ try:
+ # Stream output
+ response = requests.post(worker_addr + "/worker_generate_stream",
+ headers=headers, json=pload, stream=True, timeout=10)
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
+ if chunk:
+ data = json.loads(chunk.decode())
+ if data["error_code"] == 0:
+ output = data["text"][len(prompt):].strip()
+ state.messages[-1][-1] = output + "โ"
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
+ else:
+ output = data["text"] + f" (error_code: {data['error_code']})"
+ state.messages[-1][-1] = output
+ yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
+ return
+ time.sleep(0.03)
+ except requests.exceptions.RequestException as e:
+ state.messages[-1][-1] = server_error_msg
+ yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
+ return
+
+ state.messages[-1][-1] = state.messages[-1][-1][:-1]
+ yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
+
+ finish_tstamp = time.time()
+ logger.info(f"{output}")
+
+ with open(get_conv_log_filename(), "a") as fout:
+ data = {
+ "tstamp": round(finish_tstamp, 4),
+ "type": "chat",
+ "model": model_name,
+ "start": round(start_tstamp, 4),
+ "finish": round(finish_tstamp, 4),
+ "state": state.dict(),
+ "images": all_image_hash,
+ "ip": request.client.host,
+ }
+ fout.write(json.dumps(data) + "\n")
+
+title_markdown = ("""
+# ๐ LLaVA: Large Language and Vision Assistant
+[[Project Page](https://llava-vl.github.io)] [[Code](https://github.com/haotian-liu/LLaVA)] [[Model](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)] | ๐ [[LLaVA](https://arxiv.org/abs/2304.08485)] [[LLaVA-v1.5](https://arxiv.org/abs/2310.03744)] [[LLaVA-v1.6](https://llava-vl.github.io/blog/2024-01-30-llava-1-6/)]
+""")
+
+tos_markdown = ("""
+### Terms of use
+By using this service, users are required to agree to the following terms:
+The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
+Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
+For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
+""")
+
+
+learn_more_markdown = ("""
+### License
+The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
+""")
+
+block_css = """
+
+#buttons button {
+ min-width: min(120px,100%);
+}
+
+"""
+
+def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
+ textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
+ with gr.Blocks(title="LLaVA", theme=gr.themes.Default(), css=block_css) as demo:
+ state = gr.State()
+
+ if not embed_mode:
+ gr.Markdown(title_markdown)
+
+ with gr.Row():
+ with gr.Column(scale=3):
+ with gr.Row(elem_id="model_selector_row"):
+ model_selector = gr.Dropdown(
+ choices=models,
+ value=models[0] if len(models) > 0 else "",
+ interactive=True,
+ show_label=False,
+ container=False)
+
+ imagebox = gr.Image(type="pil")
+ image_process_mode = gr.Radio(
+ ["Crop", "Resize", "Pad", "Default"],
+ value="Default",
+ label="Preprocess for non-square image", visible=False)
+
+ if cur_dir is None:
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
+ gr.Examples(examples=[
+ [f"{cur_dir}/examples/extreme_ironing.jpg", "What is unusual about this image?"],
+ [f"{cur_dir}/examples/waterview.jpg", "What are the things I should be cautious about when I visit here?"],
+ ], inputs=[imagebox, textbox])
+
+ with gr.Accordion("Parameters", open=False) as parameter_row:
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",)
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
+ max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
+
+ with gr.Column(scale=8):
+ chatbot = gr.Chatbot(
+ elem_id="chatbot",
+ label="LLaVA Chatbot",
+ height=650,
+ layout="panel",
+ )
+ with gr.Row():
+ with gr.Column(scale=8):
+ textbox.render()
+ with gr.Column(scale=1, min_width=50):
+ submit_btn = gr.Button(value="Send", variant="primary")
+ with gr.Row(elem_id="buttons") as button_row:
+ upvote_btn = gr.Button(value="๐ Upvote", interactive=False)
+ downvote_btn = gr.Button(value="๐ Downvote", interactive=False)
+ flag_btn = gr.Button(value="โ ๏ธ Flag", interactive=False)
+ #stop_btn = gr.Button(value="โน๏ธ Stop Generation", interactive=False)
+ regenerate_btn = gr.Button(value="๐ Regenerate", interactive=False)
+ clear_btn = gr.Button(value="๐๏ธ Clear", interactive=False)
+
+ if not embed_mode:
+ gr.Markdown(tos_markdown)
+ gr.Markdown(learn_more_markdown)
+ url_params = gr.JSON(visible=False)
+
+ # Register listeners
+ btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
+ upvote_btn.click(
+ upvote_last_response,
+ [state, model_selector],
+ [textbox, upvote_btn, downvote_btn, flag_btn]
+ )
+ downvote_btn.click(
+ downvote_last_response,
+ [state, model_selector],
+ [textbox, upvote_btn, downvote_btn, flag_btn]
+ )
+ flag_btn.click(
+ flag_last_response,
+ [state, model_selector],
+ [textbox, upvote_btn, downvote_btn, flag_btn]
+ )
+
+ regenerate_btn.click(
+ regenerate,
+ [state, image_process_mode],
+ [state, chatbot, textbox, imagebox] + btn_list
+ ).then(
+ http_bot,
+ [state, model_selector, temperature, top_p, max_output_tokens],
+ [state, chatbot] + btn_list,
+ concurrency_limit=concurrency_count
+ )
+
+ clear_btn.click(
+ clear_history,
+ None,
+ [state, chatbot, textbox, imagebox] + btn_list,
+ queue=False
+ )
+
+ textbox.submit(
+ add_text,
+ [state, textbox, imagebox, image_process_mode],
+ [state, chatbot, textbox, imagebox] + btn_list,
+ queue=False
+ ).then(
+ http_bot,
+ [state, model_selector, temperature, top_p, max_output_tokens],
+ [state, chatbot] + btn_list,
+ concurrency_limit=concurrency_count
+ )
+
+ submit_btn.click(
+ add_text,
+ [state, textbox, imagebox, image_process_mode],
+ [state, chatbot, textbox, imagebox] + btn_list
+ ).then(
+ http_bot,
+ [state, model_selector, temperature, top_p, max_output_tokens],
+ [state, chatbot] + btn_list,
+ concurrency_limit=concurrency_count
+ )
+
+ if args.model_list_mode == "once":
+ demo.load(
+ load_demo,
+ [url_params],
+ [state, model_selector],
+ js=get_window_url_params
+ )
+ elif args.model_list_mode == "reload":
+ demo.load(
+ load_demo_refresh_model_list,
+ None,
+ [state, model_selector],
+ queue=False
+ )
+ else:
+ raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
+
+ return demo
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--host", type=str, default="0.0.0.0")
+ parser.add_argument("--port", type=int)
+ parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
+ parser.add_argument("--concurrency-count", type=int, default=16)
+ parser.add_argument("--model-list-mode", type=str, default="once",
+ choices=["once", "reload"])
+ parser.add_argument("--share", action="store_true")
+ parser.add_argument("--moderate", action="store_true")
+ parser.add_argument("--embed", action="store_true")
+ args = parser.parse_args()
+ logger.info(f"args: {args}")
+
+ models = get_model_list()
+
+ logger.info(args)
+ demo = build_demo(args.embed, concurrency_count=args.concurrency_count)
+ demo.queue(
+ api_open=False
+ ).launch(
+ server_name=args.host,
+ server_port=args.port,
+ share=args.share
+ )
diff --git a/llava/serve/model_worker.py b/llava/serve/model_worker.py
new file mode 100644
index 0000000000000000000000000000000000000000..9144329893c51f402ff2e2f65d9fb7baf177bd52
--- /dev/null
+++ b/llava/serve/model_worker.py
@@ -0,0 +1,288 @@
+"""
+A model worker executes the model.
+"""
+import argparse
+import asyncio
+import json
+import time
+import threading
+import uuid
+
+from fastapi import FastAPI, Request, BackgroundTasks
+from fastapi.responses import StreamingResponse
+import requests
+import torch
+import uvicorn
+from functools import partial
+
+from llava.constants import WORKER_HEART_BEAT_INTERVAL
+from llava.utils import (build_logger, server_error_msg,
+ pretty_print_semaphore)
+from llava.model.builder import load_pretrained_model
+from llava.mm_utils import process_images, load_image_from_base64, tokenizer_image_token
+from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
+from transformers import TextIteratorStreamer
+from threading import Thread
+
+
+GB = 1 << 30
+
+worker_id = str(uuid.uuid4())[:6]
+logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
+global_counter = 0
+
+model_semaphore = None
+
+
+def heart_beat_worker(controller):
+
+ while True:
+ time.sleep(WORKER_HEART_BEAT_INTERVAL)
+ controller.send_heart_beat()
+
+
+class ModelWorker:
+ def __init__(self, controller_addr, worker_addr,
+ worker_id, no_register,
+ model_path, model_base, model_name,
+ load_8bit, load_4bit, device, use_flash_attn=False):
+ self.controller_addr = controller_addr
+ self.worker_addr = worker_addr
+ self.worker_id = worker_id
+ if model_path.endswith("/"):
+ model_path = model_path[:-1]
+ if model_name is None:
+ model_paths = model_path.split("/")
+ if model_paths[-1].startswith('checkpoint-'):
+ self.model_name = model_paths[-2] + "_" + model_paths[-1]
+ else:
+ self.model_name = model_paths[-1]
+ else:
+ self.model_name = model_name
+
+ self.device = device
+ logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
+ self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
+ model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device, use_flash_attn=use_flash_attn)
+ self.is_multimodal = 'llava' in self.model_name.lower()
+
+ if not no_register:
+ self.register_to_controller()
+ self.heart_beat_thread = threading.Thread(
+ target=heart_beat_worker, args=(self,), daemon=True)
+ self.heart_beat_thread.start()
+
+ def register_to_controller(self):
+ logger.info("Register to controller")
+
+ url = self.controller_addr + "/register_worker"
+ data = {
+ "worker_name": self.worker_addr,
+ "check_heart_beat": True,
+ "worker_status": self.get_status()
+ }
+ r = requests.post(url, json=data)
+ assert r.status_code == 200
+
+ def send_heart_beat(self):
+ logger.info(f"Send heart beat. Models: {[self.model_name]}. "
+ f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
+ f"global_counter: {global_counter}")
+
+ url = self.controller_addr + "/receive_heart_beat"
+
+ while True:
+ try:
+ ret = requests.post(url, json={
+ "worker_name": self.worker_addr,
+ "queue_length": self.get_queue_length()}, timeout=5)
+ exist = ret.json()["exist"]
+ break
+ except requests.exceptions.RequestException as e:
+ logger.error(f"heart beat error: {e}")
+ time.sleep(5)
+
+ if not exist:
+ self.register_to_controller()
+
+ def get_queue_length(self):
+ if model_semaphore is None:
+ return 0
+ else:
+ return args.limit_model_concurrency - model_semaphore._value + (len(
+ model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
+
+ def get_status(self):
+ return {
+ "model_names": [self.model_name],
+ "speed": 1,
+ "queue_length": self.get_queue_length(),
+ }
+
+ @torch.inference_mode()
+ def generate_stream(self, params):
+ tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
+
+ prompt = params["prompt"]
+ ori_prompt = prompt
+ images = params.get("images", None)
+ num_image_tokens = 0
+ if images is not None and len(images) > 0 and self.is_multimodal:
+ if len(images) > 0:
+ if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
+ raise ValueError("Number of images does not match number of tokens in prompt")
+
+ images = [load_image_from_base64(image) for image in images]
+ image_sizes = [image.size for image in images]
+ images = process_images(images, image_processor, model.config)
+
+ if type(images) is list:
+ images = [image.to(self.model.device, dtype=torch.float16) for image in images]
+ else:
+ images = images.to(self.model.device, dtype=torch.float16)
+
+ replace_token = DEFAULT_IMAGE_TOKEN
+ if getattr(self.model.config, 'mm_use_im_start_end', False):
+ replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
+ prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
+
+ num_image_tokens = prompt.count(replace_token) * model.get_vision_tower().num_patches
+ else:
+ images = None
+ image_sizes = None
+ image_args = {"images": images, "image_sizes": image_sizes}
+ else:
+ images = None
+ image_args = {}
+
+ temperature = float(params.get("temperature", 1.0))
+ top_p = float(params.get("top_p", 1.0))
+ max_context_length = getattr(model.config, 'max_position_embeddings', 2048)
+ max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
+ stop_str = params.get("stop", None)
+ do_sample = True if temperature > 0.001 else False
+
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
+ keywords = [stop_str]
+ # stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
+
+ max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
+
+ if max_new_tokens < 1:
+ yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
+ return
+
+ thread = Thread(target=model.generate, kwargs=dict(
+ inputs=input_ids,
+ do_sample=do_sample,
+ temperature=temperature,
+ top_p=top_p,
+ max_new_tokens=max_new_tokens,
+ streamer=streamer,
+ use_cache=True,
+ **image_args
+ ))
+ thread.start()
+
+ generated_text = ori_prompt
+ for new_text in streamer:
+ generated_text += new_text
+ if generated_text.endswith(stop_str):
+ generated_text = generated_text[:-len(stop_str)]
+ yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
+
+ def generate_stream_gate(self, params):
+ try:
+ for x in self.generate_stream(params):
+ yield x
+ except ValueError as e:
+ print("Caught ValueError:", e)
+ ret = {
+ "text": server_error_msg,
+ "error_code": 1,
+ }
+ yield json.dumps(ret).encode() + b"\0"
+ except torch.cuda.CudaError as e:
+ print("Caught torch.cuda.CudaError:", e)
+ ret = {
+ "text": server_error_msg,
+ "error_code": 1,
+ }
+ yield json.dumps(ret).encode() + b"\0"
+ except Exception as e:
+ print("Caught Unknown Error", e)
+ ret = {
+ "text": server_error_msg,
+ "error_code": 1,
+ }
+ yield json.dumps(ret).encode() + b"\0"
+
+
+app = FastAPI()
+
+
+def release_model_semaphore(fn=None):
+ model_semaphore.release()
+ if fn is not None:
+ fn()
+
+
+@app.post("/worker_generate_stream")
+async def generate_stream(request: Request):
+ global model_semaphore, global_counter
+ global_counter += 1
+ params = await request.json()
+
+ if model_semaphore is None:
+ model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
+ await model_semaphore.acquire()
+ worker.send_heart_beat()
+ generator = worker.generate_stream_gate(params)
+ background_tasks = BackgroundTasks()
+ background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
+ return StreamingResponse(generator, background=background_tasks)
+
+
+@app.post("/worker_get_status")
+async def get_status(request: Request):
+ return worker.get_status()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--host", type=str, default="localhost")
+ parser.add_argument("--port", type=int, default=21002)
+ parser.add_argument("--worker-address", type=str,
+ default="http://localhost:21002")
+ parser.add_argument("--controller-address", type=str,
+ default="http://localhost:21001")
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
+ parser.add_argument("--model-base", type=str, default=None)
+ parser.add_argument("--model-name", type=str)
+ parser.add_argument("--device", type=str, default="cuda")
+ parser.add_argument("--multi-modal", action="store_true", help="Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
+ parser.add_argument("--limit-model-concurrency", type=int, default=5)
+ parser.add_argument("--stream-interval", type=int, default=1)
+ parser.add_argument("--no-register", action="store_true")
+ parser.add_argument("--load-8bit", action="store_true")
+ parser.add_argument("--load-4bit", action="store_true")
+ parser.add_argument("--use-flash-attn", action="store_true")
+ args = parser.parse_args()
+ logger.info(f"args: {args}")
+
+ if args.multi_modal:
+ logger.warning("Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
+
+ worker = ModelWorker(args.controller_address,
+ args.worker_address,
+ worker_id,
+ args.no_register,
+ args.model_path,
+ args.model_base,
+ args.model_name,
+ args.load_8bit,
+ args.load_4bit,
+ args.device,
+ use_flash_attn=args.use_flash_attn)
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
diff --git a/llava/serve/register_worker.py b/llava/serve/register_worker.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c2c40295e0351f25709ba25554c9329f15bf0d2
--- /dev/null
+++ b/llava/serve/register_worker.py
@@ -0,0 +1,26 @@
+"""
+Manually register workers.
+
+Usage:
+python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002
+"""
+
+import argparse
+
+import requests
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--controller-address", type=str)
+ parser.add_argument("--worker-name", type=str)
+ parser.add_argument("--check-heart-beat", action="store_true")
+ args = parser.parse_args()
+
+ url = args.controller_address + "/register_worker"
+ data = {
+ "worker_name": args.worker_name,
+ "check_heart_beat": args.check_heart_beat,
+ "worker_status": None,
+ }
+ r = requests.post(url, json=data)
+ assert r.status_code == 200
diff --git a/llava/serve/sglang_worker.py b/llava/serve/sglang_worker.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3297b7c295abddedfaac7f6fbe882d7b672487d
--- /dev/null
+++ b/llava/serve/sglang_worker.py
@@ -0,0 +1,244 @@
+"""
+A model worker executes the model.
+"""
+import argparse
+import asyncio
+from concurrent.futures import ThreadPoolExecutor
+import json
+import time
+import threading
+import uuid
+
+from fastapi import FastAPI, Request, BackgroundTasks
+from fastapi.responses import StreamingResponse
+import requests
+import re
+import uvicorn
+from functools import partial
+
+from llava.constants import WORKER_HEART_BEAT_INTERVAL
+from llava.utils import (build_logger, server_error_msg,
+ pretty_print_semaphore)
+from llava.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, expand2square
+from llava.constants import DEFAULT_IMAGE_TOKEN
+
+import sglang as sgl
+from sglang.backend.runtime_endpoint import RuntimeEndpoint
+
+
+GB = 1 << 30
+
+worker_id = str(uuid.uuid4())[:6]
+logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
+global_counter = 0
+
+model_semaphore = None
+
+
+def heart_beat_worker(controller):
+ while True:
+ time.sleep(WORKER_HEART_BEAT_INTERVAL)
+ controller.send_heart_beat()
+
+
+@sgl.function
+def pipeline(s, prompt, max_tokens):
+ for p in prompt:
+ if type(p) is str:
+ s += p
+ else:
+ s += sgl.image(p)
+ s += sgl.gen("response", max_tokens=max_tokens)
+
+
+class ModelWorker:
+ def __init__(self, controller_addr, worker_addr, sgl_endpoint,
+ worker_id, no_register, model_name):
+ self.controller_addr = controller_addr
+ self.worker_addr = worker_addr
+ self.worker_id = worker_id
+
+ # Select backend
+ backend = RuntimeEndpoint(sgl_endpoint)
+ sgl.set_default_backend(backend)
+ model_path = backend.model_info["model_path"]
+
+ if model_path.endswith("/"):
+ model_path = model_path[:-1]
+ if model_name is None:
+ model_paths = model_path.split("/")
+ if model_paths[-1].startswith('checkpoint-'):
+ self.model_name = model_paths[-2] + "_" + model_paths[-1]
+ else:
+ self.model_name = model_paths[-1]
+ else:
+ self.model_name = model_name
+
+ logger.info(f"Loading the SGLANG model {self.model_name} on worker {worker_id} ...")
+
+ if not no_register:
+ self.register_to_controller()
+ self.heart_beat_thread = threading.Thread(
+ target=heart_beat_worker, args=(self,), daemon=True)
+ self.heart_beat_thread.start()
+
+ def register_to_controller(self):
+ logger.info("Register to controller")
+
+ url = self.controller_addr + "/register_worker"
+ data = {
+ "worker_name": self.worker_addr,
+ "check_heart_beat": True,
+ "worker_status": self.get_status()
+ }
+ r = requests.post(url, json=data)
+ assert r.status_code == 200
+
+ def send_heart_beat(self):
+ logger.info(f"Send heart beat. Models: {[self.model_name]}. "
+ f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
+ f"global_counter: {global_counter}")
+
+ url = self.controller_addr + "/receive_heart_beat"
+
+ while True:
+ try:
+ ret = requests.post(url, json={
+ "worker_name": self.worker_addr,
+ "queue_length": self.get_queue_length()}, timeout=5)
+ exist = ret.json()["exist"]
+ break
+ except requests.exceptions.RequestException as e:
+ logger.error(f"heart beat error: {e}")
+ time.sleep(5)
+
+ if not exist:
+ self.register_to_controller()
+
+ def get_queue_length(self):
+ if model_semaphore is None:
+ return 0
+ else:
+ return args.limit_model_concurrency - model_semaphore._value + (len(
+ model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
+
+ def get_status(self):
+ return {
+ "model_names": [self.model_name],
+ "speed": 1,
+ "queue_length": self.get_queue_length(),
+ }
+
+ async def generate_stream(self, params):
+ ori_prompt = prompt = params["prompt"]
+ images = params.get("images", None)
+ if images is not None and len(images) > 0:
+ if len(images) > 0:
+ if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
+ raise ValueError("Number of images does not match number of tokens in prompt")
+
+ images = [load_image_from_base64(image) for image in images]
+
+ # FIXME: for image-start/end token
+ # replace_token = DEFAULT_IMAGE_TOKEN
+ # if getattr(self.model.config, 'mm_use_im_start_end', False):
+ # replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
+ # prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
+ prompt = prompt.replace(' ' + DEFAULT_IMAGE_TOKEN + '\n', DEFAULT_IMAGE_TOKEN)
+ prompt_split = prompt.split(DEFAULT_IMAGE_TOKEN)
+ prompt = []
+ for i in range(len(prompt_split)):
+ prompt.append(prompt_split[i])
+ if i < len(images):
+ prompt.append(images[i])
+ else:
+ prompt = [prompt]
+
+ temperature = float(params.get("temperature", 1.0))
+ top_p = float(params.get("top_p", 1.0))
+ # max_context_length = getattr(model.config, 'max_position_embeddings', 2048)
+ max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
+ stop_str = params.get("stop", None)
+ stop_str = [stop_str] if stop_str is not None else None
+
+ print({'prompt': prompt, 'max_new_tokens': max_new_tokens, 'temperature': temperature, 'top_p': top_p})
+ state = pipeline.run(prompt, max_new_tokens, temperature=temperature, top_p=top_p, stream=True)
+
+ generated_text = ori_prompt
+ async for text_outputs in state.text_async_iter(var_name="response"):
+ generated_text += text_outputs
+ yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
+
+ async def generate_stream_gate(self, params):
+ try:
+ async for x in self.generate_stream(params):
+ yield x
+ except ValueError as e:
+ print("Caught ValueError:", e)
+ ret = {
+ "text": server_error_msg,
+ "error_code": 1,
+ }
+ yield json.dumps(ret).encode() + b"\0"
+ except Exception as e:
+ print("Caught Unknown Error", e)
+ ret = {
+ "text": server_error_msg,
+ "error_code": 1,
+ }
+ yield json.dumps(ret).encode() + b"\0"
+
+
+app = FastAPI()
+
+
+def release_model_semaphore(fn=None):
+ model_semaphore.release()
+ if fn is not None:
+ fn()
+
+
+@app.post("/worker_generate_stream")
+async def generate_stream(request: Request):
+ global model_semaphore, global_counter
+ global_counter += 1
+ params = await request.json()
+
+ if model_semaphore is None:
+ model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
+ await model_semaphore.acquire()
+ worker.send_heart_beat()
+ generator = worker.generate_stream_gate(params)
+ background_tasks = BackgroundTasks()
+ background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
+ return StreamingResponse(generator, background=background_tasks)
+
+
+@app.post("/worker_get_status")
+async def get_status(request: Request):
+ return worker.get_status()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--host", type=str, default="localhost")
+ parser.add_argument("--port", type=int, default=21002)
+ parser.add_argument("--worker-address", type=str,
+ default="http://localhost:21002")
+ parser.add_argument("--controller-address", type=str,
+ default="http://localhost:21001")
+ parser.add_argument("--model-name", type=str)
+ parser.add_argument("--sgl-endpoint", type=str)
+ parser.add_argument("--limit-model-concurrency", type=int, default=5)
+ parser.add_argument("--stream-interval", type=int, default=1)
+ parser.add_argument("--no-register", action="store_true")
+ args = parser.parse_args()
+ logger.info(f"args: {args}")
+
+ worker = ModelWorker(args.controller_address,
+ args.worker_address,
+ args.sgl_endpoint,
+ worker_id,
+ args.no_register,
+ args.model_name)
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
diff --git a/llava/serve/test_message.py b/llava/serve/test_message.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b090faed0e630b03b2294545050f1f4f5032cad
--- /dev/null
+++ b/llava/serve/test_message.py
@@ -0,0 +1,62 @@
+import argparse
+import json
+
+import requests
+
+from llava.conversation import default_conversation
+
+
+def main():
+ if args.worker_address:
+ worker_addr = args.worker_address
+ else:
+ controller_addr = args.controller_address
+ ret = requests.post(controller_addr + "/refresh_all_workers")
+ ret = requests.post(controller_addr + "/list_models")
+ models = ret.json()["models"]
+ models.sort()
+ print(f"Models: {models}")
+
+ ret = requests.post(controller_addr + "/get_worker_address",
+ json={"model": args.model_name})
+ worker_addr = ret.json()["address"]
+ print(f"worker_addr: {worker_addr}")
+
+ if worker_addr == "":
+ return
+
+ conv = default_conversation.copy()
+ conv.append_message(conv.roles[0], args.message)
+ prompt = conv.get_prompt()
+
+ headers = {"User-Agent": "LLaVA Client"}
+ pload = {
+ "model": args.model_name,
+ "prompt": prompt,
+ "max_new_tokens": args.max_new_tokens,
+ "temperature": 0.7,
+ "stop": conv.sep,
+ }
+ response = requests.post(worker_addr + "/worker_generate_stream", headers=headers,
+ json=pload, stream=True)
+
+ print(prompt.replace(conv.sep, "\n"), end="")
+ for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
+ if chunk:
+ data = json.loads(chunk.decode("utf-8"))
+ output = data["text"].split(conv.sep)[-1]
+ print(output, end="\r")
+ print("")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--controller-address", type=str, default="http://localhost:21001")
+ parser.add_argument("--worker-address", type=str)
+ parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
+ parser.add_argument("--max-new-tokens", type=int, default=32)
+ parser.add_argument("--message", type=str, default=
+ "Tell me a story with more than 1000 words.")
+ args = parser.parse_args()
+
+ main()
diff --git a/llava/train/__pycache__/llava_trainer.cpython-310.pyc b/llava/train/__pycache__/llava_trainer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c17d9270c4076bfbf8f9e8da89eeb2f95948be0e
--- /dev/null
+++ b/llava/train/__pycache__/llava_trainer.cpython-310.pyc
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:912e362a2d9e522479398f4f72dc57638ba9d88feca93ca44cf64637244fc737
+size 13308
diff --git a/llava/train/__pycache__/train.cpython-310.pyc b/llava/train/__pycache__/train.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c59e6b889ea934373d87212d4827e63b60f9ba33
--- /dev/null
+++ b/llava/train/__pycache__/train.cpython-310.pyc
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7a9c02f931ce23d50de939b911b28968ab10d14e3118270ac474b106c6b509d1
+size 30209
diff --git a/llava/train/llama_flash_attn_monkey_patch.py b/llava/train/llama_flash_attn_monkey_patch.py
new file mode 100644
index 0000000000000000000000000000000000000000..31db2eff8d1c4b3ae645583dfc5e156e818b6f1c
--- /dev/null
+++ b/llava/train/llama_flash_attn_monkey_patch.py
@@ -0,0 +1,115 @@
+from typing import Optional, Tuple
+import warnings
+
+import torch
+
+import transformers
+from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
+
+try:
+ from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
+except ImportError:
+ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
+from flash_attn.bert_padding import unpad_input, pad_input
+
+
+def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if output_attentions:
+ warnings.warn(
+ "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = (
+ self.q_proj(hidden_states)
+ .view(bsz, q_len, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+ key_states = (
+ self.k_proj(hidden_states)
+ .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+ value_states = (
+ self.v_proj(hidden_states)
+ .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
+ .transpose(1, 2)
+ ) # shape: (b, num_heads, s, head_dim)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states, key_states = apply_rotary_pos_emb(
+ query_states, key_states, cos, sin, position_ids
+ )
+
+ if past_key_value is not None:
+ # reuse k, v
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+ past_key_value = (key_states, value_states) if use_cache else None
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ # Transform the data into the format required by flash attention
+ qkv = torch.stack([query_states, key_states, value_states], dim=2)
+ qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim]
+ key_padding_mask = attention_mask
+
+ if key_padding_mask is None:
+ qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim)
+ cu_q_lens = torch.arange(
+ 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
+ )
+ max_s = q_len
+ output = flash_attn_unpadded_qkvpacked_func(
+ qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
+ )
+ output = output.view(bsz, q_len, -1)
+ else:
+ qkv = qkv.reshape(bsz, q_len, -1)
+ qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask)
+ qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
+ output_unpad = flash_attn_unpadded_qkvpacked_func(
+ qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
+ )
+ output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
+ output = pad_input(output_unpad, indices, bsz, q_len)
+
+ return self.o_proj(output), None, past_key_value
+
+
+# Disable the transformation of the attention mask in LlamaModel as the flash attention
+# requires the attention mask to be the same as the key_padding_mask
+def _prepare_decoder_attention_mask(
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
+):
+ # [bsz, seq_len]
+ return attention_mask
+
+
+def replace_llama_attn_with_flash_attn():
+ cuda_major, cuda_minor = torch.cuda.get_device_capability()
+ if cuda_major < 8:
+ warnings.warn(
+ "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
+ "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
+ )
+ transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
+ _prepare_decoder_attention_mask
+ )
+ transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
diff --git a/llava/train/llama_xformers_attn_monkey_patch.py b/llava/train/llama_xformers_attn_monkey_patch.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8351e41ccd4a64dca237bd8f8be0702b23989dc
--- /dev/null
+++ b/llava/train/llama_xformers_attn_monkey_patch.py
@@ -0,0 +1,129 @@
+"""
+Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments
+"""
+
+import logging
+import math
+from typing import Optional, Tuple
+
+import torch
+import transformers.models.llama.modeling_llama
+from torch import nn
+
+try:
+ import xformers.ops
+except ImportError:
+ logging.error("xformers not found! Please install it before trying to use it.")
+
+
+def replace_llama_attn_with_xformers_attn():
+ transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
+
+
+def xformers_forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ # pylint: disable=duplicate-code
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = (
+ self.q_proj(hidden_states)
+ .view(bsz, q_len, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+ key_states = (
+ self.k_proj(hidden_states)
+ .view(bsz, q_len, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+ value_states = (
+ self.v_proj(hidden_states)
+ .view(bsz, q_len, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ (
+ query_states,
+ key_states,
+ ) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
+ query_states, key_states, cos, sin, position_ids
+ )
+ # [bsz, nh, t, hd]
+
+ if past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+ past_key_value = (key_states, value_states) if use_cache else None
+
+ # We only apply xformers optimizations if we don't need to output the whole attention matrix
+ if not output_attentions:
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ # This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
+ # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
+ if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
+ # input and output should be of form (bsz, q_len, num_heads, head_dim)
+ attn_output = xformers.ops.memory_efficient_attention(
+ query_states, key_states, value_states, attn_bias=None
+ )
+ else:
+ # input and output should be of form (bsz, q_len, num_heads, head_dim)
+ attn_output = xformers.ops.memory_efficient_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_bias=xformers.ops.LowerTriangularMask(),
+ )
+ attn_weights = None
+ else:
+ attn_weights = torch.matmul(
+ query_states, key_states.transpose(2, 3)
+ ) / math.sqrt(self.head_dim)
+
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights + attention_mask
+ attn_weights = torch.max(
+ attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
+ )
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(
+ attn_weights, dim=-1, dtype=torch.float32
+ ).to(query_states.dtype)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2)
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights, past_key_value
diff --git a/llava/train/llava_trainer.py b/llava/train/llava_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b5c0ccbe407107d26f49241d6cbb11fbd28aefe
--- /dev/null
+++ b/llava/train/llava_trainer.py
@@ -0,0 +1,314 @@
+import os
+import torch
+import torch.nn as nn
+
+from torch.utils.data import Sampler
+
+from transformers import Trainer
+from transformers.trainer import (
+ is_sagemaker_mp_enabled,
+ get_parameter_names,
+ has_length,
+ ALL_LAYERNORM_LAYERS,
+ logger,
+)
+from typing import List, Optional
+
+
+# Borrowed from peft.utils.get_peft_model_state_dict
+def get_peft_state_maybe_zero_3(named_params, bias):
+ if bias == "none":
+ to_return = {k: t for k, t in named_params if "lora_" in k}
+ elif bias == "all":
+ to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
+ elif bias == "lora_only":
+ to_return = {}
+ maybe_lora_bias = {}
+ lora_bias_names = set()
+ for k, t in named_params:
+ if "lora_" in k:
+ to_return[k] = t
+ bias_name = k.split("lora_")[0] + "bias"
+ lora_bias_names.add(bias_name)
+ elif "bias" in k:
+ maybe_lora_bias[k] = t
+ for k, t in maybe_lora_bias:
+ if bias_name in lora_bias_names:
+ to_return[bias_name] = t
+ else:
+ raise NotImplementedError
+ to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
+ return to_return
+
+
+def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
+ to_return = {k: t for k, t in named_params if "lora_" not in k}
+ if require_grad_only:
+ to_return = {k: t for k, t in to_return.items() if t.requires_grad}
+ to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
+ return to_return
+
+
+def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
+ to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
+ to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
+ return to_return
+
+
+def maybe_zero_3(param, ignore_status=False, name=None):
+ from deepspeed import zero
+ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
+ if hasattr(param, "ds_id"):
+ if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
+ if not ignore_status:
+ print(name, 'no ignore status')
+ with zero.GatheredParameters([param]):
+ param = param.data.detach().cpu().clone()
+ else:
+ param = param.detach().cpu().clone()
+ return param
+
+
+
+
+def split_to_even_chunks(indices, lengths, num_chunks):
+ """
+ Split a list of indices into `chunks` chunks of roughly equal lengths.
+ """
+
+ if len(indices) % num_chunks != 0:
+ return [indices[i::num_chunks] for i in range(num_chunks)]
+
+ num_indices_per_chunk = len(indices) // num_chunks
+
+ chunks = [[] for _ in range(num_chunks)]
+ chunks_lengths = [0 for _ in range(num_chunks)]
+ for index in indices:
+ shortest_chunk = chunks_lengths.index(min(chunks_lengths))
+ chunks[shortest_chunk].append(index)
+ chunks_lengths[shortest_chunk] += lengths[index]
+ if len(chunks[shortest_chunk]) == num_indices_per_chunk:
+ chunks_lengths[shortest_chunk] = float("inf")
+
+ return chunks
+
+
+def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None):
+ # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
+ assert all(l != 0 for l in lengths), "Should not have zero length."
+ if all(l > 0 for l in lengths) or all(l < 0 for l in lengths):
+ # all samples are in the same modality
+ return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator)
+ mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0])
+ lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0])
+
+ mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)]
+ lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)]
+ megabatch_size = world_size * batch_size
+ mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)]
+ lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)]
+
+ last_mm = mm_megabatches[-1]
+ last_lang = lang_megabatches[-1]
+ additional_batch = last_mm + last_lang
+ megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
+ megabatch_indices = torch.randperm(len(megabatches), generator=generator)
+ megabatches = [megabatches[i] for i in megabatch_indices]
+
+ if len(additional_batch) > 0:
+ megabatches.append(sorted(additional_batch))
+
+ return [i for megabatch in megabatches for i in megabatch]
+
+
+def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True):
+ # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
+ indices = torch.randperm(len(lengths), generator=generator)
+ megabatch_size = world_size * batch_size
+ megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
+ megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
+ megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]
+
+ return [i for megabatch in megabatches for batch in megabatch for i in batch]
+
+
+class LengthGroupedSampler(Sampler):
+ r"""
+ Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
+ keeping a bit of randomness.
+ """
+
+ def __init__(
+ self,
+ batch_size: int,
+ world_size: int,
+ lengths: Optional[List[int]] = None,
+ generator=None,
+ group_by_modality: bool = False,
+ ):
+ if lengths is None:
+ raise ValueError("Lengths must be provided.")
+
+ self.batch_size = batch_size
+ self.world_size = world_size
+ self.lengths = lengths
+ self.generator = generator
+ self.group_by_modality = group_by_modality
+
+ def __len__(self):
+ return len(self.lengths)
+
+ def __iter__(self):
+ if self.group_by_modality:
+ indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
+ else:
+ indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
+ return iter(indices)
+
+
+class LLaVATrainer(Trainer):
+
+ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
+ if self.train_dataset is None or not has_length(self.train_dataset):
+ return None
+
+ if self.args.group_by_modality_length:
+ lengths = self.train_dataset.modality_lengths
+ return LengthGroupedSampler(
+ self.args.train_batch_size,
+ world_size=self.args.world_size * self.args.gradient_accumulation_steps,
+ lengths=lengths,
+ group_by_modality=True,
+ )
+ else:
+ return super()._get_train_sampler()
+
+ def create_optimizer(self):
+ """
+ Setup the optimizer.
+
+ We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
+ Trainer's init through `optimizers`, or subclass and override this method in a subclass.
+ """
+ if is_sagemaker_mp_enabled():
+ return super().create_optimizer()
+
+ opt_model = self.model
+
+ if self.optimizer is None:
+ decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
+ decay_parameters = [name for name in decay_parameters if "bias" not in name]
+ if self.args.mm_projector_lr is not None:
+ projector_parameters = [name for name, _ in opt_model.named_parameters() if "mm_projector" in name]
+ optimizer_grouped_parameters = [
+ {
+ "params": [
+ p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in projector_parameters and p.requires_grad)
+ ],
+ "weight_decay": self.args.weight_decay,
+ },
+ {
+ "params": [
+ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in projector_parameters and p.requires_grad)
+ ],
+ "weight_decay": 0.0,
+ },
+ {
+ "params": [
+ p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in projector_parameters and p.requires_grad)
+ ],
+ "weight_decay": self.args.weight_decay,
+ "lr": self.args.mm_projector_lr,
+ },
+ {
+ "params": [
+ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in projector_parameters and p.requires_grad)
+ ],
+ "weight_decay": 0.0,
+ "lr": self.args.mm_projector_lr,
+ },
+ ]
+ else:
+ optimizer_grouped_parameters = [
+ {
+ "params": [
+ p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
+ ],
+ "weight_decay": self.args.weight_decay,
+ },
+ {
+ "params": [
+ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
+ ],
+ "weight_decay": 0.0,
+ },
+ ]
+
+ optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
+
+ self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
+ if optimizer_cls.__name__ == "Adam8bit":
+ import bitsandbytes
+
+ manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
+
+ skipped = 0
+ for module in opt_model.modules():
+ if isinstance(module, nn.Embedding):
+ skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
+ logger.info(f"skipped {module}: {skipped/2**20}M params")
+ manager.register_module_override(module, "weight", {"optim_bits": 32})
+ logger.debug(f"bitsandbytes: will optimize {module} in fp32")
+ logger.info(f"skipped: {skipped/2**20}M params")
+
+ return self.optimizer
+
+
+
+ def _save_checkpoint(self, model, trial, metrics=None):
+ if getattr(self.args, 'tune_mm_mlp_adapter', False):
+ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
+ checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
+
+ run_dir = self._get_output_dir(trial=trial)
+ output_dir = os.path.join(run_dir, checkpoint_folder)
+
+ # Only save Adapter
+ keys_to_match = ['mm_projector', 'vision_resampler', 'Qformer', 'query_tokens', 'de_noise', 'vision_tower.trans2embed', 'vision_tower.visual.fc_norm.']
+ if getattr(self.args, "use_im_start_end", False):
+ keys_to_match.extend(['embed_tokens', 'embed_in'])
+
+ weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match)
+
+ if self.args.local_rank == 0 or self.args.local_rank == -1:
+ self.model.config.save_pretrained(output_dir)
+ torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
+ else:
+ if self.args.lora_enable :
+ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
+ checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
+
+ run_dir = self._get_output_dir(trial=trial)
+ output_dir = os.path.join(self.args.output_dir, checkpoint_folder)
+ os.makedirs(output_dir, exist_ok=True)
+ state_dict = get_peft_state_maybe_zero_3(
+ self.model.named_parameters(), self.args.lora_bias
+ )
+ non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(
+ self.model.named_parameters()
+ )
+ if self.args.local_rank == 0 or self.args.local_rank == -1:
+ print(f"save models to {output_dir} ")
+ self.model.config.save_pretrained(output_dir)
+ # self.model.save_pretrained(output_dir, state_dict=state_dict)
+ super(LLaVATrainer, self)._save_checkpoint(model, trial, metrics)
+ torch.save(non_lora_state_dict, os.path.join(output_dir, 'non_lora_trainables.bin'))
+
+ else:
+ super(LLaVATrainer, self)._save_checkpoint(model, trial, metrics)
+
+ def _save(self, output_dir: Optional[str] = None, state_dict=None):
+ if getattr(self.args, 'tune_mm_mlp_adapter', False):
+ pass
+ else:
+ super(LLaVATrainer, self)._save(output_dir, state_dict)
diff --git a/llava/train/train.py b/llava/train/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..754109ec5ebe3be9fb307ef999ae224fbe58e693
--- /dev/null
+++ b/llava/train/train.py
@@ -0,0 +1,1196 @@
+# Modified from LLaVA: https://github.com/haotian-liu/LLaVA.git
+#
+# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
+# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
+# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import copy
+from dataclasses import dataclass, field
+import json
+import logging
+import pathlib
+from typing import Dict, Optional, Sequence, List
+import numpy as np
+import torch
+
+import transformers
+import tokenizers
+
+from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_PC_TOKEN, DEFAULT_IM_END_TOKEN
+from torch.utils.data import Dataset
+
+from llava.model.utils import SimpleTokenizer
+from llava.train.llava_trainer import LLaVATrainer
+
+from llava import conversation as conversation_lib
+from llava.model import *
+from llava.mm_utils import tokenizer_image_token
+
+import random
+local_rank = None
+
+
+def pc_norm(pc):
+ """ pc: NxC, return NxC """
+ xyz = pc[:, :3]
+ other_feature = pc[:, 3:]
+
+ centroid = np.mean(xyz, axis=0)
+ xyz = xyz - centroid
+ m = np.max(np.sqrt(np.sum(xyz ** 2, axis=1)))
+ xyz = xyz / m
+
+ pc = np.concatenate((xyz, other_feature), axis=1)
+ return pc
+
+def read_pc_2tesnor(object_id):
+ data_path = './dataset/Objaverse/8192_npy'
+
+ filename = f"{object_id}_8192.npy"
+ point_cloud = np.load(os.path.join(data_path, filename))
+
+ point_cloud = pc_norm(point_cloud)
+ point_cloud = torch.from_numpy(point_cloud.astype(np.float32))
+ return point_cloud
+
+
+def rank0_print(*args):
+ if local_rank == 0:
+ print(*args)
+
+
+from packaging import version
+
+IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse('0.14')
+
+
+@dataclass
+class ModelArguments:
+ model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
+ version: Optional[str] = field(default="v0")
+ freeze_backbone: bool = field(default=False)
+ tune_mm_mlp_adapter: bool = field(default=False)
+ vision_tower: Optional[str] = field(default=True)
+ mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer
+ pretrain_mm_mlp_adapter: Optional[str] = field(default=None)
+ mm_projector_type: Optional[str] = field(default='linear')
+ mm_use_im_start_end: bool = field(default=False)
+ mm_use_im_patch_token: bool = field(default=True)
+ mm_patch_merge_type: Optional[str] = field(default='flat')
+ mm_vision_select_feature: Optional[str] = field(default="patch")
+ encoder_type: Optional[str] = field(default="text_encoder") # text_encoder, pc encoder
+ std: Optional[float] = field(default=0.0) # Add Gaussian noise to the feature oftext encoder
+
+
+
+ # for pc encoder
+ # "CLS": inference of stage 1, 2
+ # "OM_Pooling": training and inference of stage 3
+ get_pc_tokens_way: Optional[str] = field(default=1)
+
+
+ pc_feat_dim: Optional[int] = field(default=192)
+ embed_dim: Optional[int] = field(default=1024)
+ group_size: Optional[int] = field(default=64)
+ num_group: Optional[int] = field(default=512)
+ pc_encoder_dim: Optional[int] = field(default=512)
+ patch_dropout: Optional[float] = field(default=0.0)
+ pc_ckpt_path: Optional[str] = field(
+ default="./pretrained_weight/Uni3D_PC_encoder/modelzoo/uni3d-s/model.pt")
+ pc_encoder_type: Optional[str] = field(default="small")
+
+
+@dataclass
+class DataArguments:
+ data_path: str = field(default=None,
+ metadata={"help": "Path to the training data."})
+ lazy_preprocess: bool = False
+ is_multimodal: bool = False
+
+
+
+@dataclass
+class TrainingArguments(transformers.TrainingArguments):
+
+ cache_dir: Optional[str] = field(default=None)
+ optim: str = field(default="adamw_torch")
+ remove_unused_columns: bool = field(default=False)
+ freeze_mm_mlp_adapter: bool = field(default=False)
+ mpt_attn_impl: Optional[str] = field(default="triton")
+ model_max_length: int = field(
+ default=512,
+ metadata={
+ "help":
+ "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
+ },
+ )
+ double_quant: bool = field(
+ default=True,
+ metadata={"help": "Compress the quantization statistics through double quantization."}
+ )
+ quant_type: str = field(
+ default="nf4",
+ metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
+ )
+ bits: int = field(
+ default=16,
+ metadata={"help": "How many bits to use."}
+ )
+ lora_enable: bool = False
+ lora_r: int = 64
+ lora_alpha: int = 16
+ lora_dropout: float = 0.05
+ lora_weight_path: str = ""
+ lora_bias: str = "none"
+ mm_projector_lr: Optional[float] = None
+ group_by_modality_length: bool = field(default=False)
+ lora_path: str = field(default=None, metadata={"help": "Path to the previous lora folder."})
+
+
+def maybe_zero_3(param, ignore_status=False, name=None):
+ from deepspeed import zero
+ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
+ if hasattr(param, "ds_id"):
+ if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
+ if not ignore_status:
+ logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
+ with zero.GatheredParameters([param]):
+ param = param.data.detach().cpu().clone()
+ else:
+ param = param.detach().cpu().clone()
+ return param
+
+
+# Borrowed from peft.utils.get_peft_model_state_dict
+def get_peft_state_maybe_zero_3(named_params, bias):
+ if bias == "none":
+ to_return = {k: t for k, t in named_params if "lora_" in k}
+ elif bias == "all":
+ to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
+ elif bias == "lora_only":
+ to_return = {}
+ maybe_lora_bias = {}
+ lora_bias_names = set()
+ for k, t in named_params:
+ if "lora_" in k:
+ to_return[k] = t
+ bias_name = k.split("lora_")[0] + "bias"
+ lora_bias_names.add(bias_name)
+ elif "bias" in k:
+ maybe_lora_bias[k] = t
+ for k, t in maybe_lora_bias:
+ if bias_name in lora_bias_names:
+ to_return[bias_name] = t
+ else:
+ raise NotImplementedError
+ to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
+ return to_return
+
+
+def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
+ to_return = {k: t for k, t in named_params if "lora_" not in k}
+ if require_grad_only:
+ to_return = {k: t for k, t in to_return.items() if t.requires_grad}
+ to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
+ return to_return
+
+
+def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
+ to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
+ to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
+ return to_return
+
+
+def find_all_linear_names(model):
+ cls = torch.nn.Linear
+ lora_module_names = set()
+ multimodal_keywords = ['mm_projector', 'vision_tower', 'vision_resampler']
+ for name, module in model.named_modules():
+ if any(mm_keyword in name for mm_keyword in multimodal_keywords):
+ continue
+ if isinstance(module, cls):
+ names = name.split('.')
+ lora_module_names.add(names[0] if len(names) == 1 else names[-1])
+
+ if 'lm_head' in lora_module_names: # needed for 16-bit
+ lora_module_names.remove('lm_head')
+ return list(lora_module_names)
+
+
+def safe_save_model_for_hf_trainer(trainer: transformers.Trainer,
+ output_dir: str):
+ """Collects the state dict and dump to disk."""
+
+ if getattr(trainer.args, "tune_mm_mlp_adapter", False):
+ # Only save Adapter
+ keys_to_match = ['mm_projector', 'Qformer', 'query_tokens', 'de_noise']
+ if getattr(trainer.args, "use_im_start_end", False):
+ keys_to_match.extend(['embed_tokens', 'embed_in'])
+
+ weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match)
+ trainer.model.config.save_pretrained(output_dir)
+
+ current_folder = output_dir.split('/')[-1]
+ parent_folder = os.path.dirname(output_dir)
+ if trainer.args.local_rank == 0 or trainer.args.local_rank == -1:
+ if current_folder.startswith('checkpoint-'):
+ mm_projector_folder = os.path.join(parent_folder, "mm_projector")
+ os.makedirs(mm_projector_folder, exist_ok=True)
+ torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin'))
+ else:
+ torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
+ return
+
+ if trainer.deepspeed:
+ torch.cuda.synchronize()
+ trainer.save_model(output_dir)
+ return
+
+ state_dict = trainer.model.state_dict()
+ if trainer.args.should_save:
+ cpu_state_dict = {
+ key: value.cpu()
+ for key, value in state_dict.items()
+ }
+ del state_dict
+ trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
+
+
+def smart_tokenizer_and_embedding_resize(
+ special_tokens_dict: Dict,
+ tokenizer: transformers.PreTrainedTokenizer,
+ model: transformers.PreTrainedModel,
+):
+ """Resize tokenizer and embedding.
+
+ Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
+ """
+ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
+ model.resize_token_embeddings(len(tokenizer))
+
+ if num_new_tokens > 0:
+ input_embeddings = model.get_input_embeddings().weight.data
+ output_embeddings = model.get_output_embeddings().weight.data
+
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
+ dim=0, keepdim=True)
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
+ dim=0, keepdim=True)
+
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
+
+
+def _tokenize_fn(strings: Sequence[str],
+ tokenizer: transformers.PreTrainedTokenizer) -> Dict:
+ """Tokenize a list of strings."""
+ tokenized_list = [
+ tokenizer(
+ text,
+ return_tensors="pt",
+ padding="longest",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ ) for text in strings
+ ]
+ input_ids = labels = [
+ tokenized.input_ids[0] for tokenized in tokenized_list
+ ]
+ input_ids_lens = labels_lens = [
+ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
+ for tokenized in tokenized_list
+ ]
+ return dict(
+ input_ids=input_ids,
+ labels=labels,
+ input_ids_lens=input_ids_lens,
+ labels_lens=labels_lens,
+ )
+
+
+def _mask_targets(target, tokenized_lens, speakers):
+ # cur_idx = 0
+ cur_idx = tokenized_lens[0]
+ tokenized_lens = tokenized_lens[1:]
+ target[:cur_idx] = IGNORE_INDEX
+ for tokenized_len, speaker in zip(tokenized_lens, speakers):
+ if speaker == "human":
+ target[cur_idx + 2:cur_idx + tokenized_len] = IGNORE_INDEX
+ cur_idx += tokenized_len
+
+
+def _add_speaker_and_signal(header, source, get_conversation=True):
+ """Add speaker and start/end signal on each round."""
+ BEGIN_SIGNAL = "### "
+ END_SIGNAL = "\n"
+ conversation = header
+ for sentence in source:
+ from_str = sentence["from"]
+ if from_str.lower() == "human":
+ from_str = conversation_lib.default_conversation.roles[0]
+ elif from_str.lower() == "gpt":
+ from_str = conversation_lib.default_conversation.roles[1]
+ else:
+ from_str = 'unknown'
+ sentence["value"] = (BEGIN_SIGNAL + from_str + ": " +
+ sentence["value"] + END_SIGNAL)
+ if get_conversation:
+ conversation += sentence["value"]
+ conversation += BEGIN_SIGNAL
+ return conversation
+
+
+def preprocess_multimodal(
+ sources: Sequence[str],
+ data_args: DataArguments
+) -> Dict:
+ is_multimodal = data_args.is_multimodal
+ if not is_multimodal:
+ return sources
+
+ for source in sources:
+ for sentence in source:
+ if DEFAULT_PC_TOKEN in sentence['value']:
+ # Adapt Llava code
+ sentence['value'] = sentence['value'].replace(DEFAULT_PC_TOKEN, DEFAULT_IMAGE_TOKEN)
+ # Adapt Llava code
+ sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()
+ sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value']
+ sentence['value'] = sentence['value'].strip()
+ if "mmtag" in conversation_lib.default_conversation.version:
+ sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN,
+ '' + DEFAULT_IMAGE_TOKEN + '')
+ replace_token = DEFAULT_IMAGE_TOKEN
+ if data_args.mm_use_im_start_end:
+ replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
+ sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)
+
+ return sources
+
+
+def preprocess_llama_2(
+ sources,
+ tokenizer: transformers.PreTrainedTokenizer,
+ has_image: bool = False
+) -> Dict:
+ conv = conversation_lib.default_conversation.copy()
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
+
+ # Apply prompt templates
+ conversations = []
+ for i, source in enumerate(sources):
+ if roles[source[0]["from"]] != conv.roles[0]:
+ # Skip the first one if it is not from human
+ source = source[1:]
+
+ conv.messages = []
+ for j, sentence in enumerate(source):
+ role = roles[sentence["from"]]
+ assert role == conv.roles[j % 2], f"{i}"
+ conv.append_message(role, sentence["value"])
+ conversations.append(conv.get_prompt())
+
+ # Tokenize conversations
+
+ if has_image:
+ input_ids = torch.stack(
+ [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
+ else:
+ input_ids = tokenizer(
+ conversations,
+ return_tensors="pt",
+ padding="longest",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ ).input_ids
+
+ targets = input_ids.clone()
+
+ assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2
+
+ # Mask targets
+ sep = "[/INST] "
+ for conversation, target in zip(conversations, targets):
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
+
+ rounds = conversation.split(conv.sep2)
+ cur_len = 1
+ target[:cur_len] = IGNORE_INDEX
+ for i, rou in enumerate(rounds):
+ if rou == "":
+ break
+
+ parts = rou.split(sep)
+ if len(parts) != 2:
+ break
+ parts[0] += sep
+
+ if has_image:
+ round_len = len(tokenizer_image_token(rou, tokenizer))
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
+ else:
+ round_len = len(tokenizer(rou).input_ids)
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
+
+ target[cur_len: cur_len + instruction_len] = IGNORE_INDEX
+
+ cur_len += round_len
+ target[cur_len:] = IGNORE_INDEX
+
+ if cur_len < tokenizer.model_max_length:
+ if cur_len != total_len:
+ target[:] = IGNORE_INDEX
+ print(
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
+ f" (ignored)"
+ )
+
+ return dict(
+ input_ids=input_ids,
+ labels=targets,
+ )
+
+
+def preprocess_v1(
+ sources,
+ tokenizer: transformers.PreTrainedTokenizer,
+ has_image: bool = False
+) -> Dict:
+ conv = conversation_lib.default_conversation.copy()
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
+
+ # Apply prompt templates
+ conversations = []
+ for i, source in enumerate(sources):
+ if roles[source[0]["from"]] != conv.roles[0]:
+ # Skip the first one if it is not from human
+ source = source[1:]
+
+ conv.messages = []
+ for j, sentence in enumerate(source):
+ role = roles[sentence["from"]]
+ assert role == conv.roles[j % 2], f"{i}"
+ conv.append_message(role, sentence["value"])
+ conversations.append(conv.get_prompt())
+
+ # Tokenize conversations
+
+ if has_image:
+ input_ids = torch.stack(
+ [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
+ else:
+ input_ids = tokenizer(
+ conversations,
+ return_tensors="pt",
+ padding="longest",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ ).input_ids
+
+ targets = input_ids.clone()
+
+ assert conv.sep_style == conversation_lib.SeparatorStyle.TWO
+
+ # Mask targets
+ sep = conv.sep + conv.roles[1] + ": "
+ for conversation, target in zip(conversations, targets):
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
+
+ rounds = conversation.split(conv.sep2)
+ cur_len = 1
+ target[:cur_len] = IGNORE_INDEX
+ for i, rou in enumerate(rounds):
+ if rou == "":
+ break
+
+ parts = rou.split(sep)
+ if len(parts) != 2:
+ break
+ parts[0] += sep
+
+ if has_image:
+ round_len = len(tokenizer_image_token(rou, tokenizer))
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
+ else:
+ round_len = len(tokenizer(rou).input_ids)
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
+
+ if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14:
+ round_len -= 1
+ instruction_len -= 1
+
+ target[cur_len: cur_len + instruction_len] = IGNORE_INDEX
+
+ cur_len += round_len
+ target[cur_len:] = IGNORE_INDEX
+
+ if cur_len < tokenizer.model_max_length:
+ if cur_len != total_len:
+ target[:] = IGNORE_INDEX
+ print(
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
+ f" (ignored)"
+ )
+
+ return dict(
+ input_ids=input_ids,
+ labels=targets,
+ )
+
+
+def preprocess_mpt(
+ sources,
+ tokenizer: transformers.PreTrainedTokenizer,
+ has_image: bool = False
+) -> Dict:
+ conv = conversation_lib.default_conversation.copy()
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
+
+ # Apply prompt templates
+ conversations = []
+ for i, source in enumerate(sources):
+ if roles[source[0]["from"]] != conv.roles[0]:
+ # Skip the first one if it is not from human
+ source = source[1:]
+
+ conv.messages = []
+ for j, sentence in enumerate(source):
+ role = roles[sentence["from"]]
+ assert role == conv.roles[j % 2], f"{i}"
+ conv.append_message(role, sentence["value"])
+ conversations.append(conv.get_prompt())
+
+ # Tokenize conversations
+
+ if has_image:
+ input_ids = torch.stack(
+ [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
+ else:
+ input_ids = tokenizer(
+ conversations,
+ return_tensors="pt",
+ padding="longest",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ ).input_ids
+
+ targets = input_ids.clone()
+ assert conv.sep_style == conversation_lib.SeparatorStyle.MPT
+
+ # Mask targets
+ sep = conv.sep + conv.roles[1]
+ for conversation, target in zip(conversations, targets):
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
+
+ rounds = conversation.split(conv.sep)
+ re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt
+ for conv_idx in range(3, len(rounds), 2):
+ re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx + 2])) # user + gpt
+ cur_len = 0
+ target[:cur_len] = IGNORE_INDEX
+ for i, rou in enumerate(re_rounds):
+ if rou == "":
+ break
+
+ parts = rou.split(sep)
+ if len(parts) != 2:
+ break
+ parts[0] += sep
+
+ if has_image:
+ round_len = len(tokenizer_image_token(rou, tokenizer))
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1
+ else:
+ round_len = len(tokenizer(rou).input_ids)
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 1
+
+ if i != 0 and getattr(tokenizer, 'legacy', False) and IS_TOKENIZER_GREATER_THAN_0_14:
+ round_len += 1
+ instruction_len += 1
+
+ target[cur_len: cur_len + instruction_len] = IGNORE_INDEX
+
+ cur_len += round_len
+ target[cur_len:] = IGNORE_INDEX
+
+ if cur_len < tokenizer.model_max_length:
+ if cur_len != total_len:
+ target[:] = IGNORE_INDEX
+ print(
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
+ f" (ignored)"
+ )
+
+ return dict(
+ input_ids=input_ids,
+ labels=targets,
+ )
+
+
+def preprocess_phi3(
+ sources,
+ tokenizer: transformers.PreTrainedTokenizer,
+ has_image: bool = False
+) -> Dict:
+ conv = conversation_lib.default_conversation.copy()
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
+
+ # Apply prompt templates
+ conversations = []
+ for i, source in enumerate(sources):
+ if roles[source[0]["from"]] != conv.roles[0]:
+ # Skip the first one if it is not from human
+ source = source[1:]
+
+ conv.messages = []
+ for j, sentence in enumerate(source):
+ role = roles[sentence["from"]]
+ assert role == conv.roles[j % 2], f"{i}"
+ conv.append_message(role, sentence["value"])
+ conversations.append(conv.get_prompt())
+
+ # Tokenize conversations
+
+ if has_image:
+ input_ids = torch.stack(
+ [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
+ else:
+ input_ids = tokenizer(
+ conversations,
+ return_tensors="pt",
+ padding="longest",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ ).input_ids
+
+ targets = input_ids.clone()
+ assert conv.sep_style == conversation_lib.SeparatorStyle.MPT
+
+ # Mask targets
+ sep = conv.sep + conv.roles[1]
+ for conversation, target in zip(conversations, targets):
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
+
+ rounds = conversation.split(conv.sep)
+ re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt
+ for conv_idx in range(3, len(rounds), 2):
+ re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx + 2])) # user + gpt
+ cur_len = 0
+ target[:cur_len] = IGNORE_INDEX
+ for i, rou in enumerate(re_rounds):
+ if rou == "":
+ break
+
+ parts = rou.split(sep)
+ if len(parts) != 2:
+ break
+ parts[0] += sep
+
+ if has_image:
+ round_len = len(tokenizer_image_token(rou, tokenizer))
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1
+ else:
+ round_len = len(tokenizer(rou).input_ids)
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 1
+
+ if i == 0:
+ round_len += 1
+ instruction_len += 1
+ else:
+ round_len -= 2
+ instruction_len -= 2
+
+ if i != 0 and getattr(tokenizer, 'legacy', False) and IS_TOKENIZER_GREATER_THAN_0_14:
+ round_len += 1
+ instruction_len += 1
+
+ target[cur_len: cur_len + instruction_len] = IGNORE_INDEX
+
+ cur_len += round_len
+ target[cur_len:] = IGNORE_INDEX
+
+ if cur_len < tokenizer.model_max_length:
+ if cur_len != total_len:
+ target[:] = IGNORE_INDEX
+ print(
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
+ f" (ignored)"
+ )
+
+ return dict(
+ input_ids=input_ids,
+ labels=targets,
+ )
+
+
+def preprocess_plain(
+ sources: Sequence[str],
+ tokenizer: transformers.PreTrainedTokenizer,
+) -> Dict:
+ # add end signal and concatenate together
+ conversations = []
+ for source in sources:
+ assert len(source) == 2
+ assert DEFAULT_IMAGE_TOKEN in source[0]['value']
+ source[0]['value'] = DEFAULT_IMAGE_TOKEN
+ conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep
+ conversations.append(conversation)
+ # tokenize conversations
+ input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
+ targets = copy.deepcopy(input_ids)
+ for target, source in zip(targets, sources):
+ tokenized_len = len(tokenizer_image_token(source[0]['value'], tokenizer))
+ target[:tokenized_len] = IGNORE_INDEX
+
+ return dict(input_ids=input_ids, labels=targets)
+
+
+def preprocess(
+ sources: Sequence[str],
+ tokenizer: transformers.PreTrainedTokenizer,
+ has_image: bool = False
+) -> Dict:
+ """
+ Given a list of sources, each is a conversation list. This transform:
+ 1. Add signal '### ' at the beginning each sentence, with end signal '\n';
+ 2. Concatenate conversations together;
+ 3. Tokenize the concatenated conversation;
+ 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
+ """
+ if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN:
+ return preprocess_plain(sources, tokenizer)
+ if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2:
+ return preprocess_llama_2(sources, tokenizer, has_image=has_image)
+ if conversation_lib.default_conversation.version.startswith("v1"):
+ return preprocess_v1(sources, tokenizer, has_image=has_image)
+ if conversation_lib.default_conversation.version == "mpt":
+ return preprocess_mpt(sources, tokenizer, has_image=has_image)
+ if conversation_lib.default_conversation.version == "phi3":
+ return preprocess_phi3(sources, tokenizer, has_image=has_image)
+ # add end signal and concatenate together
+ conversations = []
+ for source in sources:
+ header = f"{conversation_lib.default_conversation.system}\n\n"
+ conversation = _add_speaker_and_signal(header, source)
+ conversations.append(conversation)
+
+ # tokenize conversations
+ def get_tokenize_len(prompts):
+ return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts]
+
+ if has_image:
+ input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
+ else:
+ conversations_tokenized = _tokenize_fn(conversations, tokenizer)
+ input_ids = conversations_tokenized["input_ids"]
+
+ targets = copy.deepcopy(input_ids)
+ for target, source in zip(targets, sources):
+ if has_image:
+ tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source])
+ else:
+ tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"]
+ speakers = [sentence["from"] for sentence in source]
+ _mask_targets(target, tokenized_lens, speakers)
+
+ return dict(input_ids=input_ids, labels=targets)
+
+
+class LazySupervisedDataset(Dataset):
+ """Dataset for supervised fine-tuning."""
+
+ def __init__(self, data_path: str,
+ tokenizer: transformers.PreTrainedTokenizer,
+ data_args: DataArguments,
+ model_args: dict):
+ super(LazySupervisedDataset, self).__init__()
+ list_data_dict = json.load(open(data_path, "r"))
+ random.shuffle(list_data_dict)
+
+ rank0_print("Formatting inputs...Skip in lazy mode")
+ self.tokenizer = tokenizer
+ self.list_data_dict = list_data_dict
+ self.data_args = data_args
+ self.model_args = model_args
+ self.clip_tokenizer = SimpleTokenizer()
+
+ def __len__(self):
+ return len(self.list_data_dict)
+
+ @property
+ def lengths(self): # no need
+ length_list = []
+ for sample in self.list_data_dict:
+ if self.model_args.encoder_type == 'text_encoder':
+ img_tokens = 77 # need check
+ elif self.model_args.encoder_type == 'pc_encoder':
+ img_tokens = 513
+ # img_tokens = 128 if 'image' in sample else 0
+ length_list.append(sum(len(conv['value'].split()) for conv in sample['conversations']) + img_tokens)
+ return length_list
+
+ @property
+ def modality_lengths(self):
+ length_list = []
+ for sample in self.list_data_dict:
+ cur_len = sum(len(conv['value'].split()) for conv in sample['conversations'])
+ cur_len = cur_len if 'image' in sample else -cur_len
+ length_list.append(cur_len)
+ return length_list
+
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
+ sources = self.list_data_dict[i]
+ if isinstance(i, int):
+ sources = [sources]
+ assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
+
+ if self.model_args.encoder_type == 'text_encoder':
+ if 'caption' in sources[0]:
+ image = self.list_data_dict[i]['caption']
+
+ image = self.clip_tokenizer(image).to(non_blocking=True)
+
+ elif self.model_args.encoder_type == 'pc_encoder':
+
+ if 'object_id' in sources[0]:
+ object_id = self.list_data_dict[i]['object_id']
+ image = read_pc_2tesnor(object_id)
+ # image = image.to(dtype=torch.bfloat16)
+
+
+ sources = preprocess_multimodal(
+ copy.deepcopy([e["conversations"] for e in sources]),
+ self.data_args)
+
+
+ data_dict = preprocess(
+ sources,
+ self.tokenizer,
+ has_image=(True))
+
+ if isinstance(i, int):
+ data_dict = dict(input_ids=data_dict["input_ids"][0],
+ labels=data_dict["labels"][0])
+
+ data_dict['image'] = image
+
+ return data_dict
+
+
+@dataclass
+class DataCollatorForSupervisedDataset(object):
+ """Collate examples for supervised fine-tuning."""
+
+ tokenizer: transformers.PreTrainedTokenizer
+
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
+ input_ids, labels = tuple([instance[key] for instance in instances]
+ for key in ("input_ids", "labels"))
+ input_ids = torch.nn.utils.rnn.pad_sequence(
+ input_ids,
+ batch_first=True,
+ padding_value=self.tokenizer.pad_token_id)
+ labels = torch.nn.utils.rnn.pad_sequence(labels,
+ batch_first=True,
+ padding_value=IGNORE_INDEX)
+ input_ids = input_ids[:, :self.tokenizer.model_max_length]
+ labels = labels[:, :self.tokenizer.model_max_length]
+ batch = dict(
+ input_ids=input_ids,
+ labels=labels,
+ attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
+ )
+
+ if 'image' in instances[0]:
+ images = [instance['image'] for instance in instances]
+
+ if all(x is not None and x.shape == images[0].shape for x in images):
+ batch['images'] = torch.stack(images)
+ else:
+ batch['images'] = images
+
+ return batch
+
+
+def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
+ data_args, model_args) -> Dict:
+ """Make dataset and collator for supervised fine-tuning."""
+ train_dataset = LazySupervisedDataset(tokenizer=tokenizer,
+ data_path=data_args.data_path,
+ data_args=data_args,
+ model_args=model_args)
+ data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
+ return dict(train_dataset=train_dataset,
+ eval_dataset=None,
+ data_collator=data_collator)
+
+
+def train(attn_implementation=None):
+ global local_rank
+
+ parser = transformers.HfArgumentParser(
+ (ModelArguments, DataArguments, TrainingArguments))
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ local_rank = training_args.local_rank
+ compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
+
+ bnb_model_from_pretrained_args = {}
+ if training_args.bits in [4, 8]:
+ from transformers import BitsAndBytesConfig
+ bnb_model_from_pretrained_args.update(dict(
+ device_map={"": training_args.device},
+ load_in_4bit=training_args.bits == 4,
+ load_in_8bit=training_args.bits == 8,
+ quantization_config=BitsAndBytesConfig(
+ load_in_4bit=training_args.bits == 4,
+ load_in_8bit=training_args.bits == 8,
+ llm_int8_skip_modules=["mm_projector"],
+ llm_int8_threshold=6.0,
+ llm_int8_has_fp16_weight=False,
+ bnb_4bit_compute_dtype=compute_dtype,
+ bnb_4bit_use_double_quant=training_args.double_quant,
+ bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'}
+ )
+ ))
+
+ if model_args.vision_tower is not None:
+ if 'mpt' in model_args.model_name_or_path:
+ config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
+ config.attn_config['attn_impl'] = training_args.mpt_attn_impl
+ model = LlavaMptForCausalLM.from_pretrained(
+ model_args.model_name_or_path,
+ config=config,
+ cache_dir=training_args.cache_dir,
+ **bnb_model_from_pretrained_args
+ )
+ else:
+ model = LlavaPhiForCausalLM.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ attn_implementation=attn_implementation,
+ torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
+ **bnb_model_from_pretrained_args
+ )
+ else:
+ model = transformers.LlamaForCausalLM.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ attn_implementation=attn_implementation,
+ torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
+ **bnb_model_from_pretrained_args
+ )
+ model.config.use_cache = False
+
+ if model_args.freeze_backbone:
+ model.model.requires_grad_(False)
+
+ if training_args.bits in [4, 8]:
+ from peft import prepare_model_for_kbit_training
+ model.config.torch_dtype = (
+ torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
+ model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)
+
+ if training_args.gradient_checkpointing:
+ if hasattr(model, "enable_input_require_grads"):
+ model.enable_input_require_grads()
+ else:
+ def make_inputs_require_grad(module, input, output):
+ output.requires_grad_(True)
+
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
+
+ if training_args.lora_enable:
+ from peft import LoraConfig, get_peft_model
+ lora_config = LoraConfig(
+ r=training_args.lora_r,
+ lora_alpha=training_args.lora_alpha,
+ target_modules=find_all_linear_names(model),
+ lora_dropout=training_args.lora_dropout,
+ bias=training_args.lora_bias,
+ task_type="CAUSAL_LM",
+ )
+ if training_args.bits == 16:
+ if training_args.bf16:
+ model.to(torch.bfloat16)
+ if training_args.fp16:
+ model.to(torch.float16)
+ rank0_print("Adding LoRA adapters...")
+ # model = get_peft_model(model, lora_config)
+ if training_args.lora_path:
+ from peft import PeftModel
+ model = PeftModel.from_pretrained(model, training_args.lora_path)
+ print("load lora weight ok")
+
+
+ for name, param in model.named_parameters():
+ if "lora" in name and not param.requires_grad:
+ param.requires_grad = True
+
+ else:
+ model = get_peft_model(model, lora_config) # Model is defined here!
+
+ if 'mpt' in model_args.model_name_or_path:
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ model_max_length=training_args.model_max_length,
+ padding_side="right"
+ )
+ else:
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ model_max_length=training_args.model_max_length,
+ padding_side="right",
+ use_fast=False,
+ )
+
+ if model_args.version == "v0":
+ if tokenizer.pad_token is None:
+ smart_tokenizer_and_embedding_resize(
+ special_tokens_dict=dict(pad_token="[PAD]"),
+ tokenizer=tokenizer,
+ model=model,
+ )
+ elif model_args.version == "v0.5":
+ tokenizer.pad_token = tokenizer.unk_token
+ else:
+ tokenizer.pad_token = tokenizer.unk_token
+ if model_args.version in conversation_lib.conv_templates:
+ conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version]
+ else:
+ conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1"]
+
+ if model_args.vision_tower is not None:
+ model.get_model().initialize_other_modules(
+ model_args=model_args,
+ fsdp=training_args.fsdp
+ )
+
+ # freeze vision encoder
+ vision_tower = model.get_model().vision_tower
+ vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device)
+ for p in model.get_model().vision_tower.parameters():
+ p.requires_grad = False
+
+ data_args.is_multimodal = True
+
+ model.config.tokenizer_padding_side = tokenizer.padding_side
+ model.config.tokenizer_model_max_length = tokenizer.model_max_length
+
+ model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter
+ model.config.get_pc_tokens_way = training_args.get_pc_tokens_way = model_args.get_pc_tokens_way
+
+
+ if training_args.lora_enable == False:
+ model.requires_grad_(False)
+
+ for p in model.get_model().mm_projector.parameters():
+ p.requires_grad = True
+
+ if model_args.encoder_type == 'pc_encoder' and model_args.get_pc_tokens_way == "OM_Pooling":
+ for p in model.get_model().vision_tower.visual.norm.parameters():
+ p.requires_grad = True
+
+ for p in model.get_model().vision_tower.visual.fc_norm.parameters():
+ p.requires_grad = True
+
+ for p in model.get_model().vision_tower.trans2embed.parameters():
+ p.requires_grad = True
+
+
+ model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter
+ if training_args.freeze_mm_mlp_adapter:
+ for p in model.get_model().mm_projector.parameters():
+ p.requires_grad = False
+
+ if training_args.bits in [4, 8]:
+ model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device)
+
+ model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end
+ model.config.mm_projector_lr = training_args.mm_projector_lr
+ training_args.use_im_start_end = model_args.mm_use_im_start_end
+ model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token
+ model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer)
+
+
+ if training_args.bits in [4, 8]:
+ from peft.tuners.lora import LoraLayer
+ for name, module in model.named_modules():
+ if isinstance(module, LoraLayer):
+ if training_args.bf16:
+ module = module.to(torch.bfloat16)
+ if 'norm' in name:
+ module = module.to(torch.float32)
+ if 'lm_head' in name or 'embed_tokens' in name:
+ if hasattr(module, 'weight'):
+ if training_args.bf16 and module.weight.dtype == torch.float32:
+ module = module.to(torch.bfloat16)
+
+ data_module = make_supervised_data_module(tokenizer=tokenizer,
+ data_args=data_args, model_args=model_args)
+ #####
+ for name, param in model.named_parameters():
+ if param.requires_grad: # ๅฆๆๅๆฐ้่ฆๆขฏๅบฆ๏ผ้ฃไนๅฎๅฐ่ขซๆดๆฐ
+ print(f"Parameter {name} will be updated.")
+
+ print()
+ num_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ print(f"Total Number of trainable parameters: {num_trainable_params}")
+
+ model_params = sum(p.numel() for p in model.parameters())
+ print(f"Total Number of model parameters: {model_params}")
+ #######
+
+
+
+ trainer = LLaVATrainer(model=model,
+ tokenizer=tokenizer,
+ args=training_args,
+ **data_module)
+
+ if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
+ trainer.train(resume_from_checkpoint=True)
+ else:
+ trainer.train()
+
+ trainer.save_state()
+
+ model.config.use_cache = True
+
+ # stage 2 or stage 3
+ if training_args.lora_enable:
+ state_dict = get_peft_state_maybe_zero_3(
+ model.named_parameters(), training_args.lora_bias
+ )
+ non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(
+ model.named_parameters()
+ )
+ if training_args.local_rank == 0 or training_args.local_rank == -1:
+ model.config.save_pretrained(training_args.output_dir)
+ model.save_pretrained(training_args.output_dir, state_dict=state_dict)
+ torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin'))
+
+ else:
+ safe_save_model_for_hf_trainer(trainer=trainer,
+ output_dir=training_args.output_dir)
+
+
+
+
+if __name__ == "__main__":
+ train()
+
diff --git a/llava/train/train_mem.py b/llava/train/train_mem.py
new file mode 100644
index 0000000000000000000000000000000000000000..a554f69d6461901be6c5fbb8e61fd3ec4f2a94bb
--- /dev/null
+++ b/llava/train/train_mem.py
@@ -0,0 +1,5 @@
+from llava.train.train import train
+
+if __name__ == "__main__":
+ train(attn_implementation="flash_attention_2")
+ print('train end')
diff --git a/llava/train/train_xformers.py b/llava/train/train_xformers.py
new file mode 100644
index 0000000000000000000000000000000000000000..23a59bf4ee0f365de9fbf3838836b170058126d6
--- /dev/null
+++ b/llava/train/train_xformers.py
@@ -0,0 +1,13 @@
+# Make it more memory efficient by monkey patching the LLaMA model with xformers attention.
+
+# Need to call this before importing transformers.
+from llava.train.llama_xformers_attn_monkey_patch import (
+ replace_llama_attn_with_xformers_attn,
+)
+
+replace_llama_attn_with_xformers_attn()
+
+from llava.train.train import train
+
+if __name__ == "__main__":
+ train()
diff --git a/llava/utils.py b/llava/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..925a3f4e82ae12c9e3fcd39e3f68cf8c853e8c5a
--- /dev/null
+++ b/llava/utils.py
@@ -0,0 +1,281 @@
+import datetime
+import logging
+import logging.handlers
+import os
+import sys
+
+import requests
+
+from llava.constants import LOGDIR
+
+server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
+moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
+
+handler = None
+
+
+def build_logger(logger_name, logger_filename):
+ global handler
+
+ formatter = logging.Formatter(
+ fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ )
+
+ # Set the format of root handlers
+ if not logging.getLogger().handlers:
+ logging.basicConfig(level=logging.INFO)
+ logging.getLogger().handlers[0].setFormatter(formatter)
+
+ # Redirect stdout and stderr to loggers
+ stdout_logger = logging.getLogger("stdout")
+ stdout_logger.setLevel(logging.INFO)
+ sl = StreamToLogger(stdout_logger, logging.INFO)
+ sys.stdout = sl
+
+ stderr_logger = logging.getLogger("stderr")
+ stderr_logger.setLevel(logging.ERROR)
+ sl = StreamToLogger(stderr_logger, logging.ERROR)
+ sys.stderr = sl
+
+ # Get logger
+ logger = logging.getLogger(logger_name)
+ logger.setLevel(logging.INFO)
+
+ # Add a file handler for all loggers
+ if handler is None:
+ os.makedirs(LOGDIR, exist_ok=True)
+ filename = os.path.join(LOGDIR, logger_filename)
+ handler = logging.handlers.TimedRotatingFileHandler(
+ filename, when='D', utc=True, encoding='UTF-8')
+ handler.setFormatter(formatter)
+
+ for name, item in logging.root.manager.loggerDict.items():
+ if isinstance(item, logging.Logger):
+ item.addHandler(handler)
+
+ return logger
+
+
+class StreamToLogger(object):
+ """
+ Fake file-like stream object that redirects writes to a logger instance.
+ """
+ def __init__(self, logger, log_level=logging.INFO):
+ self.terminal = sys.stdout
+ self.logger = logger
+ self.log_level = log_level
+ self.linebuf = ''
+
+ def __getattr__(self, attr):
+ return getattr(self.terminal, attr)
+
+ def write(self, buf):
+ temp_linebuf = self.linebuf + buf
+ self.linebuf = ''
+ for line in temp_linebuf.splitlines(True):
+ # From the io.TextIOWrapper docs:
+ # On output, if newline is None, any '\n' characters written
+ # are translated to the system default line separator.
+ # By default sys.stdout.write() expects '\n' newlines and then
+ # translates them so this is still cross platform.
+ if line[-1] == '\n':
+ self.logger.log(self.log_level, line.rstrip())
+ else:
+ self.linebuf += line
+
+ def flush(self):
+ if self.linebuf != '':
+ self.logger.log(self.log_level, self.linebuf.rstrip())
+ self.linebuf = ''
+
+
+def disable_torch_init():
+ """
+ Disable the redundant torch default initialization to accelerate model creation.
+ """
+ import torch
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
+
+
+def violates_moderation(text):
+ """
+ Check whether the text violates OpenAI moderation API.
+ """
+ url = "https://api.openai.com/v1/moderations"
+ headers = {"Content-Type": "application/json",
+ "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
+ text = text.replace("\n", "")
+ data = "{" + '"input": ' + f'"{text}"' + "}"
+ data = data.encode("utf-8")
+ try:
+ ret = requests.post(url, headers=headers, data=data, timeout=5)
+ flagged = ret.json()["results"][0]["flagged"]
+ except requests.exceptions.RequestException as e:
+ flagged = False
+ except KeyError as e:
+ flagged = False
+
+ return flagged
+
+
+def pretty_print_semaphore(semaphore):
+ if semaphore is None:
+ return "None"
+ return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
+
+
+
+
+# Modified from github.com/openai/CLIP
+import gzip
+import html
+import os
+from functools import lru_cache
+
+import ftfy
+import regex as re
+import torch
+
+
+@lru_cache()
+def default_bpe():
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
+
+
+@lru_cache()
+def bytes_to_unicode():
+ """
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
+ The reversible bpe codes work on unicode strings.
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
+ """
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("ยก"), ord("ยฌ")+1))+list(range(ord("ยฎ"), ord("รฟ")+1))
+ cs = bs[:]
+ n = 0
+ for b in range(2**8):
+ if b not in bs:
+ bs.append(b)
+ cs.append(2**8+n)
+ n += 1
+ cs = [chr(n) for n in cs]
+ return dict(zip(bs, cs))
+
+
+def get_pairs(word):
+ """Return set of symbol pairs in a word.
+ Word is represented as tuple of symbols (symbols being variable-length strings).
+ """
+ pairs = set()
+ prev_char = word[0]
+ for char in word[1:]:
+ pairs.add((prev_char, char))
+ prev_char = char
+ return pairs
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r'\s+', ' ', text)
+ text = text.strip()
+ return text
+
+
+class SimpleTokenizer(object):
+ def __init__(self, bpe_path: str = default_bpe()):
+ self.byte_encoder = bytes_to_unicode()
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
+ merges = merges[1:49152-256-2+1]
+ merges = [tuple(merge.split()) for merge in merges]
+ vocab = list(bytes_to_unicode().values())
+ vocab = vocab + [v+'' for v in vocab]
+ for merge in merges:
+ vocab.append(''.join(merge))
+ vocab.extend(['<|startoftext|>', '<|endoftext|>'])
+ self.encoder = dict(zip(vocab, range(len(vocab))))
+ self.decoder = {v: k for k, v in self.encoder.items()}
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
+ self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
+ self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
+
+ def bpe(self, token):
+ if token in self.cache:
+ return self.cache[token]
+ word = tuple(token[:-1]) + ( token[-1] + '',)
+ pairs = get_pairs(word)
+
+ if not pairs:
+ return token+''
+
+ while True:
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
+ if bigram not in self.bpe_ranks:
+ break
+ first, second = bigram
+ new_word = []
+ i = 0
+ while i < len(word):
+ try:
+ j = word.index(first, i)
+ new_word.extend(word[i:j])
+ i = j
+ except:
+ new_word.extend(word[i:])
+ break
+
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
+ new_word.append(first+second)
+ i += 2
+ else:
+ new_word.append(word[i])
+ i += 1
+ new_word = tuple(new_word)
+ word = new_word
+ if len(word) == 1:
+ break
+ else:
+ pairs = get_pairs(word)
+ word = ' '.join(word)
+ self.cache[token] = word
+ return word
+
+ def encode(self, text):
+ bpe_tokens = []
+ text = whitespace_clean(basic_clean(text)).lower()
+ for token in re.findall(self.pat, text):
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
+ return bpe_tokens
+
+ def decode(self, tokens):
+ text = ''.join([self.decoder[token] for token in tokens])
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
+ return text
+
+ def __call__(self, texts, context_length=77):
+ if isinstance(texts, str):
+ texts = [texts]
+
+ sot_token = self.encoder["<|startoftext|>"]
+ eot_token = self.encoder["<|endoftext|>"]
+ all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts]
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
+
+ for i, tokens in enumerate(all_tokens):
+ tokens = tokens[:context_length]
+ result[i, :len(tokens)] = torch.tensor(tokens)
+
+ if len(result) == 1:
+ return result[0]
+ return result
\ No newline at end of file
diff --git a/media/3_stage.png b/media/3_stage.png
new file mode 100644
index 0000000000000000000000000000000000000000..fcd0463c00a80812c1d3515a2c8e044ae1c8db49
--- /dev/null
+++ b/media/3_stage.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d323f2c34d0106d8b715ed25b1311288b1ab51acf8d3b0579b00e801a2d17cc4
+size 661884
diff --git a/media/T3D.png b/media/T3D.png
new file mode 100644
index 0000000000000000000000000000000000000000..2f9b089dfbec077c22469c10ab312e76372fc5dc
--- /dev/null
+++ b/media/T3D.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e69ff9e39350a1c93a1164ed292b0b8f7960fa06d85609d7549459423617e994
+size 631628
diff --git a/pointllm/__init__.py b/pointllm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e43701abfd68f05cd3bf1a85117b96c4ecc58299
--- /dev/null
+++ b/pointllm/__init__.py
@@ -0,0 +1 @@
+# from .model import PointLLMLlamaForCausalLM
diff --git a/pointllm/__pycache__/__init__.cpython-310.pyc b/pointllm/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b4c72fcce9338e08856de7ff5c5c5cd952498ac4
--- /dev/null
+++ b/pointllm/__pycache__/__init__.cpython-310.pyc
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9a3d96ef565a07125c5bdf62b971b9ee2c2c42c3061ea13e6fdbedb0e028ea3b
+size 145
diff --git a/pointllm/__pycache__/conversation.cpython-310.pyc b/pointllm/__pycache__/conversation.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0e24895a113703a22cf76b71d6f40b94acd388f5
--- /dev/null
+++ b/pointllm/__pycache__/conversation.cpython-310.pyc
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:873eb64a3c976012b8937d829394c0c34be004d23366630ef1ed2b618c491cd9
+size 10941
diff --git a/pointllm/__pycache__/utils.cpython-310.pyc b/pointllm/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ee62d606853021d17e8a23a0aeb173d10a8f5d2a
--- /dev/null
+++ b/pointllm/__pycache__/utils.cpython-310.pyc
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:aeeae6c9da03772d47df2d0fc687c893abedbefe3fdfaa1e7c9791be5f02eb32
+size 4720
diff --git a/pointllm/conversation.py b/pointllm/conversation.py
new file mode 100644
index 0000000000000000000000000000000000000000..80dcea89dec62a209d39d50d293959f4fdabb2b8
--- /dev/null
+++ b/pointllm/conversation.py
@@ -0,0 +1,375 @@
+import dataclasses
+from enum import auto, Enum
+from typing import List, Tuple
+
+
+class SeparatorStyle(Enum):
+ """Different separator style."""
+ SINGLE = auto()
+ TWO = auto()
+ MPT = auto()
+
+
+@dataclasses.dataclass
+class Conversation:
+ """A class that keeps all conversation history."""
+ system: str
+ roles: List[str]
+ messages: List[List[str]]
+ offset: int
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
+ sep: str = "###"
+ sep2: str = None
+ version: str = "Unknown"
+
+ skip_next: bool = False
+
+ def reset(self):
+ self.messages = self.messages[:self.offset]
+
+ def get_prompt(self):
+ if self.sep_style == SeparatorStyle.SINGLE:
+ ret = self.system + self.sep
+ for role, message in self.messages:
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += role + ": " + message + self.sep
+ else:
+ ret += role + ":"
+ return ret
+ elif self.sep_style == SeparatorStyle.TWO:
+ seps = [self.sep, self.sep2]
+ ret = self.system + seps[0]
+ for i, (role, message) in enumerate(self.messages):
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += role + ": " + message + seps[i % 2]
+ else:
+ ret += role + ":"
+ return ret
+ if self.sep_style == SeparatorStyle.MPT:
+ ret = self.system + self.sep
+ for role, message in self.messages:
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += role + message + self.sep
+ else:
+ ret += role
+ return ret
+ else:
+ raise ValueError(f"Invalid style: {self.sep_style}")
+
+ def append_message(self, role, message):
+ self.messages.append([role, message])
+
+ def pop_last_none_message(self):
+ # * pop the last message if it's None, this is used for multi-round dialogue
+ if self.messages[-1][1] is None:
+ self.messages.pop()
+
+ def get_images(self, return_pil=False):
+ images = []
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
+ if i % 2 == 0:
+ if type(msg) is tuple:
+ import base64
+ from io import BytesIO
+ from PIL import Image
+ msg, image, image_process_mode = msg
+ if image_process_mode == "Pad":
+ def expand2square(pil_img, background_color=(122, 116, 104)):
+ width, height = pil_img.size
+ if width == height:
+ return pil_img
+ elif width > height:
+ result = Image.new(pil_img.mode, (width, width), background_color)
+ result.paste(pil_img, (0, (width - height) // 2))
+ return result
+ else:
+ result = Image.new(pil_img.mode, (height, height), background_color)
+ result.paste(pil_img, ((height - width) // 2, 0))
+ return result
+
+ image = expand2square(image)
+ elif image_process_mode == "Crop":
+ pass
+ elif image_process_mode == "Resize":
+ image = image.resize((224, 224))
+ else:
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
+ max_hw, min_hw = max(image.size), min(image.size)
+ aspect_ratio = max_hw / min_hw
+ max_len, min_len = 800, 400
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
+ longest_edge = int(shortest_edge * aspect_ratio)
+ W, H = image.size
+ if H > W:
+ H, W = longest_edge, shortest_edge
+ else:
+ H, W = shortest_edge, longest_edge
+ image = image.resize((W, H))
+ if return_pil:
+ images.append(image)
+ else:
+ buffered = BytesIO()
+ image.save(buffered, format="JPEG")
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
+ images.append(img_b64_str)
+ return images
+
+ def to_gradio_chatbot(self):
+ ret = []
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
+ if i % 2 == 0:
+ if type(msg) is tuple:
+ import base64
+ from io import BytesIO
+ msg, image, image_process_mode = msg
+ max_hw, min_hw = max(image.size), min(image.size)
+ aspect_ratio = max_hw / min_hw
+ max_len, min_len = 800, 400
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
+ longest_edge = int(shortest_edge * aspect_ratio)
+ W, H = image.size
+ if H > W:
+ H, W = longest_edge, shortest_edge
+ else:
+ H, W = shortest_edge, longest_edge
+ image = image.resize((W, H))
+ # image = image.resize((224, 224))
+ buffered = BytesIO()
+ image.save(buffered, format="JPEG")
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
+ img_str = f'
'
+ msg = msg.replace('', img_str)
+ ret.append([msg, None])
+ else:
+ ret[-1][-1] = msg
+ return ret
+
+ def copy(self):
+ return Conversation(
+ system=self.system,
+ roles=self.roles,
+ messages=[[x, y] for x, y in self.messages],
+ offset=self.offset,
+ sep_style=self.sep_style,
+ sep=self.sep,
+ sep2=self.sep2)
+
+ def dict(self):
+ if len(self.get_images()) > 0:
+ return {
+ "system": self.system,
+ "roles": self.roles,
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
+ "offset": self.offset,
+ "sep": self.sep,
+ "sep2": self.sep2,
+ }
+ return {
+ "system": self.system,
+ "roles": self.roles,
+ "messages": self.messages,
+ "offset": self.offset,
+ "sep": self.sep,
+ "sep2": self.sep2,
+ }
+
+
+conv_v1 = Conversation(
+ system="A chat between a curious human and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
+ roles=("Human", "Assistant"),
+ messages=(
+ ("Human", "Give three tips for staying healthy."),
+ ("Assistant",
+ "Sure, here are three tips for staying healthy:\n"
+ "1. Exercise regularly: Regular physical activity can help improve your overall health and wellbeing. "
+ "It can also help reduce your risk of chronic conditions such as obesity, diabetes, heart disease, "
+ "and certain cancers. Aim for at least 150 minutes of moderate-intensity aerobic exercise or "
+ "75 minutes of vigorous-intensity aerobic exercise per week, along with muscle-strengthening "
+ "activities at least two days per week.\n"
+ "2. Eat a balanced diet: Eating a balanced diet that is rich in fruits, "
+ "vegetables, whole grains, lean proteins, and healthy fats can help support "
+ "your overall health. Try to limit your intake of processed and high-sugar foods, "
+ "and aim to drink plenty of water throughout the day.\n"
+ "3. Get enough sleep: Getting enough quality sleep is essential for your physical "
+ "and mental health. Adults should aim for seven to nine hours of sleep per night. "
+ "Establish a regular sleep schedule and try to create a relaxing bedtime routine to "
+ "help improve the quality of your sleep.")
+ ),
+ offset=2,
+ sep_style=SeparatorStyle.SINGLE,
+ sep="###",
+)
+
+conv_v1_2 = Conversation(
+ system="A chat between a curious human and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
+ roles=("Human", "Assistant"),
+ messages=(
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
+ ("Assistant",
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
+ "renewable and non-renewable energy sources:\n"
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
+ "energy sources are finite and will eventually run out.\n"
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
+ "and other negative effects.\n"
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
+ "have lower operational costs than non-renewable sources.\n"
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
+ "locations than non-renewable sources.\n"
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
+ ),
+ offset=2,
+ sep_style=SeparatorStyle.SINGLE,
+ sep="###",
+)
+
+conv_vicuna_v1_1 = Conversation(
+ system="A chat between a curious user and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
+ roles=("USER", "ASSISTANT"),
+ version="v1",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.TWO,
+ sep=" ",
+ sep2="",
+)
+
+conv_mpt = Conversation(
+ system="""<|im_start|>system
+- You are a helpful language and vision assistant.
+- You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.
+- You should follow the instructions carefully and explain your answers in detail.""",
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
+ version="mpt",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.MPT,
+ sep="<|im_end|>",
+)
+
+conv_mpt_text = Conversation(
+ system="""<|im_start|>system
+- You are a helpful assistant chatbot trained by MosaicML.
+- You answer questions.
+- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
+- You are more than just an information source, you are also able to write poetry, short stories, and make jokes.""",
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
+ version="mpt",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.MPT,
+ sep="<|im_end|>",
+)
+
+conv_bair_v1 = Conversation(
+ system="BEGINNING OF CONVERSATION:",
+ roles=("USER", "GPT"),
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.TWO,
+ sep=" ",
+ sep2="",
+)
+
+simple_conv = Conversation(
+ system="A chat between a curious human and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
+ roles=("Human", "Assistant"),
+ messages=(
+ ("Human", "Hi!"),
+ ("Assistant", "Hi there! How can I help you today?")
+ ),
+ offset=2,
+ sep_style=SeparatorStyle.SINGLE,
+ sep="###",
+)
+
+simple_conv_multimodal = Conversation(
+ system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab."
+ "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
+ "Follow the instructions carefully and explain your answers in detail.",
+ roles=("Human", "Assistant"),
+ messages=(
+ ("Human", "Hi!"),
+ ("Assistant", "Hi there! How can I help you today?\n")
+ ),
+ offset=2,
+ sep_style=SeparatorStyle.SINGLE,
+ sep="###",
+)
+
+simple_conv_mpt_multimodal = Conversation(
+ system="""<|im_start|>system
+- You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab.
+- You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.
+- You should follow the instructions carefully and explain your answers in detail.""",
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
+ version="mpt",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.MPT,
+ sep="<|im_end|>",
+)
+
+simple_conv_legacy = Conversation(
+ system="You are LLaVA, a large language model trained by UW Madison WAIV Lab."
+ "You are designed to assist human with a variety of tasks using natural language."
+ "Follow the instructions carefully.",
+ roles=("Human", "Assistant"),
+ messages=(
+ ("Human", "Hi!\n\n### Response:"),
+ ("Assistant", "Hi there! How can I help you today?\n")
+ ),
+ offset=2,
+ sep_style=SeparatorStyle.SINGLE,
+ sep="###",
+)
+
+conv_llava_v1 = Conversation(
+ system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab."
+ "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
+ "Follow the instructions carefully and explain your answers in detail.",
+ roles=("USER", "ASSISTANT"),
+ version="v1",
+ messages=(),
+ offset=0,
+ sep_style=SeparatorStyle.TWO,
+ sep=" ",
+ sep2="",
+)
+
+default_conversation = conv_v1_2
+conv_templates = {
+ "default": conv_v1_2,
+ "simple": simple_conv,
+ "simple_legacy": simple_conv_legacy,
+ "multimodal": simple_conv_multimodal,
+ "mpt_multimodal": simple_conv_mpt_multimodal,
+ "llava_v1": conv_llava_v1,
+
+ # fastchat
+ "v1": conv_v1_2,
+ "bair_v1": conv_bair_v1,
+ "vicuna_v1_1": conv_vicuna_v1_1,
+ "mpt": conv_mpt,
+ "mpt_text": conv_mpt_text,
+}
+
+if __name__ == "__main__":
+ print(default_conversation.get_prompt())
\ No newline at end of file
diff --git a/pointllm/data/__init__.py b/pointllm/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..aedb1579856d54f38f29f8903ac3ee651317e06e
--- /dev/null
+++ b/pointllm/data/__init__.py
@@ -0,0 +1,5 @@
+from .utils import load_objaverse_point_cloud, pc_norm, farthest_point_sample
+from .object_point_dataset import ObjectPointCloudDataset, make_object_point_data_module
+from .modelnet import ModelNet
+
+# from .scanobjectNN import ScanObjectNN
\ No newline at end of file
diff --git a/pointllm/data/__pycache__/__init__.cpython-310.pyc b/pointllm/data/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4a4bcf7e2fe809681e8584ded6b538bab7c543e5
--- /dev/null
+++ b/pointllm/data/__pycache__/__init__.cpython-310.pyc
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e5968aa7aa3df12d9f10e0729817a6da1cb2f47c9c5f43bc341d686165c2b813
+size 405
diff --git a/pointllm/data/__pycache__/modelnet.cpython-310.pyc b/pointllm/data/__pycache__/modelnet.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..83a5c86acb43299781d401f760b119c9439fbf7c
--- /dev/null
+++ b/pointllm/data/__pycache__/modelnet.cpython-310.pyc
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:11535325eede16ad0700f46f828c06101e52fa250591cd74f2cbb6f7260e401f
+size 4508
diff --git a/pointllm/data/__pycache__/object_point_dataset.cpython-310.pyc b/pointllm/data/__pycache__/object_point_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6535d4c0f1487434fa8d672a5e542ebac986ae5d
--- /dev/null
+++ b/pointllm/data/__pycache__/object_point_dataset.cpython-310.pyc
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fe2c95376ec8d60debb53c81d18acc5dfd99719eb0ae55310596d1aa39c35ec0
+size 7034
diff --git a/pointllm/data/__pycache__/scanobjectNN.cpython-310.pyc b/pointllm/data/__pycache__/scanobjectNN.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..68928713e3edf5dc0efb1d5b67a07d483347f692
--- /dev/null
+++ b/pointllm/data/__pycache__/scanobjectNN.cpython-310.pyc
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f0f4130cced07e4741580d40a8fa8107bf4e280be701cee84dfdfcc5e2551a06
+size 4157
diff --git a/pointllm/data/__pycache__/utils.cpython-310.pyc b/pointllm/data/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4cacf4f44d31113fd663524e05b71cc432979e74
--- /dev/null
+++ b/pointllm/data/__pycache__/utils.cpython-310.pyc
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e1a4cc55f300e4beea0a94778617233405e8997df2e603f6f098c0e933a022e9
+size 7869
diff --git a/pointllm/data/modelnet.py b/pointllm/data/modelnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ebf6b7a08f84a5425e66c51feb25a2b04a8df08
--- /dev/null
+++ b/pointllm/data/modelnet.py
@@ -0,0 +1,149 @@
+import os
+import torch
+import numpy as np
+import pickle
+from torch.utils.data import Dataset
+from pointllm.utils import *
+from pointllm.data.utils import *
+
+class ModelNet(Dataset):
+ def __init__(self, data_path, config_path, split, subset_nums=-1, use_color=False):
+ """
+ Args:
+ data_args:
+ split: train or test
+ """
+ super(ModelNet, self).__init__()
+
+ if config_path is None:
+ # * use the default config file in the same dir
+ config_path = os.path.join(os.path.dirname(__file__), "modelnet_config", "ModelNet40.yaml")
+
+ config = cfg_from_yaml_file(config_path)
+ # * check data path
+ self.root = data_path
+
+ if not os.path.exists(self.root):
+ print(f"Data path {self.root} does not exist. Please check your data path.")
+ exit()
+
+ self.npoints = config.npoints
+ self.num_category = config.NUM_CATEGORY # * should be 40
+ self.random_sample = config.random_sampling
+ self.use_height = config.use_height
+ self.use_normals = config.USE_NORMALS
+ self.subset_nums = subset_nums
+ self.normalize_pc = True
+ self.use_color = use_color
+
+ if self.use_height or self.use_normals:
+ print(f"Warning: Usually we don't use height or normals for shapenet but use_height: {self.use_height} and \
+ use_normals: {self.use_normals}.")
+
+ self.split = split
+ assert (self.split == 'train' or self.split == 'test')
+
+ self.catfile = os.path.join(os.path.dirname(__file__), "modelnet_config", 'modelnet40_shape_names_modified.txt')
+
+ # "tv_stand" -> "tv stand"
+ self.categories = [line.rstrip() for line in open(self.catfile)] # * list of category names
+
+ self.save_path = os.path.join(self.root,
+ 'modelnet%d_%s_%dpts_fps.dat' % (self.num_category, self.split, self.npoints))
+
+ print('Load processed data from %s...' % self.save_path)
+ with open(self.save_path, 'rb') as f:
+ self.list_of_points, self.list_of_labels = pickle.load(f) # * ndarray of N, C: (8192, 6) (xyz and normals)
+
+ if self.subset_nums > 0:
+ # * set random seed
+ import random
+ random.seed(0)
+ # * random choose subset_nums
+ idxs = random.sample(range(len(self.list_of_labels)), self.subset_nums)
+ self.list_of_labels = [self.list_of_labels[idx] for idx in idxs]
+ self.list_of_points = [self.list_of_points[idx] for idx in idxs]
+
+ # * print len
+ print(f"Load {len(self.list_of_points)} data from {self.save_path}.")
+
+ def __len__(self):
+ return len(self.list_of_labels)
+
+ def _get_item(self, index):
+ point_set, label = self.list_of_points[index], self.list_of_labels[index]
+
+ if self.npoints < point_set.shape[0]:
+ if self.random_sample:
+ # * random sample
+ point_set = point_set[np.random.choice(point_set.shape[0], self.npoints, replace=False)]
+ else:
+ point_set = farthest_point_sample(point_set, self.npoints)
+
+ point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])
+ if not self.use_normals:
+ point_set = point_set[:, 0:3]
+
+ if self.use_height:
+ self.gravity_dim = 1
+ height_array = point_set[:, self.gravity_dim:self.gravity_dim + 1] - point_set[:,
+ self.gravity_dim:self.gravity_dim + 1].min()
+ point_set = np.concatenate((point_set, height_array), axis=1)
+
+ # point_set = np.concatenate((point_set, np.zeros_like(point_set)), axis=-1) if self.use_color else point_set
+ # point_set = np.concatenate((point_set, np.ones_like(point_set)*0.4), axis=-1) if self.use_color else point_set
+ point_set = np.concatenate((point_set, np.zeros_like(point_set)), axis=-1) if self.use_color else point_set
+
+ return point_set, label.item() # * ndarray, int
+
+ def pc_norm(self, pc):
+ """ pc: NxC, return NxC """
+ xyz = pc[:, :3]
+ other_feature = pc[:, 3:]
+
+ centroid = np.mean(xyz, axis=0)
+ xyz = xyz - centroid
+ m = np.max(np.sqrt(np.sum(xyz ** 2, axis=1)))
+ xyz = xyz / m
+
+ pc = np.concatenate((xyz, other_feature), axis=1)
+ return pc
+
+ def __getitem__(self, index):
+ points, label = self._get_item(index)
+ pt_idxs = np.arange(0, points.shape[0]) # 2048
+ if self.split == 'train':
+ np.random.shuffle(pt_idxs)
+ current_points = points[pt_idxs].copy()
+
+ if self.normalize_pc:
+ # * modelnet point cloud is already normalized
+ current_points = self.pc_norm(current_points)
+
+ current_points = torch.from_numpy(current_points).float() # * N, C tensors
+ label_name = self.categories[int(label)]
+
+ data_dict = {
+ "indice": index, # * int
+ "point_clouds": current_points, # * tensor of N, C
+ "labels": label, # * int
+ "label_names": label_name # * str
+ }
+
+ return data_dict
+
+if __name__ == '__main__':
+ import argparse
+
+ parser = argparse.ArgumentParser(description='ModelNet Dataset')
+
+ parser.add_argument("--config_path", type=str, default=None, help="config file path.")
+ parser.add_argument("--split", type=str, default="test", help="train or test.")
+ parser.add_argument("--subset_nums", type=int, default=200)
+
+ args = parser.parse_args()
+
+ dataset = ModelNet(config_path=args.config_path, split=args.split, subset_nums=args.subset_nums)
+
+ # * get the first item
+ print(dataset[0])
\ No newline at end of file
diff --git a/pointllm/data/modelnet_config/ModelNet40.yaml b/pointllm/data/modelnet_config/ModelNet40.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..98cc59f5c4e0777f6caacaafa8ef7b3c18edf367
--- /dev/null
+++ b/pointllm/data/modelnet_config/ModelNet40.yaml
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6a99380013d910db59646d0b48a969dbdbbb8c13f43f6b7ffff20023bb03a14c
+size 155
diff --git a/pointllm/data/modelnet_config/modelnet40_shape_names_modified.txt b/pointllm/data/modelnet_config/modelnet40_shape_names_modified.txt
new file mode 100644
index 0000000000000000000000000000000000000000..877d86ecf3477b06ea35f04ed56678607d81e0a3
--- /dev/null
+++ b/pointllm/data/modelnet_config/modelnet40_shape_names_modified.txt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:25089598c178794aa4fae0a5f18b9927b39b3ba008bb44c326e2c3e50777a228
+size 275
diff --git a/pointllm/data/modelnet_show.py b/pointllm/data/modelnet_show.py
new file mode 100644
index 0000000000000000000000000000000000000000..524c931c30ec22325c5416dd61c92cb675c48e82
--- /dev/null
+++ b/pointllm/data/modelnet_show.py
@@ -0,0 +1,206 @@
+import os
+import torch
+import numpy as np
+import pickle
+from torch.utils.data import Dataset
+from pointllm.utils import *
+from pointllm.data.utils import *
+
+class ModelNet(Dataset):
+ def __init__(self, data_path, config_path, split, subset_nums=-1, use_color=False):
+ """
+ Args:
+ data_args:
+ split: train or test
+ """
+ super(ModelNet, self).__init__()
+
+ if config_path is None:
+ # * use the default config file in the same dir
+ config_path = os.path.join(os.path.dirname(__file__), "modelnet_config", "ModelNet40.yaml")
+
+ config = cfg_from_yaml_file(config_path)
+ # * check data path
+ self.root = data_path
+
+ if not os.path.exists(self.root):
+ print(f"Data path {self.root} does not exist. Please check your data path.")
+ exit()
+
+ self.npoints = config.npoints
+ self.num_category = config.NUM_CATEGORY # * should be 40
+ self.random_sample = config.random_sampling
+ self.use_height = config.use_height
+ self.use_normals = config.USE_NORMALS
+ self.subset_nums = subset_nums
+ self.normalize_pc = True
+ self.use_color = use_color
+
+ if self.use_height or self.use_normals:
+ print(f"Warning: Usually we don't use height or normals for shapenet but use_height: {self.use_height} and \
+ use_normals: {self.use_normals}.")
+
+ self.split = split
+ assert (self.split == 'train' or self.split == 'test')
+
+ self.catfile = os.path.join(os.path.dirname(__file__), "modelnet_config", 'modelnet40_shape_names_modified.txt')
+
+ # "tv_stand" -> "tv stand"
+ self.categories = [line.rstrip() for line in open(self.catfile)] # * list of category names
+
+ self.save_path = os.path.join(self.root,
+ 'modelnet%d_%s_%dpts_fps.dat' % (self.num_category, self.split, self.npoints))
+
+ print('Load processed data from %s...' % self.save_path)
+ with open(self.save_path, 'rb') as f:
+ self.list_of_points, self.list_of_labels = pickle.load(f) # * ndarray of N, C: (8192, 6) (xyz and normals)
+
+ if self.subset_nums > 0:
+ # * set random seed
+ import random
+ random.seed(0)
+ # * random choose subset_nums
+ idxs = random.sample(range(len(self.list_of_labels)), self.subset_nums)
+ self.list_of_labels = [self.list_of_labels[idx] for idx in idxs]
+ self.list_of_points = [self.list_of_points[idx] for idx in idxs]
+
+ # * print len
+ print(f"Load {len(self.list_of_points)} data from {self.save_path}.")
+
+ def __len__(self):
+ return len(self.list_of_labels)
+
+ def _get_item(self, index):
+ point_set, label = self.list_of_points[index], self.list_of_labels[index]
+
+ if self.npoints < point_set.shape[0]:
+ if self.random_sample:
+ # * random sample
+ point_set = point_set[np.random.choice(point_set.shape[0], self.npoints, replace=False)]
+ else:
+ point_set = farthest_point_sample(point_set, self.npoints)
+
+ point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])
+ if not self.use_normals:
+ point_set = point_set[:, 0:3]
+
+ if self.use_height:
+ self.gravity_dim = 1
+ height_array = point_set[:, self.gravity_dim:self.gravity_dim + 1] - point_set[:,
+ self.gravity_dim:self.gravity_dim + 1].min()
+ point_set = np.concatenate((point_set, height_array), axis=1)
+
+ point_set = np.concatenate((point_set, np.zeros_like(point_set)), axis=-1) if self.use_color else point_set
+
+ return point_set, label.item() # * ndarray, int
+
+ def our_get_item(self, index):
+ point_set, label = self.list_of_points[index], self.list_of_labels[index]
+
+ if self.npoints < point_set.shape[0]:
+ if self.random_sample:
+ # * random sample
+ point_set = point_set[np.random.choice(point_set.shape[0], self.npoints, replace=False)]
+ else:
+ point_set = farthest_point_sample(point_set, self.npoints)
+
+ point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])
+ if not self.use_normals:
+ point_set = point_set[:, 0:3]
+
+ if self.use_height:
+ self.gravity_dim = 1
+ height_array = point_set[:, self.gravity_dim:self.gravity_dim + 1] - point_set[:,
+ self.gravity_dim:self.gravity_dim + 1].min()
+ point_set = np.concatenate((point_set, height_array), axis=1)
+
+ point_set = np.concatenate((point_set, np.zeros_like(point_set)), axis=-1) if self.use_color else point_set
+
+ label_name = self.categories[int(label)]
+
+ data_dict = {
+ "indice": index, # * int
+ "point_clouds": point_set, # * ndarray
+ "labels": label.item(), # * int
+ "label_names": label_name # * str
+ }
+
+ return data_dict
+
+
+ def pc_norm(self, pc):
+ """ pc: NxC, return NxC """
+ xyz = pc[:, :3]
+ other_feature = pc[:, 3:]
+
+ centroid = np.mean(xyz, axis=0)
+ xyz = xyz - centroid
+ m = np.max(np.sqrt(np.sum(xyz ** 2, axis=1)))
+ xyz = xyz / m
+
+ pc = np.concatenate((xyz, other_feature), axis=1)
+ return pc
+
+ def __getitem__(self, index):
+ points, label = self._get_item(index)
+ pt_idxs = np.arange(0, points.shape[0]) # 2048
+ if self.split == 'train':
+ np.random.shuffle(pt_idxs)
+ current_points = points[pt_idxs].copy()
+
+ if self.normalize_pc:
+ # * modelnet point cloud is already normalized
+ current_points = self.pc_norm(current_points)
+
+ current_points = torch.from_numpy(current_points).float() # * N, C tensors
+ label_name = self.categories[int(label)]
+
+ data_dict = {
+ "indice": index, # * int
+ "point_clouds": current_points, # * tensor of N, C
+ "labels": label, # * int
+ "label_names": label_name # * str
+ }
+
+ return data_dict
+
+
+
+if __name__ == '__main__':
+ import argparse
+
+ parser = argparse.ArgumentParser(description='ModelNet Dataset')
+
+ parser.add_argument("--config_path", type=str, default=None, help="config file path.")
+ parser.add_argument("--split", type=str, default="test", help="train or test.")
+ parser.add_argument("--subset_nums", type=int, default=-1)
+
+ args = parser.parse_args()
+
+ dataset = ModelNet(data_path='/home/PointLLM/data/modelnet40_data', config_path=args.config_path, split="test", subset_nums=-1)
+
+
+
+
+ # new
+ show_lsit = [1202, 1884,1874, 158, 1104, 462, 137, 1322]
+
+ import numpy as np
+ import plyfile
+ import open3d as o3d
+
+ for i in show_lsit:
+ data_dict = dataset.our_get_item(i)
+
+ # print(data_dict["indice"])
+ # # print(data_dict["point_clouds"])
+ # print(data_dict["labels"])
+ # print(data_dict["label_names"])
+
+ print(data_dict["point_clouds"].shape)
+
+ pcd = o3d.geometry.PointCloud()
+ pcd.points = o3d.utility.Vector3dVector(data_dict["point_clouds"])
+ o3d.io.write_point_cloud("/home/TinyGPT-3D/TinyGPT-3D/show_mn40_pc/"+str(data_dict["indice"])+data_dict["label_names"]+".ply", pcd)
+ pc = data_dict["point_clouds"]
+ np.save("/home/TinyGPT-3D/TinyGPT-3D/show_mn40_pc/"+str(data_dict["indice"])+data_dict["label_names"]+'.npy', pc)
\ No newline at end of file
diff --git a/pointllm/data/object_point_dataset.py b/pointllm/data/object_point_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..521147f8ed419328648b51ed3d29baa660886dde
--- /dev/null
+++ b/pointllm/data/object_point_dataset.py
@@ -0,0 +1,251 @@
+import os
+import json
+import torch
+import numpy as np
+
+import copy
+import transformers
+from torch.utils.data import Dataset
+
+from .utils import *
+
+
+def make_object_point_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
+ """Make dataset and collator for Joint3Ddataset with text and point cloud data."""
+ """Initialize datasets."""
+
+ data_collator = DataCollatorForPointTextDataset(tokenizer=tokenizer)
+ if data_args.split_train_val:
+ print("Loading training datasets.")
+ train_dataset = ObjectPointCloudDataset(
+ split='train',
+ data_path=data_args.data_path,
+ anno_path=data_args.anno_path,
+ pointnum=data_args.pointnum,
+ conversation_types=data_args.conversation_types,
+ tokenizer=tokenizer,
+ use_color=data_args.use_color,
+ data_args=data_args
+ )
+ print("Done!")
+ if data_args.data_debug_num > 0:
+ print('Debug mode, using training set as val set.')
+ val_dataset = train_dataset
+ else:
+ # * make a val dataset
+ print("Loading validation datasets.")
+ val_dataset = ObjectPointCloudDataset(
+ split='val', # * load train split
+ data_path=data_args.data_path,
+ anno_path=data_args.anno_path,
+ pointnum=data_args.pointnum,
+ conversation_types=data_args.conversation_types,
+ tokenizer=tokenizer,
+ use_color=data_args.use_color,
+ data_args=data_args
+ )
+ return dict(train_dataset=train_dataset, eval_dataset=val_dataset, data_collator=data_collator)
+ else:
+ # * use all data as training data
+ train_dataset = ObjectPointCloudDataset(
+ split='train',
+ data_path=data_args.data_path,
+ anno_path=data_args.anno_path,
+ pointnum=data_args.pointnum,
+ conversation_types=data_args.conversation_types,
+ use_color=data_args.use_color,
+ tokenizer=tokenizer,
+ data_args=data_args
+ )
+ return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
+
+class ObjectPointCloudDataset(Dataset):
+ """Dataset utilities for objaverse."""
+ def __init__(self,
+ data_path=None,
+ anno_path=None,
+ tokenizer=None,
+ pointnum=8192,
+ split='train',
+ conversation_types=None, # * default is simple_des, used for stage1 pre-train
+ use_color=True,
+ data_args=None):
+
+ """
+ split: only considered when data_args.split_train_val is True.
+ conversation_types: tuple, used to filter the data, default is ('simple_description'), other types is:
+ "detailed_description", "single_round", "multi_round".
+ tokenizer: load point clouds only if None
+ """
+ super(ObjectPointCloudDataset, self).__init__()
+
+ """Initialize dataset with object point clouds and text"""
+ self.data_path = data_path
+ self.anno_path = anno_path
+ self.tokenizer = tokenizer
+ self.split = split
+ if conversation_types is None:
+ self.conversation_types = ("simple_description",)
+ else:
+ self.conversation_types = conversation_types
+
+ self.data_args = data_args
+ self.normalize_pc = True
+ self.use_color = use_color
+
+ self.pointnum = pointnum
+ self.point_backbone_config = data_args.point_backbone_config if data_args is not None else None
+ self.point_indicator = ''
+
+ # Load the data list from JSON
+ print(f"Loading anno file from {anno_path}.")
+ with open(anno_path, "r") as json_file:
+ self.list_data_dict = json.load(json_file)
+
+ # * print the conversations_type
+ print(f"Using conversation_type: {self.conversation_types}")
+ # * print before filtering
+ print(f"Before filtering, the dataset size is: {len(self.list_data_dict)}.")
+
+ # * iterate the list and filter
+ # * these two ids have corrupted colored point files, so filter them when use_color is True
+ filter_ids = ['6760e543e1d645d5aaacd3803bcae524', 'b91c0711149d460a8004f9c06d3b7f38'] if self.use_color else []
+
+ # Iterate the list, filter those "conversation_type" not in self.conversation_types
+ self.list_data_dict = [
+ data for data in self.list_data_dict
+ if data.get('conversation_type', 'simple_description') in self.conversation_types
+ and data.get('object_id') not in filter_ids
+ ]
+
+ # * print after filtering
+ print(f"After filtering, the dataset size is: {len(self.list_data_dict)}.")
+ # * print the size of different conversation_type
+ for conversation_type in self.conversation_types:
+ print(f"Number of {conversation_type}: {len([data for data in self.list_data_dict if data.get('conversation_type', 'simple_description') == conversation_type])}")
+
+ if self.data_args is not None and self.data_args.data_debug_num > 0:
+ self.list_data_dict = self.list_data_dict[:self.data_args.data_debug_num]
+ # * print all the scan_id in debug mode, not using for loop
+ print('Debug mode, using: ' + ' '.join([data['object_id'] for data in self.list_data_dict]))
+ elif self.data_args is not None and self.data_args.split_train_val:
+ # * split train and val with 9:1 ratios
+ if self.split == 'train':
+ self.list_data_dict = self.list_data_dict[:int(self.data_args.split_ratio * len(self.list_data_dict))]
+ print(f"Train set size: {len(self.list_data_dict)}")
+ else:
+ self.list_data_dict = self.list_data_dict[int(self.data_args.split_ratio * len(self.list_data_dict)):]
+ print(f"Val set size: {len(self.list_data_dict)}")
+
+ print("dataloader")
+ def _load_point_cloud(self, object_id, type='objaverse'):
+ if type == 'objaverse':
+ return self._load_objaverse_point_cloud(object_id)
+
+ def _load_objaverse_point_cloud(self, object_id):
+ filename = f"{object_id}_{self.pointnum}.npy"
+ point_cloud = np.load(os.path.join(self.data_path, filename))
+
+ if not self.use_color:
+ point_cloud = point_cloud[:, :3]
+
+ return point_cloud
+
+ def pc_norm(self, pc):
+ """ pc: NxC, return NxC """
+ xyz = pc[:, :3]
+ other_feature = pc[:, 3:]
+
+ centroid = np.mean(xyz, axis=0)
+ xyz = xyz - centroid
+ m = np.max(np.sqrt(np.sum(xyz ** 2, axis=1)))
+ xyz = xyz / m
+
+ pc = np.concatenate((xyz, other_feature), axis=1)
+ return pc
+
+ def __getitem__(self, index):
+ sources = self.list_data_dict[index]
+ if isinstance(index, int):
+ sources = [sources]
+ assert len(sources) == 1, "sources should be a list"
+ if self.point_indicator in sources[0]['conversations'][0]['value']:
+
+ object_id = self.list_data_dict[index]['object_id']
+
+ # Point cloud representation
+ point_cloud = self._load_point_cloud(object_id) # * N, C
+ if self.normalize_pc:
+ point_cloud = self.pc_norm(point_cloud) # * need to norm since point encoder is norm
+
+ if self.tokenizer is None:
+ data_dict = dict(
+ point_clouds=torch.from_numpy(point_cloud.astype(np.float32)),
+ object_ids=object_id
+ )
+ return data_dict
+
+ sources = preprocess_multimodal_point_cloud(
+ copy.deepcopy([e["conversations"] for e in sources]), self.point_backbone_config, point_indicator=self.point_indicator)
+ else:
+ sources = copy.deepcopy([e["conversations"] for e in sources])
+
+ data_dict = preprocess_v1(
+ sources,
+ self.tokenizer)
+
+ if isinstance(index, int):
+ data_dict = dict(input_ids=data_dict["input_ids"][0],
+ labels=data_dict["labels"][0])
+
+ # point exist in the data
+ if self.point_indicator in self.list_data_dict[index]['conversations'][0]['value']:
+ data_dict['point_clouds'] = torch.from_numpy(point_cloud.astype(np.float32))
+
+ return data_dict
+
+ def __len__(self):
+ """Return number of utterances."""
+ return len(self.list_data_dict)
+
+if __name__ == '__main__':
+ import argparse
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument("--data_path", default="data/objaverse_data", type=str,
+ help="Path to the data directory.")
+ parser.add_argument("--anno_path", default=None, type=str, required=True,
+ help="Path to the annotation file.")
+ parser.add_argument("--split", default='train', type=str,
+ help="Whether to use the train or validation dataset.")
+ parser.add_argument("--pointnum", default=8192, type=int,
+ help="Number of points in the point cloud.")
+ parser.add_argument("--data_debug_num", default=0, type=int,
+ help="Number of data to debug with.")
+ parser.add_argument("--split_train_val", default=False, type=bool,
+ help="Whether to split the dataset into training and validation.")
+ parser.add_argument("--split_ratio", default=0.9, type=float,
+ help="The ratio of training to validation data.")
+ parser.add_argument("--tokenizer_path", default=None, type=str, required=True,
+ help="Path to the tokenizer config file.")
+
+ args = parser.parse_args()
+
+ # Initialize tokenizer
+ tokenizer = transformers.AutoTokenizer.from_pretrained(args.tokenizer_path)
+
+ args.point_backbone_config = None
+
+ # Initialize dataset
+ dataset = ObjectPointCloudDataset(
+ data_path=args.data_path,
+ anno_path=args.anno_path,
+ pointnum=args.pointnum,
+ split=args.split,
+ tokenizer=tokenizer,
+ data_args=args
+ )
+
+ # Example usage
+ print(f'Dataset length: {len(dataset)}')
+
diff --git a/pointllm/data/utils.py b/pointllm/data/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..272d0af6890af9953253a5898e216468e0aae2c6
--- /dev/null
+++ b/pointllm/data/utils.py
@@ -0,0 +1,311 @@
+from collections import OrderedDict, defaultdict
+
+import transformers
+from pointllm import conversation as conversation_lib
+from dataclasses import dataclass
+from typing import Optional, Dict, Sequence
+import torch
+
+import numpy as np
+import os
+
+IGNORE_INDEX = -100
+
+# * Sample Usage:
+# * from utils import LRUCache
+# * cache = LRUCache(capacity, max_access_count)
+# if self.cache is None:
+# info_data = self.multiview_scannet[info_index]
+# else:
+# info_data = self.cache.get(info_index)
+# if info_data is None or self.cache.get_access_count(info_index) >= self.cache.max_access_count:
+# # If not in cache, or accessed max_access_count times, load it and put it in cache
+# info_data = self.multiview_scannet[info_index]
+# self.cache.put(info_index, info_data)
+# self.cache.reset_access_count(info_index)
+
+class LRUCache:
+ def __init__(self, capacity, max_access_count):
+ self.cache = OrderedDict()
+ self.access_count = defaultdict(int)
+ self.capacity = capacity
+ self.max_access_count = max_access_count
+
+ def get(self, key):
+ if key not in self.cache:
+ return None
+ value = self.cache.pop(key)
+ self.cache[key] = value # Put key as the newest one
+ self.access_count[key] += 1
+ return value
+
+ def put(self, key, value):
+ if key in self.cache: # Update the value and put it as newest
+ self.cache.pop(key)
+ elif len(self.cache) == self.capacity: # If cache is full
+ oldest_key = next(iter(self.cache))
+ self.cache.popitem(last=False) # Remove oldest item
+ del self.access_count[oldest_key] # Remove the corresponding access count
+ self.cache[key] = value
+ self.access_count[key] = 1
+
+ def get_access_count(self, key):
+ return self.access_count.get(key, 0)
+
+ def reset_access_count(self, key):
+ self.access_count[key] = 0
+
+
+def preprocess_v1(
+ sources,
+ tokenizer: transformers.PreTrainedTokenizer,
+) -> Dict:
+ conv = conversation_lib.default_conversation.copy()
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
+
+ # Apply prompt templates
+ conversations = []
+ for i, source in enumerate(sources):
+ if roles[source[0]["from"]] != conv.roles[0]:
+ # Skip the first one if it is not from human
+ source = source[1:]
+
+ conv.messages = []
+ for j, sentence in enumerate(source):
+ role = roles[sentence["from"]]
+ assert role == conv.roles[j % 2], f"{i}"
+ conv.append_message(role, sentence["value"])
+ conversations.append(conv.get_prompt())
+
+ ############################################################
+ test_str_1 = ['USER:']
+ input_ids_test_str_1 = tokenizer(
+ test_str_1,
+ return_tensors="pt",
+ padding="longest",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ ).input_ids
+
+ test_str_2 = ['']
+ input_ids_test_str_2 = tokenizer(
+ test_str_2,
+ return_tensors="pt",
+ padding="longest",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ ).input_ids
+
+ test_str_3 = ['USER:']
+ input_ids_test_str_3 = tokenizer(
+ test_str_3,
+ return_tensors="pt",
+ padding="longest",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ ).input_ids
+
+ test_str_4_no_space = ['A 3D model of a low poly yellow tree with a brown base.']
+ input_ids_test_str_relay_nospace = tokenizer(
+ test_str_4_no_space,
+ return_tensors="pt",
+ padding="longest",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ ).input_ids
+
+ test_str_4_with_space = [' A 3D model of a low poly yellow tree with a brown base.']
+ input_ids_test_str_relay_withspace = tokenizer(
+ test_str_4_with_space,
+ return_tensors="pt",
+ padding="longest",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ ).input_ids
+
+ test_str_5_with_space = [' ASSISTANT: A 3D model of a low poly yellow tree with a brown base.']
+ input_ids_test_str_5_with_space = tokenizer(
+ test_str_5_with_space,
+ return_tensors="pt",
+ padding="longest",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ ).input_ids
+
+
+ test_str_5_no_space = ['ASSISTANT: A 3D model of a low poly yellow tree with a brown base.']
+ input_ids_test_str_5_no_space = tokenizer(
+ test_str_5_no_space,
+ return_tensors="pt",
+ padding="longest",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ ).input_ids
+ ###############################################################
+
+ # Tokenize conversations
+ input_ids = tokenizer(
+ conversations,
+ return_tensors="pt",
+ padding="longest",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ ).input_ids
+ targets = input_ids.clone()
+
+ assert conv.sep_style == conversation_lib.SeparatorStyle.TWO
+
+ # Mask targets
+ sep = conv.sep + conv.roles[1] + ": "
+ for conversation, target in zip(conversations, targets):
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
+
+ rounds = conversation.split(conv.sep2)
+ # cur_len = 1
+ cur_len = 0
+ target[:cur_len] = IGNORE_INDEX
+ for i, rou in enumerate(rounds):
+ if rou == "":
+ break
+
+ parts = rou.split(sep)
+ if len(parts) != 2: # * can handle padded tokens
+ break
+ parts[0] += sep
+
+ #########
+ tt = tokenizer(rou).input_ids
+ ######
+ round_len = len(tokenizer(rou).input_ids)
+
+ #####
+ tt_2 = tokenizer(parts[0]).input_ids
+ ####
+ instruction_len = len(tokenizer(parts[0]).input_ids) -1
+
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
+
+ cur_len = cur_len + round_len + 3
+ target[cur_len:] = IGNORE_INDEX # * this is necessary for padded tokens
+
+ # if cur_len < tokenizer.model_max_length:
+ # if cur_len != total_len: # * unk tokens in the dialogue will cause this.
+ # target[:] = IGNORE_INDEX
+ # print(
+ # f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
+ # f" (ignored)"
+ # )
+
+ return dict(
+ input_ids=input_ids,
+ labels=targets,
+ )
+
+def preprocess_multimodal_point_cloud(
+ sources: Sequence[str],
+ point_backbone_config: dict,
+ point_indicator: str = "",
+) -> Dict:
+ point_token_len = point_backbone_config['point_token_len']
+ default_point_patch_token = point_backbone_config['default_point_patch_token']
+
+ for source in sources:
+ for sentence in source:
+ replace_token = default_point_patch_token * point_token_len
+ if point_backbone_config['mm_use_point_start_end']:
+ replace_token = point_backbone_config['default_point_start_token']+ replace_token + point_backbone_config['default_point_end_token']
+ sentence["value"] = sentence["value"].replace(point_indicator, replace_token)
+
+ return sources
+
+def pc_norm(pc):
+ """ pc: NxC, return NxC """
+ xyz = pc[:, :3]
+ other_feature = pc[:, 3:]
+
+ centroid = np.mean(xyz, axis=0)
+ xyz = xyz - centroid
+ m = np.max(np.sqrt(np.sum(xyz ** 2, axis=1)))
+ xyz = xyz / m
+
+ pc = np.concatenate((xyz, other_feature), axis=1)
+ return pc
+
+def load_objaverse_point_cloud(data_path, object_id, pointnum=8192, use_color=False):
+ filename = f"{object_id}_{pointnum}.npy"
+ point_cloud = np.load(os.path.join(data_path, filename))
+
+ # * normalize
+ point_cloud = pc_norm(point_cloud)
+
+ if not use_color:
+ point_cloud = point_cloud[:, :3]
+
+ return point_cloud
+
+@dataclass
+class DataCollatorForPointTextDataset(object):
+ """Collate examples for mixed dataset with text and point cloud data."""
+
+ tokenizer: transformers.PreTrainedTokenizer
+
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
+ input_ids, labels = tuple([instance[key] for instance in instances]
+ for key in ("input_ids", "labels"))
+ input_ids = torch.nn.utils.rnn.pad_sequence(
+ input_ids,
+ batch_first=True,
+ padding_value=self.tokenizer.pad_token_id)
+ labels = torch.nn.utils.rnn.pad_sequence(labels,
+ batch_first=True,
+ padding_value=IGNORE_INDEX)
+ batch = dict(
+ input_ids=input_ids,
+ labels=labels,
+ attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
+ )
+
+ if 'point_clouds' in instances[0]:
+ point_clouds = [instance['point_clouds'] for instance in instances]
+ if all(x is not None and x.shape == point_clouds[0].shape for x in point_clouds): # * point_clouds have different shapes
+ batch['point_clouds'] = torch.stack(point_clouds)
+ else:
+ batch['point_clouds'] = point_clouds # * return as lists
+
+ return batch
+
+def farthest_point_sample(point, npoint):
+ """
+ Input:
+ xyz: pointcloud data, [N, D]
+ npoint: number of samples
+ Return:
+ centroids: sampled pointcloud index, [npoint, D]
+ """
+ N, D = point.shape
+ xyz = point[:,:3]
+ centroids = np.zeros((npoint,))
+ distance = np.ones((N,)) * 1e10
+ farthest = np.random.randint(0, N)
+ for i in range(npoint):
+ centroids[i] = farthest
+ centroid = xyz[farthest, :]
+ dist = np.sum((xyz - centroid) ** 2, -1)
+ mask = dist < distance
+ distance[mask] = dist[mask]
+ farthest = np.argmax(distance, -1)
+ point = point[centroids.astype(np.int32)]
+ return point
+
+def pc_normalize(pc):
+ """
+ pc: Nx3 array
+ This functions normalizes a point cloud to fit within a unit sphere.
+ It first calculates the centroid of the point cloud and then subtracts
+ it from all points before scaling all points to fit within a unit sphere.
+ """
+ centroid = np.mean(pc, axis=0)
+ pc = pc - centroid
+ m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
+ pc = pc / m
+ return pc
\ No newline at end of file
diff --git a/pointllm/data/utils_backup.py b/pointllm/data/utils_backup.py
new file mode 100644
index 0000000000000000000000000000000000000000..c41aaca765e4e670207ee798807ec64c65730a48
--- /dev/null
+++ b/pointllm/data/utils_backup.py
@@ -0,0 +1,236 @@
+from collections import OrderedDict, defaultdict
+
+import transformers
+from pointllm import conversation as conversation_lib
+from dataclasses import dataclass
+from typing import Optional, Dict, Sequence
+import torch
+
+import numpy as np
+import os
+
+IGNORE_INDEX = -100
+
+# * Sample Usage:
+# * from utils import LRUCache
+# * cache = LRUCache(capacity, max_access_count)
+# if self.cache is None:
+# info_data = self.multiview_scannet[info_index]
+# else:
+# info_data = self.cache.get(info_index)
+# if info_data is None or self.cache.get_access_count(info_index) >= self.cache.max_access_count:
+# # If not in cache, or accessed max_access_count times, load it and put it in cache
+# info_data = self.multiview_scannet[info_index]
+# self.cache.put(info_index, info_data)
+# self.cache.reset_access_count(info_index)
+
+class LRUCache:
+ def __init__(self, capacity, max_access_count):
+ self.cache = OrderedDict()
+ self.access_count = defaultdict(int)
+ self.capacity = capacity
+ self.max_access_count = max_access_count
+
+ def get(self, key):
+ if key not in self.cache:
+ return None
+ value = self.cache.pop(key)
+ self.cache[key] = value # Put key as the newest one
+ self.access_count[key] += 1
+ return value
+
+ def put(self, key, value):
+ if key in self.cache: # Update the value and put it as newest
+ self.cache.pop(key)
+ elif len(self.cache) == self.capacity: # If cache is full
+ oldest_key = next(iter(self.cache))
+ self.cache.popitem(last=False) # Remove oldest item
+ del self.access_count[oldest_key] # Remove the corresponding access count
+ self.cache[key] = value
+ self.access_count[key] = 1
+
+ def get_access_count(self, key):
+ return self.access_count.get(key, 0)
+
+ def reset_access_count(self, key):
+ self.access_count[key] = 0
+
+
+def preprocess_v1(
+ sources,
+ tokenizer: transformers.PreTrainedTokenizer,
+) -> Dict:
+ conv = conversation_lib.default_conversation.copy()
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
+
+ # Apply prompt templates
+ conversations = []
+ for i, source in enumerate(sources):
+ if roles[source[0]["from"]] != conv.roles[0]:
+ # Skip the first one if it is not from human
+ source = source[1:]
+
+ conv.messages = []
+ for j, sentence in enumerate(source):
+ role = roles[sentence["from"]]
+ assert role == conv.roles[j % 2], f"{i}"
+ conv.append_message(role, sentence["value"])
+ conversations.append(conv.get_prompt())
+
+ # Tokenize conversations
+ input_ids = tokenizer(
+ conversations,
+ return_tensors="pt",
+ padding="longest",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ ).input_ids
+ targets = input_ids.clone()
+
+ assert conv.sep_style == conversation_lib.SeparatorStyle.TWO
+
+ # Mask targets
+ sep = conv.sep + conv.roles[1] + ": "
+ for conversation, target in zip(conversations, targets):
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
+
+ rounds = conversation.split(conv.sep2)
+ cur_len = 1
+ target[:cur_len] = IGNORE_INDEX
+ for i, rou in enumerate(rounds):
+ if rou == "":
+ break
+
+ parts = rou.split(sep)
+ if len(parts) != 2: # * can handle padded tokens
+ break
+ parts[0] += sep
+ round_len = len(tokenizer(rou).input_ids)
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
+
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
+
+ cur_len += round_len
+ target[cur_len:] = IGNORE_INDEX # * this is necessary for padded tokens
+
+ if cur_len < tokenizer.model_max_length:
+ if cur_len != total_len: # * unk tokens in the dialogue will cause this.
+ target[:] = IGNORE_INDEX
+ print(
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
+ f" (ignored)"
+ )
+
+ return dict(
+ input_ids=input_ids,
+ labels=targets,
+ )
+
+def preprocess_multimodal_point_cloud(
+ sources: Sequence[str],
+ point_backbone_config: dict,
+ point_indicator: str = "",
+) -> Dict:
+ point_token_len = point_backbone_config['point_token_len']
+ default_point_patch_token = point_backbone_config['default_point_patch_token']
+
+ for source in sources:
+ for sentence in source:
+ replace_token = default_point_patch_token * point_token_len
+ if point_backbone_config['mm_use_point_start_end']:
+ replace_token = point_backbone_config['default_point_start_token']+ replace_token + point_backbone_config['default_point_end_token']
+ sentence["value"] = sentence["value"].replace(point_indicator, replace_token)
+
+ return sources
+
+def pc_norm(pc):
+ """ pc: NxC, return NxC """
+ xyz = pc[:, :3]
+ other_feature = pc[:, 3:]
+
+ centroid = np.mean(xyz, axis=0)
+ xyz = xyz - centroid
+ m = np.max(np.sqrt(np.sum(xyz ** 2, axis=1)))
+ xyz = xyz / m
+
+ pc = np.concatenate((xyz, other_feature), axis=1)
+ return pc
+
+def load_objaverse_point_cloud(data_path, object_id, pointnum=8192, use_color=False):
+ filename = f"{object_id}_{pointnum}.npy"
+ point_cloud = np.load(os.path.join(data_path, filename))
+
+ # * normalize
+ point_cloud = pc_norm(point_cloud)
+
+ if not use_color:
+ point_cloud = point_cloud[:, :3]
+
+ return point_cloud
+
+@dataclass
+class DataCollatorForPointTextDataset(object):
+ """Collate examples for mixed dataset with text and point cloud data."""
+
+ tokenizer: transformers.PreTrainedTokenizer
+
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
+ input_ids, labels = tuple([instance[key] for instance in instances]
+ for key in ("input_ids", "labels"))
+ input_ids = torch.nn.utils.rnn.pad_sequence(
+ input_ids,
+ batch_first=True,
+ padding_value=self.tokenizer.pad_token_id)
+ labels = torch.nn.utils.rnn.pad_sequence(labels,
+ batch_first=True,
+ padding_value=IGNORE_INDEX)
+ batch = dict(
+ input_ids=input_ids,
+ labels=labels,
+ attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
+ )
+
+ if 'point_clouds' in instances[0]:
+ point_clouds = [instance['point_clouds'] for instance in instances]
+ if all(x is not None and x.shape == point_clouds[0].shape for x in point_clouds): # * point_clouds have different shapes
+ batch['point_clouds'] = torch.stack(point_clouds)
+ else:
+ batch['point_clouds'] = point_clouds # * return as lists
+
+ return batch
+
+def farthest_point_sample(point, npoint):
+ """
+ Input:
+ xyz: pointcloud data, [N, D]
+ npoint: number of samples
+ Return:
+ centroids: sampled pointcloud index, [npoint, D]
+ """
+ N, D = point.shape
+ xyz = point[:,:3]
+ centroids = np.zeros((npoint,))
+ distance = np.ones((N,)) * 1e10
+ farthest = np.random.randint(0, N)
+ for i in range(npoint):
+ centroids[i] = farthest
+ centroid = xyz[farthest, :]
+ dist = np.sum((xyz - centroid) ** 2, -1)
+ mask = dist < distance
+ distance[mask] = dist[mask]
+ farthest = np.argmax(distance, -1)
+ point = point[centroids.astype(np.int32)]
+ return point
+
+def pc_normalize(pc):
+ """
+ pc: Nx3 array
+ This functions normalizes a point cloud to fit within a unit sphere.
+ It first calculates the centroid of the point cloud and then subtracts
+ it from all points before scaling all points to fit within a unit sphere.
+ """
+ centroid = np.mean(pc, axis=0)
+ pc = pc - centroid
+ m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
+ pc = pc / m
+ return pc
\ No newline at end of file
diff --git a/pointllm/eval/eval_modelnet_cls.py b/pointllm/eval/eval_modelnet_cls.py
new file mode 100644
index 0000000000000000000000000000000000000000..3dce0b19e9686a151defe65064e2ff7e1e0dfe9a
--- /dev/null
+++ b/pointllm/eval/eval_modelnet_cls.py
@@ -0,0 +1,235 @@
+import os
+import json
+import argparse
+from torch.utils.data import DataLoader
+from pointllm.data import ModelNet
+from tqdm import tqdm
+import torch
+from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
+from llava.conversation import conv_templates
+from llava.model.builder import load_pretrained_model
+from llava.mm_utils import tokenizer_image_token, get_model_name_from_path
+
+
+class MyClass:
+
+ def __init__(self, arg):
+
+ self.vision_tower = None
+ self.pretrain_mm_mlp_adapter = arg.pretrain_mm_mlp_adapter
+ self.encoder_type = 'pc_encoder'
+ self.std=arg.std
+ self.pc_encoder_type = arg.pc_encoder_type
+ self.pc_feat_dim = 192
+ self.embed_dim = 1024
+ self.group_size = 64
+ self.num_group =512
+ self.pc_encoder_dim =512
+ self.patch_dropout = 0.0
+ self.pc_ckpt_path = arg.pc_ckpt_path
+ self.lora_path = arg.lora_path
+ self.model_path=arg.model_path
+ self.get_pc_tokens_way=arg.get_pc_tokens_way
+
+
+def init_model(model_arg_):
+ model_path = "llava-vicuna_phi_3_finetune_weight"
+ model_name = get_model_name_from_path(model_path)
+ model_path = model_arg_.model_path
+ tokenizer, model, context_len = load_pretrained_model(model_path, None, model_name)
+
+ if model_arg_.lora_path:
+ from peft import PeftModel
+
+ model = PeftModel.from_pretrained(model, model_arg_.lora_path)
+ print("load lora weight ok")
+
+ model.get_model().initialize_other_modules(model_arg_)
+ print("load encoder, mlp ok")
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+
+ # ๅฐๆจกๅๅ ่ฝฝๅฐCUDA่ฎพๅค
+ model.to(dtype=torch.bfloat16)
+ model.get_model().vision_tower.to(dtype=torch.float)
+ model.to(device)
+
+ return tokenizer, model
+
+
+
+PROMPT_LISTS = [
+ "What is this?",
+ "This is an object of "
+]
+
+
+def load_dataset(data_path, config_path, split, subset_nums, use_color):
+ print(f"Loading {split} split of ModelNet datasets.")
+ dataset = ModelNet(data_path=data_path, config_path=config_path, split=split, subset_nums=subset_nums, use_color=use_color)
+ print("Done!")
+ return dataset
+
+def get_dataloader(dataset, batch_size, shuffle=False, num_workers=4):
+ assert shuffle is False, "Since we using the index of ModelNet as Object ID when evaluation \
+ so shuffle shoudl be False and should always set random seed."
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
+ return dataloader
+
+
+def start_generation(model, tokenizer, dataloader, prompt_index, output_dir, output_file, args):
+ # stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
+ qs = PROMPT_LISTS[prompt_index]
+
+ results = {"prompt": qs}
+
+ qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
+
+ conv_mode = "phi3_instruct"
+ conv = conv_templates[conv_mode].copy()
+ conv.append_message(conv.roles[0], qs)
+ conv.append_message(conv.roles[1], None)
+ qs = conv.get_prompt()
+
+ input_ids = (
+ tokenizer_image_token(qs, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
+ .unsqueeze(0)
+ .cuda()
+ )
+
+ responses = []
+
+ for batch in tqdm(dataloader):
+ point_clouds = batch["point_clouds"].cuda() # * tensor of B, N, C(3)
+ labels = batch["labels"]
+ label_names = batch["label_names"]
+ indice = batch["indice"]
+
+ texts = input_ids.repeat(point_clouds.size()[0], 1)
+
+ images_tensor = point_clouds.to(dtype=torch.bfloat16)
+
+ temperature = args.temperature
+ top_p = args.top_p
+
+ max_new_tokens = args.max_new_tokens
+ min_new_tokens = args.min_new_tokens
+ num_beams = args.num_beams
+
+ with torch.inference_mode():
+ output_ids = model.generate(
+ texts,
+ images=images_tensor,
+ do_sample=True if temperature > 0 and num_beams == 1 else False,
+ temperature=temperature,
+ top_p=top_p,
+ num_beams=num_beams,
+ max_new_tokens=max_new_tokens,
+ min_new_tokens=min_new_tokens,
+ use_cache=True,
+ )
+
+ answers = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
+
+ outputs = []
+ for answer in answers:
+ answer = answer.strip()
+ answer = answer.replace("<|end|>", "").strip()
+ outputs.append(answer)
+
+ # saving results
+ for index, output, label, label_name in zip(indice, outputs, labels, label_names):
+ responses.append({
+ "object_id": index.item(),
+ "ground_truth": label.item(),
+ "model_output": output,
+ "label_name": label_name
+ })
+
+ results["results"] = responses
+
+ os.makedirs(output_dir, exist_ok=True)
+ # save the results to a JSON file
+ with open(os.path.join(output_dir, output_file), 'w') as fp:
+ json.dump(results, fp, indent=2)
+
+ # * print info
+ print(f"Saved results to {os.path.join(output_dir, output_file)}")
+
+ return results
+
+def main(args):
+ # * ouptut
+ args.output_dir = os.path.join(args.out_path, "evaluation")
+
+ # * output file
+ args.output_file = f"ModelNet_classification_prompt{args.prompt_index}.json"
+ args.output_file_path = os.path.join(args.output_dir, args.output_file)
+
+ # * First inferencing, then evaluate
+ if not os.path.exists(args.output_file_path):
+ # * need to generate results first
+ dataset = load_dataset(data_path=args.data_path, config_path=None, split=args.split, subset_nums=args.subset_nums, use_color=args.use_color) # * defalut config
+ dataloader = get_dataloader(dataset, args.batch_size, args.shuffle, args.num_workers)
+
+ model_arg = MyClass(args)
+ tokenizer, model = init_model(model_arg)
+
+
+ model.eval()
+
+ # * ouptut
+ print(f'[INFO] Start generating results for {args.output_file}.')
+ results = start_generation(model, tokenizer, dataloader, args.prompt_index, args.output_dir, args.output_file, args)
+
+ # * release model and tokenizer, and release cuda memory
+ del model
+
+ torch.cuda.empty_cache()
+ else:
+ # * directly load the results
+ print(f'[INFO] {args.output_file_path} already exists, directly loading...')
+ with open(args.output_file_path, 'r') as fp:
+ results = json.load(fp)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--out_path", type=str, default="./output_json")
+ parser.add_argument("--pretrain_mm_mlp_adapter", type=str, required=True)
+
+
+
+ parser.add_argument("--lora_path", type=str, default=None)
+ parser.add_argument("--model_path", type=str, default='./lava-vicuna_2024_4_Phi-3-mini-4k-instruct')
+
+ parser.add_argument("--std", type=float, default=0.0)
+ parser.add_argument("--pc_ckpt_path", type=str, required=True, default="./pretrained_weight/Uni3D_PC_encoder/modelzoo/uni3d-small/model.pt")
+ parser.add_argument("--pc_encoder_type", type=str, required=True, default='small')
+ parser.add_argument("--get_pc_tokens_way", type=str, required=True)
+
+ # * dataset type
+ parser.add_argument("--data_path", type=str, default="./dataset/modelnet40_data", help="train or test.")
+ parser.add_argument("--split", type=str, default="test", help="train or test.")
+ parser.add_argument("--use_color", action="store_true", default=True)
+
+ # * data loader, batch_size, shuffle, num_workers
+ parser.add_argument("--batch_size", type=int, default=10)
+ parser.add_argument("--shuffle", type=bool, default=False)
+ parser.add_argument("--num_workers", type=int, default=20)
+ parser.add_argument("--subset_nums", type=int, default=-1) # * only use "subset_nums" of samples, mainly for debug
+
+ # * evaluation setting
+ parser.add_argument("--prompt_index", type=int, required=True, help="0 or 1")
+
+ ############## new add
+ parser.add_argument("--max_new_tokens", type=int, default=110, help="max number of generated tokens")
+ parser.add_argument("--min_new_tokens", type=int, default=0, help="min number of generated tokens")
+ parser.add_argument("--num_beams", type=int, default=1)
+ parser.add_argument("--temperature", type=float, default=0.1)
+ parser.add_argument("--top_k", type=int, default=1)
+ parser.add_argument("--top_p", type=float, default=0.7)
+ ############## new add
+
+ args = parser.parse_args()
+
+ main(args)
diff --git a/pointllm/eval/eval_objaverse.py b/pointllm/eval/eval_objaverse.py
new file mode 100644
index 0000000000000000000000000000000000000000..43eee803988e21322b182e10840e41c743ceb2f6
--- /dev/null
+++ b/pointllm/eval/eval_objaverse.py
@@ -0,0 +1,282 @@
+import os
+import argparse
+import torch
+import json
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+
+from pointllm.data import ObjectPointCloudDataset
+
+
+PROMPT_LISTS = [
+ "What is this?",
+ "This is an object of ",
+ "Caption this 3D model in detail.",
+]
+
+
+from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
+from llava.conversation import conv_templates
+from llava.model.builder import load_pretrained_model
+from llava.mm_utils import tokenizer_image_token, get_model_name_from_path
+
+
+class MyClass:
+
+ def __init__(self, arg):
+
+ self.vision_tower = None
+ self.pretrain_mm_mlp_adapter = arg.pretrain_mm_mlp_adapter
+
+ self.encoder_type = 'pc_encoder' # text_encoder, pc_encoder
+ self.std=arg.std
+
+ self.pc_encoder_type = arg.pc_encoder_type
+ self.pc_feat_dim = 192 # ไธๅ็pc encoder ไธๅ
+ self.embed_dim = 1024
+ self.group_size = 64
+ self.num_group =512
+ self.pc_encoder_dim =512
+ self.patch_dropout = 0.0
+ self.pc_ckpt_path = arg.pc_ckpt_path
+ self.lora_path = arg.lora_path
+ self.model_path=arg.model_path
+ self.get_pc_tokens_way=arg.get_pc_tokens_way
+
+
+def init_model(model_arg_):
+ model_path = "llava-vicuna_phi_3_finetune_weight"
+ model_name = get_model_name_from_path(model_path)
+ model_path = model_arg_.model_path
+ tokenizer, model, context_len = load_pretrained_model(model_path, None, model_name)
+
+ if model_arg_.lora_path:
+ from peft import PeftModel
+
+ model = PeftModel.from_pretrained(model, model_arg_.lora_path)
+ print("load lora weight ok")
+
+ model.get_model().initialize_other_modules(model_arg_)
+ print("load encoder, mlp ok")
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+
+ model.to(dtype=torch.bfloat16)
+ model.get_model().vision_tower.to(dtype=torch.float)
+ model.to(device)
+
+ return tokenizer, model
+
+
+
+def load_dataset(data_path, anno_path, pointnum, conversation_types, use_color):
+ print("Loading validation datasets.")
+ dataset = ObjectPointCloudDataset(
+ data_path=data_path,
+ anno_path=anno_path,
+ pointnum=pointnum,
+ conversation_types=conversation_types,
+ use_color=use_color,
+ tokenizer=None # * load point cloud only
+ )
+ print("Done!")
+ return dataset
+
+
+def get_dataloader(dataset, batch_size, shuffle=False, num_workers=4):
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
+ return dataloader
+
+
+def start_generation(model, dataloader, annos, prompt_index, output_dir, output_file, tokenizer, args):
+ qs = PROMPT_LISTS[prompt_index]
+
+ results = {"prompt": qs}
+
+
+ qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
+
+ conv_mode = 'phi3_instruct'
+ conv = conv_templates[conv_mode].copy()
+ conv.append_message(conv.roles[0], qs)
+ conv.append_message(conv.roles[1], None)
+ qs = conv.get_prompt()
+
+ print("qs:",qs)
+
+
+ input_ids = (
+ tokenizer_image_token(qs, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
+ .unsqueeze(0)
+ .cuda()
+ )
+
+
+ responses = []
+
+ for batch in tqdm(dataloader):
+ point_clouds = batch["point_clouds"].cuda()
+ object_ids = batch["object_ids"] # * list of string
+
+ texts = input_ids.repeat(point_clouds.size()[0], 1)
+
+ images_tensor = point_clouds.to(dtype=torch.bfloat16) # torch.Size([20, 8192, 6]
+
+
+ temperature = args.temperature
+ top_p = args.top_p
+
+ max_new_tokens = args.max_new_tokens
+ min_new_tokens = args.min_new_tokens
+ num_beams = args.num_beams
+ repetition_penalty=args.repetition_penalty
+
+
+ with torch.inference_mode():
+ output_ids = model.generate(
+ texts,
+ images=images_tensor,
+ do_sample=True if temperature > 0 and num_beams == 1 else False,
+ temperature=temperature,
+ top_p=top_p,
+ num_beams=num_beams,
+ max_new_tokens=max_new_tokens,
+ min_new_tokens=min_new_tokens,
+ use_cache=True,
+ repetition_penalty=repetition_penalty,
+ )
+
+
+
+ answers = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
+
+ outputs = []
+ for answer in answers:
+ answer = answer.strip()
+ answer = answer.replace("<|end|>", "").strip()
+ outputs.append(answer)
+
+ # saving results
+ for obj_id, output in zip(object_ids, outputs):
+ responses.append({
+ "object_id": obj_id,
+ "ground_truth": annos[obj_id],
+ "model_output": output
+ })
+
+ results["results"] = responses
+
+ os.makedirs(output_dir, exist_ok=True)
+ # save the results to a JSON file
+ with open(os.path.join(output_dir, output_file), 'w') as fp:
+ json.dump(results, fp, indent=2)
+
+ # * print info
+ print(f"Saved results to {os.path.join(output_dir, output_file)}")
+
+ return results
+
+
+def main(args):
+ # * ouptut
+ args.output_dir = os.path.join(args.out_path, "evaluation")
+
+ # * output file
+ anno_file = os.path.splitext(os.path.basename(args.anno_path))[0]
+ args.output_file = f"{anno_file}_Objaverse_{args.task_type}_prompt{args.prompt_index}.json"
+ args.output_file_path = os.path.join(args.output_dir, args.output_file)
+
+ # * First inferencing, then evaluate
+ if not os.path.exists(args.output_file_path):
+ # * need inferencing
+ # * load annotation files
+ with open(args.anno_path, 'r') as fp:
+ annos = json.load(fp)
+
+ dataset = load_dataset(args.data_path, args.anno_path, args.pointnum, ("simple_description",), args.use_color)
+ dataloader = get_dataloader(dataset, args.batch_size, args.shuffle, args.num_workers)
+
+ model_arg = MyClass(args)
+ tokenizer, model = init_model(model_arg)
+ model.eval()
+
+ # * convert annos file from [{"object_id": }] to {"object_id": }
+ annos = {anno["object_id"]: anno["conversations"][1]['value'] for anno in annos}
+
+ print(f'[INFO] Start generating results for {args.output_file}.')
+ results = start_generation(model, dataloader, annos, args.prompt_index, args.output_dir, args.output_file, tokenizer, args)
+
+ # * release model and release cuda memory
+ del model
+
+ torch.cuda.empty_cache()
+ else:
+ # * directly load the results
+ print(f'[INFO] {args.output_file_path} already exists, directly loading...')
+ with open(args.output_file_path, 'r') as fp:
+ results = json.load(fp)
+
+
+
+
+if __name__ == "__main__":
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--out_path", type=str, default="./output_json")
+ parser.add_argument("--pretrain_mm_mlp_adapter", type=str, required=True)
+
+ parser.add_argument("--lora_path", type=str, default=None)
+ parser.add_argument("--model_path", type=str, default='./lava-vicuna_2024_4_Phi-3-mini-4k-instruct')
+
+ parser.add_argument("--std", type=float, default=0.0)
+ parser.add_argument("--pc_ckpt_path", type=str, required=True, default="./pretrained_weight/Uni3D_PC_encoder/modelzoo/uni3d-small/model.pt")
+ parser.add_argument("--pc_encoder_type", type=str, required=True, default='small')
+ parser.add_argument("--get_pc_tokens_way", type=str, required=True)
+
+ # * dataset type
+ parser.add_argument("--data_path", type=str, default="./dataset/Objaverse/8192_npy", required=False)
+
+ parser.add_argument("--anno_path", type=str,
+ default="./dataset/Objaverse/PointLLM_brief_description_val_200_GT.json",
+ required=False)
+ parser.add_argument("--pointnum", type=int, default=8192)
+ parser.add_argument("--use_color", action="store_true", default=True)
+
+ # * data loader, batch_size, shuffle, num_workers
+ parser.add_argument("--batch_size", type=int, default=10)
+ parser.add_argument("--shuffle", type=bool, default=False)
+ parser.add_argument("--num_workers", type=int, default=10)
+
+ # * evaluation setting
+ parser.add_argument("--prompt_index", type=int, default=0)
+
+ parser.add_argument("--task_type", type=str, default="classification", choices=["captioning", "classification"],
+ help="Type of the task to evaluate.")
+
+
+ ############## new add
+ parser.add_argument("--max_new_tokens", type=int, default=150, help="max number of generated tokens")
+ parser.add_argument("--min_new_tokens", type=int, default=0, help="min number of generated tokens")
+ parser.add_argument("--num_beams", type=int, default=1)
+ parser.add_argument("--temperature", type=float, default=0.1)
+ parser.add_argument("--top_k", type=int, default=1) # ๆๆถๆฒก่ตทไฝ็จ
+ parser.add_argument("--top_p", type=float, default=0.7)
+ parser.add_argument("--repetition_penalty", type=float, default=1 )
+ ############## new add
+
+ args = parser.parse_args()
+
+ # * check prompt index
+ # * * classification: 0, 1 and captioning: 2. Raise Warning otherwise.
+ if args.task_type == "classification":
+ if args.prompt_index != 0 and args.prompt_index != 1:
+ print("[Warning] For classification task, prompt_index should be 0 or 1.")
+ elif args.task_type == "captioning":
+ pass
+ if args.prompt_index != 2:
+ print("[Warning] For captioning task, prompt_index should be 2.")
+ else:
+ raise NotImplementedError
+
+ main(args)
+
+
diff --git a/pointllm/eval/evaluator_opensource_llm_QwenAPI.py b/pointllm/eval/evaluator_opensource_llm_QwenAPI.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8e3fdec360872795a1acb16f569695985f0f6ff
--- /dev/null
+++ b/pointllm/eval/evaluator_opensource_llm_QwenAPI.py
@@ -0,0 +1,970 @@
+import argparse
+import json
+import time
+from http import HTTPStatus
+
+from dashscope import Generation
+from tqdm import tqdm
+from multiprocessing import Pool
+import random
+from gradio_client import Client
+from concurrent.futures import ThreadPoolExecutor
+
+# from openai import OpenAI
+random.seed(0)
+import re
+
+import os
+
+
+MY_client = None
+
+open_free_from_cls_prompt = """Analyze two sentences and determine if they're referring to the same general object or concept, only focusing on the type of object and category, not attributes such as color, size, or shape. Ignore 3D model-related adjectives such as "cartoon-style", "toy". Respond with 'T' if they refer to the same big category and 'F' if not. Also, provide a brief rationale (no more than 20 words) for your judgment.
+Example:
+Input: 1. Spiral staircase that goes from a ground floor. 2. This is a 3D model of wooden stairs in light brown
+Output: T#Both refer to a staircase.
+
+Input: 1. A white and red van. 2. This is a 3D model of a toy cartoon-style truck
+Output: T# Both refer to a car, they are in the same big category.
+
+Now, analyze the following:
+Input: 1. {ground_truth} 2. {model_output}
+Output: """ # * about 230 input tokens
+
+close_set_cls_prompt = """Given the following free-form description of a 3D object, please determine the most probable class index from the following 40 available categories, even if the description doesn't clearly refer to any one of them. Make your best-educated guess based on the information provided. If the description already contains a valid index, then the index should be selected. If it contains more than one valid index, then randomly select one index (specify your reason). If there is no valid index and it cannot be inferred from the information, return '-1#NA#Cannot infer'.
+Categories:
+{candidate_lists}
+Reply with the format of 'index#class#short reason (no more than 10 words)'.
+
+Examples:
+Input: This is a 3D object model of a cartoon white truck.
+Output: 7#car#Closest match to 'car' in categories.
+
+Input: A green leaf in a flower pot.
+Output: 26#plant#The primary subject 'leaf' directly indicates a plant.
+
+Input: It's difficult to determine the exact type of this object due to insufficient details. But it seems to be like a piece of furniture.
+Output: 33#table#Randomly select one kind of furniture from the list.
+
+Input: I cannot determine the specific type of the object without additional information or context.
+Output: -1#NA#Cannot infer.
+
+Now analyze the following:
+Input: """
+
+object_captioning_prompt = """Evaluate a model-generated caption against a ground-truth caption for a 3D model. Identify the aspects mentioned in theground-truth caption and calculate the percentage of these aspects correctly mentioned or partially matched in the model caption. Score from 0 to 100, where each aspect contributes equally to the score. Consider similar concepts for partial score.
+
+Provide your score (0-100) and a short justification (less than 15 words) in the format of 'score#reason'
+
+Example:
+Ground Truth:: A white brown skeleton
+Model: This is a 3D model of a small, cartoon-like robot. It has a spherical body and is covered in a layer of white dust.
+Output: 50#mention white; skeleton and robot have similar appearence.
+
+Now score the following:
+Ground Truth: {ground_truth}
+Model: {model_output}
+Output: """
+
+
+LLM_object_captioning_prompt = object_captioning_prompt
+LLM_open_free_from_cls_prompt = open_free_from_cls_prompt
+LLM_close_set_cls_prompt = close_set_cls_prompt
+
+GPT_PRICES = {
+ # * check https://openai.com/pricing for updated price
+ "gpt-3.5-turbo-0125": {
+ "price_1k_prompt_tokens": 0.0005,
+ "price_1k_completion_tokens": 0.0015
+ },
+ "gpt-3.5-turbo-0613": {
+ "price_1k_prompt_tokens": 0.0015,
+ "price_1k_completion_tokens": 0.002
+ },
+ "gpt-3.5-turbo-1106": {
+ "price_1k_prompt_tokens": 0.0010,
+ "price_1k_completion_tokens": 0.002
+ },
+ "gpt-4-0613": {
+ "price_1k_prompt_tokens": 0.03,
+ "price_1k_completion_tokens": 0.06
+ },
+ "gpt-4-1106-preview": {
+ "price_1k_prompt_tokens": 0.01,
+ "price_1k_completion_tokens": 0.03
+ },
+ "HF": {
+ "price_1k_prompt_tokens": 0,
+ "price_1k_completion_tokens": 0
+ },
+}
+
+
+class OpenAIOpenFreeFormClsEvaluator():
+ def __init__(self, inputs, output_dir, output_file, model_type="Qwen/Qwen2-72B-Instruct", client=None):
+ """
+ Args:
+ inputs: A dictionary containing the results of the evaluation. It contains two keys: "results" and "prompt".
+ "prompt": str
+ "results": [
+ {
+ "object_id": str,
+ "model_output": str,
+ "ground_truth": str
+ }
+ ]
+ """
+ print("-" * 80)
+ print("Initializing OpenAIEvaluator...")
+ self.results = inputs['results'] # * contains two keys: "results" and "prompt"
+ self.inference_prompt = inputs['prompt'] # * used to prompt PointLLM
+ self.correct_predictions = 0
+ self.total_predictions = 0
+ self.invalid_responses = 0
+ self.response_data = [] # to save all the response data by openaigpt
+ self.model_type = model_type
+ self.check_model_type()
+
+ self.client = client
+
+ self.prompt_tokens = 0
+ self.completion_tokens = 0
+
+ self.default_chat_parameters = {
+ "model": model_type,
+ "temperature": 1,
+ "top_p": 1,
+ "max_tokens": 2048
+ }
+
+ # * price
+ self.price_1k_prompt_tokens = GPT_PRICES["HF"]["price_1k_prompt_tokens"]
+ self.price_1k_completion_tokens = GPT_PRICES["HF"]["price_1k_completion_tokens"]
+
+ print(f"OpenAIGPT config: ")
+ print(self.default_chat_parameters)
+
+ # self.openaigpt = OpenAIGPT(**self.default_chat_parameters)
+ self.gpt_prompt = LLM_open_free_from_cls_prompt
+ self.output_dir = output_dir
+ self.output_file = output_file
+ self.temp_output_file = self.output_file.replace(".json", "_processed_temp.json")
+
+ def get_relpy_from_llm(self, input_sentence):
+
+ query_input = input_sentence.replace("the rocket", "this rocket")
+ query_input = query_input.replace("The rocket", "This rocket")
+
+ time.sleep(0.2)
+
+ # try:
+ messages = [{'role': 'system', 'content': 'You are a helpful assistant.'},
+ {'role': 'user', 'content': query_input}]
+ response = Generation.call(model=self.model_type,
+ messages=messages,
+ # ่ฎพ็ฝฎ้ๆบๆฐ็งๅญseed๏ผๅฆๆๆฒกๆ่ฎพ็ฝฎ๏ผๅ้ๆบๆฐ็งๅญ้ป่ฎคไธบ1234
+ seed=1234,
+ temperature=0,
+ top_p=0.8,
+ top_k=50,
+ # ๅฐ่พๅบ่ฎพ็ฝฎไธบ"message"ๆ ผๅผ
+ result_format='message')
+ # if response.status_code == HTTPStatus.OK:
+ # print(response.output.choices[0].message.content)
+
+ nested_json_str = response.output.choices[0].message.content
+ # except (Exception, KeyboardInterrupt) as e:
+ # print(e)
+ # print("response:",response)
+
+
+ return nested_json_str
+
+ def check_model_type(self):
+ # # * warning if not using gpt-4, recommend using gpt-4 for this task
+ # if "gpt-4" not in self.model_type:
+ # print(f"[WARNING] You are using {self.model_type} for evaluation. We recommend using gpt-4 for this task.")
+ pass
+
+ def resume_processing(self):
+ processed_results_path = os.path.join(self.output_dir, self.temp_output_file)
+ if os.path.exists(processed_results_path):
+ print("-" * 80)
+ # * print resuming
+ print(f"Resuming processing...")
+ print(f"Loading processed results from {processed_results_path}...")
+ with open(processed_results_path, "r") as f:
+ saved_results = json.load(f)
+ self.correct_predictions = saved_results["correct_predictions"]
+ self.total_predictions = saved_results["total_predictions"]
+ self.invalid_responses = saved_results["invalid_responses"]
+ self.response_data = saved_results["results"]
+ self.prompt_tokens = saved_results["prompt_tokens"]
+ self.completion_tokens = saved_results["completion_tokens"]
+
+ print(f"Processed results: {len(self.response_data)}")
+ # * print the length of all the data
+ print(f"Total results: {len(self.results)}")
+
+ # * remove processed data
+ processed_ids = [d['object_id'] for d in self.response_data]
+ self.results = [r for r in self.results if r['object_id'] not in processed_ids]
+
+ print(f"Remaining results: {len(self.results)}")
+
+ def remove_temp_file(self):
+ processed_results_path = os.path.join(self.output_dir, self.temp_output_file)
+ if os.path.exists(processed_results_path):
+ os.remove(processed_results_path)
+ print("-" * 80)
+ print(f"Removed Temporary file {processed_results_path}")
+
+ def parse_gpt_response_evaluate(self, gpt_response):
+ gpt_response = gpt_response.strip()
+
+ cls_result = gpt_response[0].upper()
+ reason = gpt_response[2:] if len(gpt_response) > 2 else ""
+
+ if cls_result not in ['T', 'F']:
+ self.invalid_responses += 1
+ return 0, "INVALID", gpt_response
+
+ accuracy = 1 if cls_result == 'T' else 0
+
+ return accuracy, cls_result, reason
+
+ def evaluate_result(self, result):
+
+ object_id = result['object_id']
+ ground_truth = result['ground_truth']
+ model_output = result['model_output']
+
+ messages = self.gpt_prompt.format(ground_truth=ground_truth, model_output=model_output)
+
+ gpt_response = self.get_relpy_from_llm(messages)
+
+ prompt_tokens = 0
+ completion_tokens = 0
+
+ accuracy, cls_result, reason = self.parse_gpt_response_evaluate(
+ gpt_response) # return 0, "INVALID", gpt_response if not valid
+
+ return object_id, model_output, ground_truth, accuracy, cls_result, reason, prompt_tokens, completion_tokens
+
+ def evaluate(self):
+
+ self.resume_processing()
+
+ print('-' * 80)
+ print("Starting single-thread evaluation...")
+ results = self.results
+
+ try:
+ for result in tqdm(results):
+ object_id, model_output, ground_truth, accuracy, cls_result, reason, prompt_tokens, completion_tokens = self.evaluate_result(
+ result)
+
+ self.correct_predictions += accuracy
+ self.total_predictions += 1
+ self.prompt_tokens += prompt_tokens
+ self.completion_tokens += completion_tokens
+
+ # save the object_id, model_output, ground_truth, gpt_cls_result and gpt_reason for each result
+ self.response_data.append({
+ 'object_id': object_id,
+ 'ground_truth': ground_truth,
+ 'model_output': model_output,
+ 'gpt_cls_result': cls_result,
+ 'gpt_reason': reason
+ })
+
+ print("Evaluation finished.")
+
+ self.save_results()
+ self.print_results()
+ self.remove_temp_file()
+ except (Exception, KeyboardInterrupt) as e:
+ print(f"Error {e} occurred during parallel evaluation. Saving processed results to temporary file...")
+ self.save_results(is_temp=True)
+ exit()
+
+ def parallel_evaluate(self, num_workers=20):
+
+ self.resume_processing()
+
+ print('-' * 80)
+ print("Starting parallel evaluation...")
+ results = self.results
+
+ try:
+ # ไฝฟ็จThreadPoolExecutorๅๅปบ็บฟ็จๆฑ
+ with ThreadPoolExecutor(max_workers=num_workers) as executor:
+ # ๅๅปบ่ฟๅบฆๆก
+ with tqdm(total=len(results)) as pbar:
+ # ๆไบคไปปๅกๅนถๆถ้Futureๅฏน่ฑก
+ futures = {executor.submit(self.evaluate_result, result): result for result in results}
+
+ # ้ๅๅทฒๅฎๆ็Futureๅฏน่ฑก
+ for future in futures:
+
+ # ่ทๅFuture็็ปๆ
+ object_id, model_output, ground_truth, accuracy, cls_result, reason, prompt_tokens, completion_tokens = future.result()
+
+ # ๆดๆฐ็ป่ฎกไฟกๆฏ
+ self.correct_predictions += accuracy
+ self.total_predictions += 1
+ self.prompt_tokens += prompt_tokens
+ self.completion_tokens += completion_tokens
+
+ if cls_result == 'INVALID':
+ self.invalid_responses += 1
+
+ # ไฟๅญ็ปๆๆฐๆฎ
+ self.response_data.append({
+ 'object_id': object_id,
+ 'ground_truth': ground_truth,
+ 'model_output': model_output,
+ 'gpt_cls_result': cls_result,
+ 'gpt_reason': reason
+ })
+
+ # ๆดๆฐ่ฟๅบฆๆก
+ pbar.update()
+
+
+ print("Parallel evaluation finished.")
+
+ self.save_results()
+ self.print_results()
+ self.remove_temp_file()
+
+ except (Exception, KeyboardInterrupt) as e:
+ print(f"Error {e} occurred during parallel evaluation. Saving processed results to temporary file...")
+ self.save_results(is_temp=True)
+ exit()
+
+ def save_results(self, is_temp=False):
+ if is_temp:
+ output_path = os.path.join(self.output_dir, self.temp_output_file)
+ else:
+ output_path = os.path.join(self.output_dir, self.output_file)
+ if self.total_predictions - self.invalid_responses == 0:
+ accuracy = 0 # * no results and get error
+ else:
+ accuracy = self.correct_predictions / (self.total_predictions - self.invalid_responses) * 100
+ with open(output_path, 'w') as f:
+ results_to_save = {
+ 'inference_prompt': self.inference_prompt,
+ 'prompt': self.gpt_prompt,
+ 'accuracy': f"{accuracy:.2f}%",
+ 'total_predictions': self.total_predictions,
+ 'correct_predictions': self.correct_predictions,
+ 'invalid_responses': self.invalid_responses,
+ 'prompt_tokens': self.prompt_tokens,
+ 'completion_tokens': self.completion_tokens,
+ 'GPT_cost': self.get_costs(),
+ 'results': self.response_data,
+ }
+ json.dump(results_to_save, f, indent=2)
+
+ print(f"Results saved to {output_path}")
+ # * print the length of saved results
+ print(f"Saved {len(self.response_data)} results in total.")
+
+ def print_results(self):
+ print('-' * 80)
+ if self.total_predictions - self.invalid_responses == 0:
+ accuracy = 0 # * no results and get error
+ else:
+ accuracy = self.correct_predictions / (self.total_predictions - self.invalid_responses) * 100
+ print("Results:")
+ print(f"Accuracy: {accuracy:.2f}%")
+ print(f"Total Predictions: {self.total_predictions}")
+ print(f"Correct Predictions: {self.correct_predictions}")
+ print(f"Invalid Responses: {self.invalid_responses}")
+ self.print_costs()
+
+ def print_costs(self):
+ print(f"Prompt Tokens Price: {self.prompt_tokens * self.price_1k_prompt_tokens / 1000:.2f} USD")
+ print(f"Completion Tokens Price: {self.completion_tokens * self.price_1k_completion_tokens / 1000:.2f} USD")
+
+ def get_costs(self):
+ return self.prompt_tokens * self.price_1k_prompt_tokens / 1000 + self.completion_tokens * self.price_1k_completion_tokens / 1000
+
+
+class OpenAICloseSetClsEvaluator(OpenAIOpenFreeFormClsEvaluator):
+ def __init__(self, inputs, output_dir, output_file, model_type="gpt-3.5-turbo-0613", client=None):
+ super().__init__(inputs, output_dir, output_file, model_type, client=client)
+ self.gpt_prompt = LLM_close_set_cls_prompt
+ self.invalid_correct_predictions = 0 # * random choice and correct coincidently
+
+ # * import category names
+ try:
+ # # * load a txt files of category names
+ catfile = os.path.join(os.path.dirname(__file__),
+ '../data/modelnet_config/modelnet40_shape_names_modified.txt') # * i.e. pointllm/data/modelnet_config/modelnet40_shape_names_modified.txt
+
+
+
+ self.candidate_lists_names = [line.strip() for line in open(catfile)] # * list of category names
+ except:
+ print(f"Current categories file is {catfile}. Need to move the category file to pointllm/eval/configs/.")
+
+ # * make the prompt
+ candidate_lists = [f'{i}: {cat}' for i, cat in enumerate(self.candidate_lists_names)]
+ self.num_categories = len(candidate_lists)
+ self.candidate_lists = '\n'.join(candidate_lists)
+ self.gpt_prompt = self.gpt_prompt.format(num_categories=self.num_categories,
+ candidate_lists=self.candidate_lists) + "{model_output}\nOutput: "
+
+ def check_model_type(self):
+ # * no need to check for this task
+ return
+
+ def resume_processing(self):
+ processed_results_path = os.path.join(self.output_dir, self.temp_output_file)
+ if os.path.exists(processed_results_path):
+ print("-" * 80)
+ # * print resuming
+ print(f"Resuming processing...")
+ print(f"Loading processed results from {processed_results_path}...")
+ with open(processed_results_path, "r") as f:
+ saved_results = json.load(f)
+ self.correct_predictions = saved_results["correct_predictions"]
+ self.total_predictions = saved_results["total_predictions"]
+ self.invalid_responses = saved_results["invalid_responses"]
+ self.invalid_correct_predictions = saved_results["invalid_correct_predictions"]
+ self.response_data = saved_results["results"]
+ self.prompt_tokens = saved_results["prompt_tokens"]
+ self.completion_tokens = saved_results["completion_tokens"]
+
+ print(f"Processed results: {len(self.response_data)}")
+ # * print the length of all the data
+ print(f"Total results: {len(self.results)}")
+
+ # * remove processed data
+ processed_ids = [d['object_id'] for d in self.response_data]
+ self.results = [r for r in self.results if r['object_id'] not in processed_ids]
+
+ print(f"Remaining results: {len(self.results)}")
+
+ def parse_gpt_response_evaluate(self, gpt_response, ground_truth):
+ """
+ Argument:
+ gpt_response: str, index#label#short_reason
+ groud_truth: int
+ """
+
+ # * use regular expression to extract
+ pattern = r'(\d+#[^#]*#.*$)'
+ match = re.search(pattern, gpt_response)
+
+ gpt_response = match.group(1) if match else gpt_response
+
+ gpt_response = gpt_response.strip()
+ gpt_response_list = gpt_response.split('#')
+
+ cls_result = gpt_response_list[0]
+ cls_label = gpt_response_list[1] if len(gpt_response_list) > 1 else ""
+ reason = gpt_response_list[2] if len(gpt_response_list) > 2 else ""
+
+ try:
+ # * convert to int
+ cls_result = int(cls_result)
+ if cls_result not in range(self.num_categories) or cls_label == "NA":
+ # * not valid range
+ cls_result = -1
+ except ValueError:
+ print(f"Error: unale to parse {gpt_response}.")
+ cls_result = -1
+
+ if cls_result == -1:
+ # * random choose one index from 0 to self.num_categories
+ cls_result = random.choice(range(self.num_categories))
+ cls_label = "INVALID"
+ reason = gpt_response
+
+ self.invalid_responses += 1
+
+ accuracy = 1 if cls_result == ground_truth else 0
+
+ return accuracy, cls_result, cls_label, reason
+
+ def evaluate_result(self, result):
+
+
+ object_id = result.get('object_id', -1)
+ ground_truth = result['ground_truth']
+ ground_truth_label = result['label_name']
+ model_output = result['model_output']
+
+ messages = self.gpt_prompt.format(model_output=model_output)
+
+ gpt_response = self.get_relpy_from_llm(messages)
+
+ prompt_tokens =0
+ completion_tokens = 0
+
+ gpt_response = gpt_response
+
+ accuracy, cls_result, cls_label, reason = self.parse_gpt_response_evaluate(gpt_response,
+ ground_truth) # return 0, "INVALID", gpt_response if not valid
+
+ return object_id, model_output, ground_truth, accuracy, cls_result, cls_label, reason, ground_truth_label, prompt_tokens, completion_tokens
+
+ def evaluate(self):
+
+ self.resume_processing()
+
+ print('-' * 80)
+ print("Starting single-thread evaluation...")
+ results = self.results
+
+ try:
+ for result in tqdm(results):
+
+ object_id, model_output, ground_truth, accuracy, cls_result, cls_label, reason, ground_truth_label, prompt_tokens, completion_tokens = self.evaluate_result(
+ result)
+ self.correct_predictions += accuracy
+ self.total_predictions += 1
+
+ if cls_label == "INVALID":
+ self.invalid_correct_predictions += accuracy
+ self.invalid_responses += 1
+
+ self.prompt_tokens += prompt_tokens
+ self.completion_tokens += completion_tokens
+
+ # save the object_id, model_output, ground_truth, gpt_cls_result and gpt_reason for each result
+ self.response_data.append({
+ 'object_id': object_id,
+ 'ground_truth': ground_truth,
+ 'gpt_cls_result': cls_result,
+ 'ground_truth_label': ground_truth_label,
+ 'gpt_cls_label': cls_label,
+ 'model_output': model_output,
+ 'gpt_reason': reason,
+ 'prompt_tokens': prompt_tokens,
+ 'completion_tokens': completion_tokens
+ })
+
+ print("Evaluation finished.")
+
+ self.save_results()
+ self.print_results()
+ self.remove_temp_file()
+ except (Exception, KeyboardInterrupt) as e:
+ print(f"Error {e} occurred during parallel evaluation. Saving processed results to temporary file...")
+ print(f"Current sample is {result}.")
+ self.save_results(is_temp=True)
+ exit()
+
+ def parallel_evaluate(self, num_workers=20):
+
+ self.resume_processing()
+
+ print('-' * 80)
+ print("Starting parallel evaluation...")
+ results = self.results
+
+ try:
+ # ไฝฟ็จThreadPoolExecutorๅๅปบ็บฟ็จๆฑ
+ with ThreadPoolExecutor(max_workers=num_workers) as executor:
+ # ๅๅปบ่ฟๅบฆๆก
+ with tqdm(total=len(results)) as pbar:
+ # ๆไบคไปปๅกๅนถๆถ้Futureๅฏน่ฑก
+ futures = {executor.submit(self.evaluate_result, result): result for result in results}
+
+ # ้ๅๅทฒๅฎๆ็Futureๅฏน่ฑก
+ for future in futures:
+ # ่ทๅFuture็็ปๆ
+ object_id, model_output, ground_truth, accuracy, cls_result, cls_label, reason, ground_truth_label, prompt_tokens, completion_tokens = future.result()
+
+ self.correct_predictions += accuracy
+ self.total_predictions += 1
+
+ self.prompt_tokens += prompt_tokens
+ self.completion_tokens += completion_tokens
+
+ if cls_label == "INVALID":
+ self.invalid_correct_predictions += accuracy
+ self.invalid_responses += 1
+
+ # save the object_id, model_output, ground_truth, gpt_cls_result and gpt_reason for each result
+ self.response_data.append({
+ 'object_id': object_id,
+ 'ground_truth': ground_truth,
+ 'gpt_cls_result': cls_result,
+ 'ground_truth_label': ground_truth_label,
+ 'gpt_cls_label': cls_label,
+ 'model_output': model_output,
+ 'gpt_reason': reason,
+ 'prompt_tokens': prompt_tokens,
+ 'completion_tokens': completion_tokens
+ })
+
+ pbar.update() # update the progress bar
+
+ print("Parallel evaluation finished.")
+
+ self.save_results()
+ self.print_results()
+ self.remove_temp_file()
+
+ except (Exception, KeyboardInterrupt) as e:
+ print(f"Error {e} occurred during parallel evaluation. Saving processed results to temporary file...")
+ self.save_results(is_temp=True)
+ exit()
+
+ def save_results(self, is_temp=False):
+ if is_temp:
+ output_path = os.path.join(self.output_dir, self.temp_output_file)
+ else:
+ output_path = os.path.join(self.output_dir, self.output_file)
+ if self.total_predictions - self.invalid_responses == 0:
+ accuracy = 0 # * no results and get error
+ clean_accuracy = 0
+ else:
+ accuracy = self.correct_predictions / self.total_predictions * 100
+ clean_accuracy = (self.correct_predictions - self.invalid_correct_predictions) / (
+ self.total_predictions - self.invalid_responses) * 100
+ with open(output_path, 'w') as f:
+ results_to_save = {
+ 'inference_prompt': self.inference_prompt,
+ 'prompt': self.gpt_prompt,
+ 'accuracy': f"{accuracy:.2f}%",
+ 'clean_accuracy': f"{clean_accuracy:.2f}%",
+ 'total_predictions': self.total_predictions,
+ 'correct_predictions': self.correct_predictions,
+ 'invalid_correct_predictions': self.invalid_correct_predictions,
+ 'invalid_responses': self.invalid_responses,
+ 'prompt_tokens': self.prompt_tokens,
+ 'completion_tokens': self.completion_tokens,
+ 'GPT_cost': self.get_costs(),
+ 'results': self.response_data,
+ }
+ json.dump(results_to_save, f, indent=2)
+
+ print(f"Results saved to {output_path}")
+ # * print the length of saved results
+ print(f"Saved {len(self.response_data)} results in total.")
+
+ def print_results(self):
+ print('-' * 80)
+ if self.total_predictions - self.invalid_responses == 0:
+ accuracy = 0 # * no results and get error
+ else:
+ accuracy = self.correct_predictions / self.total_predictions * 100
+ clean_accuracy = (self.correct_predictions - self.invalid_correct_predictions) / (
+ self.total_predictions - self.invalid_responses) * 100
+ accuracy = self.correct_predictions / self.total_predictions * 100
+ print("Results:")
+ print(f"Accuracy: {accuracy:.2f}%")
+ print(f"Clean Accuracy: {clean_accuracy:.2f}%", )
+ print(f"Total Predictions: {self.total_predictions}")
+ print(f"Correct Predictions: {self.correct_predictions}")
+ print(f"Invalid Correct Predictions: {self.invalid_correct_predictions}")
+ print(f"Invalid Responses: {self.invalid_responses}")
+ print(f"Prompt Tokens: {self.prompt_tokens}")
+ print(f"Completion Tokens: {self.completion_tokens}")
+
+ self.print_costs()
+
+
+class OpenAIObjectCaptioningEvaluator(OpenAIOpenFreeFormClsEvaluator):
+ def __init__(self, inputs, output_dir, output_file, model_type="gpt-4-0613", client=None):
+ super().__init__(inputs, output_dir, output_file, model_type, client=client)
+ self.gpt_prompt = LLM_object_captioning_prompt
+
+ self.total_scores = 0
+
+ def resume_processing(self):
+ processed_results_path = os.path.join(self.output_dir, self.temp_output_file)
+ if os.path.exists(processed_results_path):
+ print("-" * 80)
+ # * print resuming
+ print(f"Resuming processing...")
+ print(f"Loading processed results from {processed_results_path}...")
+ with open(processed_results_path, "r") as f:
+ saved_results = json.load(f)
+ self.total_scores = float(saved_results["total_score"])
+
+ self.total_predictions = saved_results["total_predictions"]
+ self.invalid_responses = saved_results["invalid_responses"]
+ self.response_data = saved_results["results"]
+ self.prompt_tokens = saved_results["prompt_tokens"]
+ self.completion_tokens = saved_results["completion_tokens"]
+
+ print(f"Processed results: {len(self.response_data)}")
+ # * print the length of all the data
+ print(f"Total results: {len(self.results)}")
+
+ # * remove processed data
+ processed_ids = [d['object_id'] for d in self.response_data]
+ self.results = [r for r in self.results if r['object_id'] not in processed_ids]
+
+ print(f"Remaining results: {len(self.results)}")
+
+ def parse_gpt_response_evaluate(self, gpt_response, ground_truth):
+ """
+ Argument:
+ gpt_response: str, index#label#short_reason
+ groud_truth: int
+ """
+
+ # * use regular expression to extract
+ pattern = r'(\d*#.*)'
+ match = re.search(pattern, gpt_response)
+
+ gpt_response = match.group(1) if match else gpt_response
+
+ gpt_response = gpt_response.strip()
+ gpt_response_list = gpt_response.split('#')
+
+ gpt_score = gpt_response_list[0]
+ reason = gpt_response_list[1] if len(gpt_response_list) > 1 else ""
+
+ try:
+ # * convert to int
+ gpt_score = int(gpt_score)
+ if gpt_score not in range(101): # * in 0-100
+ # * not valid range
+ gpt_score = -1
+ except ValueError:
+ print(f"Error: unale to parse {gpt_response}.")
+ gpt_score = -1
+
+ if gpt_score == -1:
+ reason = gpt_response
+
+ return gpt_score, reason
+
+ def evaluate_result(self, result):
+
+ object_id = result.get('object_id', -1)
+ ground_truth = result['ground_truth']
+ model_output = result['model_output']
+
+ messages = self.gpt_prompt.format(ground_truth=ground_truth, model_output=model_output)
+
+ gpt_response = self.get_relpy_from_llm(messages)
+
+ prompt_tokens = 0
+ completion_tokens = 0
+
+ gpt_response = gpt_response
+
+ gpt_score, reason = self.parse_gpt_response_evaluate(gpt_response,
+ ground_truth) # return 0, "INVALID", gpt_response if not valid
+
+ return object_id, model_output, ground_truth, gpt_score, reason, prompt_tokens, completion_tokens
+
+ def evaluate(self):
+
+ self.resume_processing()
+
+ print('-' * 80)
+ print("Starting single-thread evaluation...")
+ results = self.results
+
+ try:
+ for result in tqdm(results):
+ object_id, model_output, ground_truth, gpt_score, reason, prompt_tokens, completion_tokens = self.evaluate_result(
+ result)
+
+ self.total_scores += gpt_score if gpt_score != -1 else 0
+ self.total_predictions += 1
+ self.prompt_tokens += prompt_tokens
+ self.completion_tokens += completion_tokens
+
+ if gpt_score == -1:
+ self.invalid_responses += 1
+
+ # save the object_id, model_output, ground_truth, gpt_cls_result and gpt_reason for each result
+ self.response_data.append({
+ 'object_id': object_id,
+ 'ground_truth': ground_truth,
+ 'model_output': model_output,
+ "gpt_score": gpt_score,
+ 'gpt_reason': reason
+ })
+
+ print("Evaluation finished.")
+
+ self.save_results()
+ self.print_results()
+ self.remove_temp_file()
+ except (Exception, KeyboardInterrupt) as e:
+ print(f"Error {e} occurred during parallel evaluation. Saving processed results to temporary file...")
+ self.save_results(is_temp=True)
+ exit()
+
+ def parallel_evaluate(self, num_workers=20):
+
+ self.resume_processing()
+
+ print('-' * 80)
+ print("Starting parallel evaluation...")
+ results = self.results
+
+ try:
+ # ไฝฟ็จThreadPoolExecutorๅๅปบ็บฟ็จๆฑ
+ with ThreadPoolExecutor(max_workers=num_workers) as executor:
+ # ๅๅปบ่ฟๅบฆๆก
+ with tqdm(total=len(results)) as pbar:
+ # ๆไบคไปปๅกๅนถๆถ้Futureๅฏน่ฑก
+ futures = {executor.submit(self.evaluate_result, result): result for result in results}
+
+ # ้ๅๅทฒๅฎๆ็Futureๅฏน่ฑก
+ for future in futures:
+
+ # ่ทๅFuture็็ปๆ
+ object_id, model_output, ground_truth, gpt_score, reason, prompt_tokens, completion_tokens = future.result()
+
+ self.total_scores += gpt_score if gpt_score != -1 else 0
+ self.total_predictions += 1
+ self.prompt_tokens += prompt_tokens
+ self.completion_tokens += completion_tokens
+
+ if gpt_score == -1:
+ self.invalid_responses += 1
+
+ # save the object_id, model_output, ground_truth, gpt_cls_result and gpt_reason for each result
+ self.response_data.append({
+ 'object_id': object_id,
+ 'ground_truth': ground_truth,
+ 'model_output': model_output,
+ "gpt_score": gpt_score,
+ 'gpt_reason': reason
+ })
+
+ pbar.update() # update the progress bar
+
+ print("Parallel evaluation finished.")
+
+ self.save_results()
+ self.print_results()
+ self.remove_temp_file()
+
+ except (Exception, KeyboardInterrupt) as e:
+ print(f"Error {e} occurred during parallel evaluation. Saving processed results to temporary file...")
+ self.save_results(is_temp=True)
+ exit()
+
+ def save_results(self, is_temp=False):
+ if is_temp:
+ output_path = os.path.join(self.output_dir, self.temp_output_file)
+ else:
+ output_path = os.path.join(self.output_dir, self.output_file)
+ if self.total_predictions - self.invalid_responses == 0:
+ average_score = 0 # * no results and get error
+ else:
+ average_score = self.total_scores / (self.total_predictions - self.invalid_responses)
+ with open(output_path, 'w') as f:
+ results_to_save = {
+ 'inference_prompt': self.inference_prompt,
+ 'gpt_prompt': self.gpt_prompt,
+ 'average_score': f"{average_score:.2f}",
+ 'total_score': f"{self.total_scores:.2f}",
+ 'total_predictions': self.total_predictions,
+ 'invalid_responses': self.invalid_responses,
+ 'prompt_tokens': self.prompt_tokens,
+ 'completion_tokens': self.completion_tokens,
+ 'GPT_cost': self.get_costs(),
+ 'results': self.response_data,
+ }
+ json.dump(results_to_save, f, indent=2)
+
+ print(f"Results saved to {output_path}")
+ # * print the length of saved results
+ print(f"Saved {len(self.response_data)} results in total.")
+
+ def print_results(self):
+ print('-' * 80)
+ if self.total_predictions - self.invalid_responses == 0:
+ average_score = 0 # * no results and get error
+ else:
+ average_score = self.total_scores / (self.total_predictions - self.invalid_responses)
+ print("Results:")
+ print(f"Average Score: {average_score:.2f}")
+ print(f"Total Predictions: {self.total_predictions}")
+ print(f"Invalid Responses: {self.invalid_responses}")
+ print(f"Prompt Tokens: {self.prompt_tokens}")
+ print(f"Completion Tokens: {self.completion_tokens}")
+
+ self.print_costs()
+
+
+def convert_model_name_to_spaces_url(model_name: str) -> str:
+ # ๆฟๆขๆๆ ไธบ็ญๆจช็บฟ๏ผๅนถๅฐๆๆๅญ็ฌฆ่ฝฌไธบๅฐๅ
+ formatted_name = model_name.replace('/', '-').lower()
+ # ๆผๆฅๆๅฎๆด็URL
+ spaces_url = f"https://{formatted_name}.hf.space"
+ return spaces_url
+
+
+def start_evaluation(results, output_dir, output_file, eval_type="open-free-form-classification",
+ model_type="gpt-3.5-turbo-0613",
+ parallel=True, num_workers=20):
+ """
+ Args:
+ results: dict or file path to the json file containing the dict
+ output_file: the path the final evaluation results to be saved.
+ """
+ if isinstance(results, str):
+ with open(results, 'r') as fp:
+ results = json.load(fp)
+
+
+ # MY_client = Client(convert_model_name_to_spaces_url(model_type))
+ # MY_client = Client("https://s5k.cn/api/v1/studio/qwen/Qwen2-72B-Instruct-demo/gradio/")
+ MY_client = None
+
+ print("eval_type:",eval_type)
+ if eval_type == "open-free-form-classification":
+ evaluator = OpenAIOpenFreeFormClsEvaluator(results, output_dir, output_file, model_type=model_type, client=MY_client)
+ elif eval_type == "modelnet-close-set-classification":
+ evaluator = OpenAICloseSetClsEvaluator(results, output_dir, output_file, model_type=model_type, client=MY_client)
+ elif eval_type == "object-captioning":
+ evaluator = OpenAIObjectCaptioningEvaluator(results, output_dir, output_file, model_type=model_type, client=MY_client)
+ else:
+ raise NotImplementedError(f"eval_type {eval_type} not supported.")
+
+ if parallel:
+ evaluator.parallel_evaluate(num_workers=num_workers)
+ else:
+ evaluator.evaluate()
+
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument("--results_path", type=str, \
+ default="", help="Path to the results file.")
+ parser.add_argument("--output_dir", type=str, default=None, help="Path to the output directory.")
+ parser.add_argument("--model_type", type=str, default="Qwen/Qwen2-72B-Instruct",
+ help="Type of the model in hugging face used to evaluate.")
+ parser.add_argument("--parallel", default=True, action="store_true", help="Whether to use parallel evaluation.")
+ parser.add_argument("--num_workers", type=int, default=15, help="Number of workers to use for parallel evaluation.")
+ parser.add_argument("--eval_type", type=str,
+ choices=["modelnet-close-set-classification", "open-free-form-classification",
+ "object-captioning"], default="object-captioning")
+
+ args = parser.parse_args()
+
+
+ if args.output_dir is None:
+ args.output_dir = os.path.dirname(args.results_path)
+
+ output_file = os.path.basename(args.results_path).replace(".json",
+ f"_evaluated_{(args.model_type).split('/')[-1]}.json")
+
+ # if exists, then exit
+ if os.path.exists(os.path.join(args.output_dir, output_file)):
+ print(f"[INFO] Evaulated results already exists in {os.path.join(args.output_dir, output_file)}.")
+ exit()
+
+ start_evaluation(results=args.results_path, output_dir=args.output_dir, output_file=output_file,
+ eval_type=args.eval_type, model_type=args.model_type,
+ parallel=args.parallel, num_workers=args.num_workers)
diff --git a/pointllm/eval/traditional_evaluator.py b/pointllm/eval/traditional_evaluator.py
new file mode 100644
index 0000000000000000000000000000000000000000..f877b357a8c7984b09b93e8b52a193140bc1099f
--- /dev/null
+++ b/pointllm/eval/traditional_evaluator.py
@@ -0,0 +1,189 @@
+import os
+# os.environ["CUDA_VISIBLE_DEVICES"]="0"
+
+import argparse
+import json
+import os
+import random
+random.seed(0)
+
+import nltk
+# nltk.download('wordnet')
+from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
+from nltk.translate.meteor_score import meteor_score
+from rouge import Rouge
+from sentence_transformers import SentenceTransformer, util
+from scipy.spatial.distance import cosine
+from transformers import AutoModel, AutoTokenizer
+import torch
+
+
+import numpy as np
+from tqdm import tqdm
+
+class TraditionalMetricEvaluator():
+ def __init__(self, inputs, output_dir, output_file):
+ self.results = inputs['results']
+ self.inference_prompt = inputs['prompt']
+ self.output_dir = output_dir
+ self.output_file = output_file
+ self.rouge = Rouge()
+ self.response_data = []
+
+ self.ground_truths = []
+ self.generated_captions = []
+
+ # self.sbert_model = SentenceTransformer('all-mpnet-base-v2')
+ self.sbert_model = SentenceTransformer("./pretrained_weight/eval_model_weight/all-mpnet-base-v2")
+
+ self.simcse_tokenizer = AutoTokenizer.from_pretrained("./pretrained_weight/eval_model_weight/sup-simcse-roberta-large")
+ # self.simcse_tokenizer = AutoTokenizer.from_pretrained("princeton-nlp/sup-simcse-roberta-large")
+ self.simcse_model = AutoModel.from_pretrained("./pretrained_weight/eval_model_weight/sup-simcse-roberta-large")
+ # self.simcse_model = AutoModel.from_pretrained("princeton-nlp/sup-simcse-roberta-large")
+
+ self.scores = {
+ 'bleu-1': [],
+ 'bleu-2': [],
+ 'bleu-3': [],
+ 'bleu-4': [],
+ 'rouge-1': [],
+ 'rouge-2': [],
+ 'rouge-l': [],
+ 'meteor': [],
+ 'sbert_similarity': [],
+ 'simcse_similarity': []
+ }
+
+ def evaluate_result(self, result):
+ object_id = result['object_id']
+ ground_truth = result['ground_truth']
+ model_output = result['model_output']
+
+ if model_output == "":
+ # * all score should be 0
+ model_output = "##"
+
+ # create a SmoothingFunction object
+ smoothing_function = SmoothingFunction().method1 # * used to deal with non-overlap n-gram
+
+ # calculate BLEU-1 score with smoothing function
+ bleu_1_score = sentence_bleu([ground_truth.split()], model_output.split(), weights=(1, 0, 0, 0), smoothing_function=smoothing_function)
+
+ # calculate BLEU-2, BLEU-3, and BLEU-4 scores
+ bleu_2_score = sentence_bleu([ground_truth.split()], model_output.split(), weights=(0.5, 0.5, 0, 0), smoothing_function=smoothing_function)
+ bleu_3_score = sentence_bleu([ground_truth.split()], model_output.split(), weights=(0.33, 0.33, 0.33, 0), smoothing_function=smoothing_function)
+ bleu_4_score = sentence_bleu([ground_truth.split()], model_output.split(), weights=(0.25, 0.25, 0.25, 0.25), smoothing_function=smoothing_function)
+
+ # calculate ROUGE-L score
+ rouge_scores_l = self.rouge.get_scores(model_output, ground_truth)[0]['rouge-l']
+ rouge_scores_1 = self.rouge.get_scores(model_output, ground_truth)[0]['rouge-1']
+ rouge_scores_2 = self.rouge.get_scores(model_output, ground_truth)[0]['rouge-2']
+
+ # calculate METEOR score
+ meteor_scores = meteor_score([ground_truth.split()], model_output.split())
+
+ # Calculate SBERT similarity
+ embeddings = self.sbert_model.encode([ground_truth, model_output])
+ sbert_similarity = util.cos_sim(embeddings[0], embeddings[1])[0][0].item()
+
+ # calculate SimCSE similarity
+ # Tokenize input texts
+ inputs = self.simcse_tokenizer([ground_truth, model_output], padding=True, truncation=True, return_tensors="pt")
+
+ # Get the embeddings
+ with torch.no_grad():
+ embeddings = self.simcse_model(**inputs, output_hidden_states=True, return_dict=True).pooler_output
+
+ # Calculate cosine similarity
+ simcse_similarity = 1 - cosine(embeddings[0], embeddings[1]) # * consine actually calculates consine distance, which is 1 - consine similarity
+
+ scores = {
+ 'bleu-1': bleu_1_score * 100,
+ 'bleu-2': bleu_2_score * 100,
+ 'bleu-3': bleu_3_score * 100,
+ 'bleu-4': bleu_4_score * 100,
+ 'rouge-l': rouge_scores_l['f'] * 100,
+ 'rouge-1': rouge_scores_1['f'] * 100,
+ 'rouge-2': rouge_scores_2['f'] * 100,
+ 'meteor': meteor_scores * 100,
+ 'sbert_similarity': sbert_similarity * 100,
+ 'simcse_similarity': simcse_similarity * 100
+ }
+
+ return object_id, model_output, ground_truth, scores
+
+ def evaluate(self):
+ print("Starting evaluation...")
+
+ for result in tqdm(self.results, desc="Evaluating"):
+ object_id, model_output, ground_truth, scores = self.evaluate_result(result)
+
+ # save the object_id, model_output, ground_truth, and scores for each result
+ self.response_data.append({
+ 'object_id': object_id,
+ 'ground_truth': ground_truth,
+ 'model_output': model_output,
+ 'scores': scores,
+ })
+
+ # save the scores for overall results
+ for metric, score in scores.items():
+ self.scores[metric].append(score)
+
+ print("Evaluation finished.")
+ self.save_results()
+ self.print_results()
+
+ def save_results(self):
+ output_path = os.path.join(self.output_dir, self.output_file)
+
+ with open(output_path, 'w') as f:
+ results_to_save = {
+ 'inference_prompt': self.inference_prompt,
+ 'overall_scores': {metric: f"{np.mean(scores):.4f}" for metric, scores in self.scores.items()},
+ 'results': self.response_data,
+ }
+ json.dump(results_to_save, f, indent=2)
+
+ print(f"Results saved to {output_path}")
+
+ def print_results(self):
+ print('-' * 80)
+ print("Results:")
+ for metric, scores in self.scores.items():
+ print(f"Average {metric.upper()} Score: {np.mean(scores):.4f}")
+
+def start_evaluation(results, output_dir, output_file,
+ parallel=True, num_workers=20):
+ """
+ Args:
+ results: dict or file path to the json file containing the dict
+ output_file: the path the final evaluation results to be saved.
+ """
+ if isinstance(results, str):
+ with open(results, 'r') as fp:
+ results = json.load(fp)
+
+ evaluator = TraditionalMetricEvaluator(results, output_dir, output_file)
+ evaluator.evaluate()
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument("--results_path", type=str, \
+ default="", help="Path to the results file.")
+ parser.add_argument("--output_dir", type=str, default=None, help="Path to the output directory.")
+
+ args = parser.parse_args()
+
+ if args.output_dir is None:
+ args.output_dir = os.path.dirname(args.results_path)
+
+ output_file = os.path.basename(args.results_path).replace(".json", f"_evaluated_traditional.json")
+
+ output_path = os.path.join(args.output_dir, output_file)
+ if not os.path.exists(output_path):
+ start_evaluation(results=args.results_path, output_dir=args.output_dir, output_file=output_file)
+ else:
+ print(f'[INFO] {output_file} already exists, directly loading...')
+
\ No newline at end of file
diff --git a/pointllm/eval/utils.py b/pointllm/eval/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3d4419acb61f45189dc06da0b0a8f83f92fe337
--- /dev/null
+++ b/pointllm/eval/utils.py
@@ -0,0 +1,75 @@
+import openai
+
+openai.api_base = "https://api.chatanywhere.tech/v1"
+# openai.api_base = "https://api.chatanywhere.com.cn/v1"
+
+# openai.api_base = "https://api.v36.cm/v1"
+
+import time
+import random
+import os
+
+def retry_with_exponential_backoff(
+ func,
+ initial_delay: float = 1,
+ exponential_base: float = 2,
+ jitter: bool = True,
+ max_retries: int = 40,
+ max_delay: int = 30,
+ errors: tuple = (openai.error.RateLimitError, openai.error.ServiceUnavailableError, openai.error.Timeout),
+):
+ """Retry a function with exponential backoff."""
+ def wrapper(*args, **kwargs):
+ num_retries = 0
+ delay = initial_delay
+
+ while True:
+ try:
+ return func(*args, **kwargs)
+ except errors as e:
+ # * print the error info
+ num_retries += 1
+ if num_retries > max_retries:
+ print(f"[OPENAI] Encounter error: {e}.")
+ raise Exception(
+ f"[OPENAI] Maximum number of retries ({max_retries}) exceeded."
+ )
+ delay *= exponential_base * (1 + jitter * random.random())
+ time.sleep(min(delay, max_delay))
+ except Exception as e:
+ raise e
+ return wrapper
+
+class OpenAIGPT():
+ def __init__(self, model="gpt-3.5-turbo-0613", temperature=1, top_p=1, max_tokens=2048, **kwargs) -> None:
+ setup_openai(model)
+ self.default_chat_parameters = {
+ "model": model,
+ "temperature": temperature,
+ "top_p": top_p,
+ "max_tokens": max_tokens,
+ **kwargs
+ }
+
+ @retry_with_exponential_backoff
+ def safe_chat_complete(self, messages, content_only=True, **kwargs):
+ chat_parameters = self.default_chat_parameters.copy()
+ if len(kwargs) > 0:
+ chat_parameters.update(**kwargs)
+
+ response = openai.ChatCompletion.create(
+ messages=messages,
+ **chat_parameters
+ )
+
+ if content_only:
+ response = response['choices'][0]["message"]['content']
+
+ return response
+
+def setup_openai(model_name):
+ # Setup OpenAI API Key
+ print("[OPENAI] Setting OpenAI api_key...")
+ openai.api_key = os.getenv('OPENAI_API_KEY')
+ print(f"[OPENAI] OpenAI organization: {openai.organization}")
+ print(f"[OPENAI] Using MODEL: {model_name}")
\ No newline at end of file
diff --git a/pointllm/model/__init__.py b/pointllm/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8b1696d3a0db31ec6f6f22374a831f3326cab8d
--- /dev/null
+++ b/pointllm/model/__init__.py
@@ -0,0 +1,2 @@
+# from .pointllm import PointLLMLlamaForCausalLM, PointLLMConfig
+# from .pointbert.point_encoder import PointTransformer
\ No newline at end of file
diff --git a/pointllm/model/pointllm.py b/pointllm/model/pointllm.py
new file mode 100644
index 0000000000000000000000000000000000000000..d33002aa0131eea89bf63bed62b3fc54405fb658
--- /dev/null
+++ b/pointllm/model/pointllm.py
@@ -0,0 +1,365 @@
+# Copyright 2023 Runsen Xu
+
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+from torch.nn import CrossEntropyLoss
+from .utils import *
+from pointllm.utils import *
+
+from contextlib import nullcontext
+from transformers import AutoConfig, AutoModelForCausalLM, \
+ LlamaConfig, LlamaModel, LlamaForCausalLM
+
+from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+
+import os
+
+# * add logger
+import logging
+logger = logging.getLogger(__name__)
+
+class PointLLMConfig(LlamaConfig):
+ model_type = "pointllm"
+
+class PointLLMLlamaModel(LlamaModel):
+ config_class = PointLLMConfig
+
+ def __init__(self, config: LlamaConfig):
+ super(PointLLMLlamaModel, self).__init__(config)
+
+ self.point_backbone_type = config.point_backbone
+ logger.info(f"Using {self.point_backbone_type}.")
+
+ if self.point_backbone_type == "PointBERT":
+ from pointllm.model import PointTransformer
+ # address of config file, in the same dir of this file
+ point_bert_config_name = getattr(config, "point_backbone_config_name", "PointTransformer_base_8192point") # * default for v1.1, v1.2 uses PointTransformer_8192point_2layer.yaml
+ point_bert_config_addr = os.path.join(os.path.dirname(__file__), "pointbert", f"{point_bert_config_name}.yaml")
+ print(f"Loading PointBERT config from {point_bert_config_addr}.")
+ point_bert_config = cfg_from_yaml_file(point_bert_config_addr)
+ if getattr(config, "use_color", False):
+ point_bert_config.model.point_dims = 6
+ use_max_pool = getattr(point_bert_config.model, "use_max_pool", False) # * default is false
+
+ self.point_backbone = PointTransformer(point_bert_config.model, use_max_pool=use_max_pool)
+ logger.info(f"Using {self.point_backbone.point_dims} dim of points.")
+
+ self.point_backbone_config = {
+ "point_cloud_dim": point_bert_config.model.point_dims,
+ "backbone_output_dim": point_bert_config.model.trans_dim if not use_max_pool else point_bert_config.model.trans_dim * 2,
+ "project_output_dim": self.config.hidden_size,
+ "point_token_len": point_bert_config.model.num_group + 1 if not use_max_pool else 1, # * number of output features, with cls token
+ "mm_use_point_start_end": self.config.mm_use_point_start_end,
+ "projection_hidden_layer": point_bert_config.model.get('projection_hidden_layer', 0),
+ "use_max_pool": use_max_pool
+ }
+ if point_bert_config.model.get('projection_hidden_layer', 0) > 0:
+ self.point_backbone_config["projection_hidden_dim"] = point_bert_config.model.projection_hidden_dim # a list
+
+ logger.info(f"Use max pool is {use_max_pool}. Number of point token is {self.point_backbone_config['point_token_len']}.")
+
+ # * print relevant info with projection layers
+ backbone_output_dim = self.point_backbone_config["backbone_output_dim"]
+ logger.info(f"Point backbone output dim: {backbone_output_dim}.")
+ logger.info(f"Use {self.point_backbone_config['projection_hidden_layer']} projection hiddent layers.")
+ if self.point_backbone_config['projection_hidden_layer'] > 0:
+ # Add projection layer with linear layers and GELU activation
+ projection_layers = []
+ last_dim = backbone_output_dim
+ for i in range(point_bert_config.model.projection_hidden_layer):
+ projection_layers.append(nn.Linear(last_dim, self.point_backbone_config["projection_hidden_dim"][i]))
+ projection_layers.append(nn.GELU())
+ last_dim = self.point_backbone_config["projection_hidden_dim"][i]
+
+ projection_layers.append(nn.Linear(last_dim, self.point_backbone_config["project_output_dim"]))
+ self.point_proj = nn.Sequential(*projection_layers)
+ logger.info(f"Each layer with {point_bert_config.model.projection_hidden_dim} hidden units.")
+ else:
+ # Single layer
+ self.point_proj = nn.Linear(backbone_output_dim, self.point_backbone_config['project_output_dim'])
+ logger.info(f"Point projector output dim: {self.point_backbone_config['project_output_dim']}.")
+
+ self.fix_pointnet = False
+ self.fix_llm = False
+
+ def load_point_backbone_checkpoint(self, checkpoint_path=None):
+ pass
+ # self.point_backbone.load_checkpoint(self.config.point_backbone_ckpt if checkpoint_path is None else checkpoint_path)
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ point_clouds: Optional[torch.FloatTensor] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+
+ # HACK: replace back original embeddings for LLaVA pretraining
+ orig_embeds_params = getattr(self, 'orig_embeds_params', None)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ point_backbone = getattr(self, 'point_backbone', None)
+ point_backbone_config = getattr(self, 'point_backbone_config', None)
+
+ if point_backbone is not None and (input_ids.shape[1] != 1 or self.training) and point_clouds is not None:
+ # * enter when training or the first generation step of inference
+ with torch.no_grad() if self.fix_pointnet else nullcontext():
+ if self.fix_pointnet:
+ self.point_backbone.eval()
+ if type(point_clouds) is list:
+ # * variable numbers of points
+ point_features = []
+ for point_cloud in point_clouds: # * iterate over batch
+ point_feature = self.point_backbone(point_cloud.unsqueeze(0))[0]
+ point_features.append(point_feature)
+ else:
+ point_features = self.point_backbone(point_clouds)
+
+ if type(point_clouds) is list:
+ point_features = [self.point_proj(point_feature) for point_feature in point_features]
+ else:
+ point_features = self.point_proj(point_features)
+
+ dummy_point_features = torch.zeros(point_backbone_config['point_token_len'], point_backbone_config['backbone_output_dim'], device=inputs_embeds.device, dtype=inputs_embeds.dtype)
+ dummy_point_features = self.point_proj(dummy_point_features)
+
+ new_input_embeds = []
+ cur_point_idx = 0
+ for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds): # * input_ids: B, L; input_embeds: B, L, C
+ if (cur_input_ids == point_backbone_config['point_patch_token']).sum() == 0:
+ # multimodal LLM, but the current sample is not multimodal
+ cur_input_embeds = cur_input_embeds + (0. * dummy_point_features).sum() # * seems doing nothing
+ new_input_embeds.append(cur_input_embeds)
+ cur_point_idx += 1
+ continue
+ cur_point_features = point_features[cur_point_idx].to(device=cur_input_embeds.device)
+ num_patches = cur_point_features.shape[0] # * number of point tokens
+ if point_backbone_config['mm_use_point_start_end']:
+ if (cur_input_ids == point_backbone_config["point_start_token"]).sum() != (cur_input_ids == point_backbone_config["point_end_token"]).sum():
+ raise ValueError("The number of point start tokens and point end tokens should be the same.")
+ point_start_tokens = torch.where(cur_input_ids == point_backbone_config["point_start_token"])[0]
+ for point_start_token_pos in point_start_tokens:
+ if cur_input_ids[point_start_token_pos + num_patches + 1] != point_backbone_config["point_end_token"]:
+ raise ValueError("The point end token should follow the image start token.")
+ if orig_embeds_params is not None: # * will not update the original embeddings except for IMAGE_START_TOKEN and IMAGE_END_TOKEN
+ cur_new_input_embeds = torch.cat((cur_input_embeds[:point_start_token_pos].detach(), cur_input_embeds[point_start_token_pos:point_start_token_pos+1], cur_point_features, cur_input_embeds[point_start_token_pos + num_patches + 1:point_start_token_pos + num_patches + 2], cur_input_embeds[point_start_token_pos + num_patches + 2:].detach()), dim=0)
+ else:
+ cur_new_input_embeds = torch.cat((cur_input_embeds[:point_start_token_pos+1], cur_point_features, cur_input_embeds[point_start_token_pos + num_patches + 1:]), dim=0)
+ cur_point_idx += 1
+ new_input_embeds.append(cur_new_input_embeds)
+ else:
+ if (cur_input_ids == point_backbone_config["point_patch_token"]).sum() != num_patches:
+ raise ValueError("The number of point patch tokens should be the same as the number of point patches.")
+ masked_indices = torch.where(cur_input_ids == point_backbone_config["point_patch_token"])[0]
+ mask_index_start = masked_indices[0]
+ if (masked_indices != torch.arange(mask_index_start, mask_index_start+num_patches, device=masked_indices.device, dtype=masked_indices.dtype)).any():
+ raise ValueError("The image patch tokens should be consecutive.")
+ if orig_embeds_params is not None:
+ cur_new_input_embeds = torch.cat((cur_input_embeds[:mask_index_start].detach(), cur_point_features, cur_input_embeds[mask_index_start+num_patches:].detach()), dim=0)
+ else:
+ cur_new_input_embeds = torch.cat((cur_input_embeds[:mask_index_start], cur_point_features, cur_input_embeds[mask_index_start+num_patches:]), dim=0)
+ new_input_embeds.append(cur_new_input_embeds)
+ cur_point_idx += 1
+ inputs_embeds = torch.stack(new_input_embeds, dim=0)
+
+ return super(PointLLMLlamaModel, self).forward(
+ input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds, use_cache=use_cache,
+ output_attentions=output_attentions, output_hidden_states=output_hidden_states,
+ return_dict=return_dict
+ )
+
+
+class PointLLMLlamaForCausalLM(LlamaForCausalLM):
+ config_class = PointLLMConfig
+
+ def __init__(self, config):
+ super(LlamaForCausalLM, self).__init__(config)
+ self.model = PointLLMLlamaModel(config)
+
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_model(self):
+ return self.model
+
+ def maybe_autocast(self, dtype=torch.float16):
+ # if on cpu, don't use autocast
+ # if on gpu, use autocast with dtype if provided, otherwise use torch.float16
+ enable_autocast = self.device != torch.device("cpu")
+
+ if enable_autocast:
+ return torch.cuda.amp.autocast(dtype=dtype)
+ else:
+ return contextlib.nullcontext()
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None, # * control whether to return past_key_values
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ point_clouds: Optional[torch.FloatTensor] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ point_clouds=point_clouds
+ )
+
+ hidden_states = outputs[0]
+ logits = self.lm_head(hidden_states)
+
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous() # * B, L, V(32003)
+ shift_labels = labels[..., 1:].contiguous() # * B, L
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
+ shift_labels = shift_labels.view(-1)
+ # Enable model/pipeline parallelism
+ shift_labels = shift_labels.to(shift_logits.device)
+ loss = loss_fct(shift_logits, shift_labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
+ ):
+ if past_key_values:
+ input_ids = input_ids[:, -1:]
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and past_key_values is None:
+ model_inputs = {"inputs_embeds": inputs_embeds}
+ else:
+ model_inputs = {"input_ids": input_ids}
+
+ model_inputs.update(
+ {
+ "past_key_values": past_key_values,
+ "use_cache": kwargs.get("use_cache"),
+ "attention_mask": attention_mask,
+ "point_clouds": kwargs.get("point_clouds", None),
+ }
+ )
+ return model_inputs
+
+ def initialize_tokenizer_point_backbone_config_wo_embedding(self, tokenizer):
+ # * called when stage2 or inference or inference without pre-training, assume tokenizer has point tokens
+ config = self.config
+ point_backbone_config = self.get_model().point_backbone_config
+ mm_use_point_start_end = point_backbone_config['mm_use_point_start_end'] = config.mm_use_point_start_end
+
+ default_point_patch_token = config.DEFAULT_POINT_PATCH_TOKEN
+
+ tokenizer.add_tokens([default_point_patch_token], special_tokens=True)
+
+ # * assert tokenizer has the default_point_patch_token
+ point_backbone_config['default_point_patch_token'] = default_point_patch_token
+ point_backbone_config['point_patch_token'] = tokenizer.convert_tokens_to_ids([default_point_patch_token])[0]
+
+ if mm_use_point_start_end:
+ default_point_start_token = config.DEFAULT_POINT_START_TOKEN
+ default_point_end_token = config.DEFAULT_POINT_END_TOKEN
+ tokenizer.add_tokens([default_point_start_token, default_point_end_token], special_tokens=True)
+
+ point_backbone_config['default_point_start_token'] = default_point_start_token
+ point_backbone_config['default_point_end_token'] = default_point_end_token
+
+ point_backbone_config["point_start_token"] = tokenizer.convert_tokens_to_ids([default_point_start_token])[0]
+ point_backbone_config["point_end_token"] = tokenizer.convert_tokens_to_ids([default_point_end_token])[0]
+
+ def initialize_tokenizer_point_backbone_config(self, tokenizer, device, fix_llm=True):
+
+ config = self.config
+ point_backbone_config = self.get_model().point_backbone_config
+ mm_use_point_start_end = point_backbone_config['mm_use_point_start_end'] = config.mm_use_point_start_end
+
+ default_point_patch_token = config.DEFAULT_POINT_PATCH_TOKEN
+ point_backbone_config['default_point_patch_token'] = default_point_patch_token
+ tokenizer.add_tokens([default_point_patch_token], special_tokens=True) # * no need to update embed since it will be replaced
+ self.resize_token_embeddings(len(tokenizer)) # ! resize_token_embeddings will make the tokens trainable again
+ point_backbone_config['point_patch_token'] = tokenizer.convert_tokens_to_ids([default_point_patch_token])[0]
+
+ if mm_use_point_start_end:
+ default_point_start_token = config.DEFAULT_POINT_START_TOKEN
+ default_point_end_token = config.DEFAULT_POINT_END_TOKEN
+ point_backbone_config['default_point_start_token'] = default_point_start_token
+ point_backbone_config['default_point_end_token'] = default_point_end_token
+
+ num_new_tokens = tokenizer.add_tokens([default_point_start_token, default_point_end_token], special_tokens=True)
+ self.resize_token_embeddings(len(tokenizer))
+ point_backbone_config["point_start_token"] = tokenizer.convert_tokens_to_ids([default_point_start_token])[0]
+ point_backbone_config["point_end_token"] = tokenizer.convert_tokens_to_ids([default_point_end_token])[0]
+
+ if num_new_tokens > 0:
+ input_embeddings = self.get_input_embeddings().weight.data
+ output_embeddings = self.get_output_embeddings().weight.data
+
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
+ dim=0, keepdim=True)
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
+ dim=0, keepdim=True)
+
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
+
+ # need to update the input embeding, but no need to update the output embedding
+ for p in self.get_input_embeddings().parameters():
+ p.requires_grad = True
+ if fix_llm:
+ self.get_model().orig_embeds_params = [self.get_input_embeddings().weight.data.clone().to(device=device)] # * only tuning the new embeddings
+ for p in self.get_output_embeddings().parameters(): # * the llm head
+ p.requires_grad = False
+ print(f"Setting output embeddings fixed and {num_new_tokens} new tokens' input embeddings trainable.")
+ else:
+ self.get_model().orig_embeds_params = None
+ for p in self.get_output_embeddings().parameters():
+ p.requires_grad = True
+ print("Setting output embeddings and all input embeddings trainable.")
+
+AutoConfig.register("pointllm", PointLLMConfig)
+AutoModelForCausalLM.register(PointLLMConfig, PointLLMLlamaForCausalLM)
diff --git a/pointllm/model/utils.py b/pointllm/model/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b78741ca050c66d3c3891a236715f30652130c97
--- /dev/null
+++ b/pointllm/model/utils.py
@@ -0,0 +1,24 @@
+import torch
+from transformers import StoppingCriteria
+
+class KeywordsStoppingCriteria(StoppingCriteria):
+ def __init__(self, keywords, tokenizer, input_ids):
+ self.keywords = keywords
+ self.keyword_ids = [tokenizer(keyword).input_ids for keyword in keywords]
+ self.keyword_ids = [keyword_id[0] for keyword_id in self.keyword_ids if type(keyword_id) is list and len(keyword_id) == 1]
+ self.tokenizer = tokenizer
+ self.start_len = None
+ self.input_ids = input_ids
+
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
+ if self.start_len is None:
+ self.start_len = self.input_ids.shape[1]
+ else:
+ for keyword_id in self.keyword_ids:
+ if output_ids[0, -1] == keyword_id:
+ return True
+ outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0]
+ for keyword in self.keywords:
+ if keyword in outputs:
+ return True
+ return False
diff --git a/pointllm/train/llama_flash_attn_monkey_patch.py b/pointllm/train/llama_flash_attn_monkey_patch.py
new file mode 100644
index 0000000000000000000000000000000000000000..fcd3ba7f9361649b5ba0e5a9db312e002c1cac44
--- /dev/null
+++ b/pointllm/train/llama_flash_attn_monkey_patch.py
@@ -0,0 +1,107 @@
+# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
+from typing import List, Optional, Tuple
+from cv2 import exp
+
+import torch
+from torch import nn
+
+import transformers
+from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
+
+from einops import rearrange
+
+# * some version is changed to flash_attn_varlen_qkvpacked_func, so need to check
+try:
+ from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
+except:
+ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
+from flash_attn.bert_padding import unpad_input, pad_input
+
+def forward(
+ self,
+ hidden_states: torch.Tensor,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor],
+ Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel
+
+ attention_mask: [bsz, q_len]
+ """
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states).view(
+ bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(
+ bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(
+ bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ # [bsz, q_len, nh, hd]
+ # [bsz, nh, q_len, hd]
+
+ kv_seq_len = key_states.shape[-2]
+ offset = 0
+ if past_key_value is not None:
+ offset = past_key_value[0].shape[-2]
+ kv_seq_len += offset
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states, key_states = apply_rotary_pos_emb(query_states,
+ key_states,
+ cos,
+ sin,
+ offset=offset)
+ # [bsz, nh, t, hd]
+ assert not output_attentions, "output_attentions is not supported"
+ assert not use_cache, "use_cache is not supported"
+ assert past_key_value is None, "past_key_value is not supported"
+
+ # Flash attention codes from
+ # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
+
+ # transform the data into the format required by flash attention
+ qkv = torch.stack([query_states, key_states, value_states], dim=2) # [bsz, nh, 3, q_len, hd]
+ qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
+ # We have disabled _prepare_decoder_attention_mask in LlamaModel
+ # the attention_mask should be the same as the key_padding_mask
+ key_padding_mask = attention_mask
+
+
+ if key_padding_mask is None:
+ qkv = rearrange(qkv, 'b s ... -> (b s) ...')
+ max_s = q_len
+ cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32,
+ device=qkv.device)
+ output = flash_attn_unpadded_qkvpacked_func(
+ qkv, cu_q_lens, max_s, 0.0,
+ softmax_scale=None, causal=True
+ )
+ output = rearrange(output, '(b s) ... -> b s ...', b=bsz)
+ else:
+ nheads = qkv.shape[-2]
+ x = rearrange(qkv, 'b s three h d -> b s (three h d)')
+ x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
+ x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
+ output_unpad = flash_attn_unpadded_qkvpacked_func(
+ x_unpad, cu_q_lens, max_s, 0.0,
+ softmax_scale=None, causal=True
+ )
+ output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
+ indices, bsz, q_len),
+ 'b s (h d) -> b s h d', h=nheads)
+ return self.o_proj(rearrange(output,
+ 'b s h d -> b s (h d)')), None, None
+
+
+# Disable the transformation of the attention mask in LlamaModel as the flash attention
+# requires the attention mask to be the same as the key_padding_mask
+def _prepare_decoder_attention_mask(self, attention_mask, input_shape,
+ inputs_embeds, past_key_values_length):
+ # [bsz, seq_len]
+ return attention_mask
+
+
+def replace_llama_attn_with_flash_attn():
+ transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
+ transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
diff --git a/pointllm/train/pointllm_trainer.py b/pointllm/train/pointllm_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..36b82ef994aae1da2a1e5d64231928ce1a12eeee
--- /dev/null
+++ b/pointllm/train/pointllm_trainer.py
@@ -0,0 +1,51 @@
+import os
+import torch
+import torch.nn as nn
+
+from transformers import Trainer
+from typing import Optional
+
+
+def unwrap_model(model: nn.Module) -> nn.Module:
+ """
+ Recursively unwraps a model from potential containers (as used in distributed training).
+
+ Args:
+ model (`torch.nn.Module`): The model to unwrap.
+ """
+ # since there could be multiple levels of wrapping, unwrap recursively
+ if hasattr(model, "module"):
+ return unwrap_model(model.module)
+ else:
+ return model
+
+
+class PointLLMTrainer(Trainer):
+
+ def _save(self, output_dir: Optional[str] = None, state_dict=None):
+ print("no save!!!!!!1")
+ pass
+ if getattr(self.args, 'tune_mm_mlp_adapter', False):
+ # Save the model
+ _state_dict = state_dict
+ if _state_dict is None:
+ # Only save the model itself if we are using distributed training
+ model_to_save = unwrap_model(self.model)
+ _state_dict = model_to_save.state_dict()
+
+ weight_to_save = {}
+ keys_to_match = ['point_proj', 'embed_tokens', 'embed_in']
+ for k, v in _state_dict.items():
+ if any(key_match in k for key_match in keys_to_match):
+ weight_to_save[k] = v
+
+ current_folder = output_dir.split('/')[-1]
+ parent_folder = os.path.dirname(output_dir)
+ if current_folder.startswith('checkpoint-'):
+ mm_projector_folder = os.path.join(parent_folder, "point_proj")
+ os.makedirs(mm_projector_folder, exist_ok=True)
+ torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin'))
+ else:
+ torch.save(weight_to_save, os.path.join(output_dir, f'point_proj.bin'))
+
+ super(PointLLMTrainer, self)._save(output_dir, state_dict)
diff --git a/pointllm/train/train.py b/pointllm/train/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b9ec190a3ea714e4e5f1daa2b3bb6b536172b1d
--- /dev/null
+++ b/pointllm/train/train.py
@@ -0,0 +1,336 @@
+# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
+# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
+# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# import os
+# os.environ["CUDA_VISIBLE_DEVICES"]="2"
+from dataclasses import dataclass, field
+import pathlib
+from typing import Optional, List
+
+import torch
+import transformers
+from pointllm.train.pointllm_trainer import PointLLMTrainer
+
+from pointllm import conversation as conversation_lib
+from pointllm.model import *
+from pointllm.data import make_object_point_data_module
+
+# * logger
+from pointllm.utils import build_logger
+
+IGNORE_INDEX = -100
+
+DEFAULT_PAD_TOKEN = "[PAD]"
+DEFAULT_EOS_TOKEN = ""
+DEFAULT_BOS_TOKEN = ""
+DEFAULT_UNK_TOKEN = ""
+
+
+@dataclass
+class ModelArguments:
+ # /home/pointllm_weight_2/PointLLM_7B_v1.2
+ model_name_or_path: Optional[str] = field(default="/home/TinyGPT-V/pretrain_weight/phi-new")
+ # model_name_or_path: Optional[str] = field(default="/home/pointllm_weight_2/PointLLM_7B_v1.2")
+ version: Optional[str] = field(default="v1")
+
+@dataclass
+class DataArguments:
+ data_path: str = field(default="/home/PointLLM/data/objaverse_data", metadata={"help": "Path to the training data."})
+ anno_path: str = field(default='/home/PointLLM/data/anno_data/PointLLM_complex_instruction_70K.json', metadata={"help": "Path to the utterance data. If None, will use referit3d by defautl."})
+ use_color: bool = field(default=True, metadata={"help": "Whether to use color."})
+ data_debug_num: int = field(default=0, metadata={"help": "Number of data to use in debug mode. If larger than 0, use debug mode, else use the whole data"})
+ split_train_val: bool = field(default=False, metadata={"help": "Whether to split train and val."})
+ split_ratio: float = field(default=0.9, metadata={"help": "Ratio of train and val."})
+ pointnum: int = field(default=8192, metadata={"help": "Number of points."})
+
+ # conversation_types: List[str] = field(default_factory=lambda: ["simple_description"], metadata={"help": "Conversation types to use."})
+ conversation_types: List[str] = field(default_factory=lambda: ["detailed_description", "single_round", "multi_round"],
+ metadata={"help": "Conversation types to use."})
+ is_multimodal: bool = True
+
+@dataclass
+class TrainingArguments(transformers.TrainingArguments):
+ # * can refer to https://huggingface.co/docs/transformers/v4.28.1/en/main_classes/trainer#transformers.TrainingArgument
+ cache_dir: Optional[str] = field(default='/home/PointLLM/trash')
+ output_dir: Optional[str] = field(default='/home/PointLLM/trash')
+
+ save_strategy: Optional[str] = field(default='no')
+
+ save_steps: int = field(default=2400)
+ optim: str = field(default="adamw_torch")
+ dataloader_num_workers: int = field(default=24)
+
+ model_max_length: int = field(
+ default=2048,
+ metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
+ )
+ per_device_train_batch_size: int = field(
+ default=6, metadata={"help": "Batch size per GPU/TPU/MPS/NPU core/CPU for training."}
+ )
+ model_debug: bool = field(default=False, metadata={"help": "Whether to use small model."}) # * whether to load checkpoints at the mo
+ fix_llm: bool = field(default=True, metadata={"help": "Whether to fix the LLM."})
+ fix_pointnet: bool = field(default=True, metadata={"help": "Whether to fix the PointNet."})
+
+ remove_unused_columns: bool = field(default=False)
+ force_fsdp: bool = field(default=False)
+ bf16: bool = field(default=True)
+ # * for two stage training
+ tune_mm_mlp_adapter: bool = field(default=True) # * set True when pre-training, and false when fine-tuning
+ stage_2: bool = field(default=False) # * set True when fine-tuning
+ pretrained_mm_mlp_adapter: Optional[str] = field(default=None) # * path to the pre-trained projector & output_embed & input_embed
+ detatch_point_token: bool = field(default=False) # * deprecated
+ # * point backbone ckpt path
+ # point_backbone_ckpt: str = field(default='/home/pointllm_weight_2/PointLLM_7B_v1.2')
+
+
+ # point_backbone_ckpt: str = field(default="/home/R_Decoder/pretrain_weight/point_mae/pretrain.pth")
+ point_backbone_ckpt: str = field(default="/home/pointllm_weight_2/point_model/point_model.pth")
+ # point_backbone_ckpt: str = field(default="")
+
+def safe_save_model_for_hf_trainer(trainer: transformers.Trainer,
+ output_dir: str):
+ """Collects the state dict and dump to disk."""
+ state_dict = trainer.model.state_dict()
+ if trainer.args.should_save:
+ cpu_state_dict = {
+ key: value.cpu()
+ for key, value in state_dict.items()
+ }
+ del state_dict
+ trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
+
+
+def train():
+ parser = transformers.HfArgumentParser(
+ (ModelArguments, DataArguments, TrainingArguments))
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+
+ training_args.log_level = "info" # * default is passive(warning)
+ # training_args.bf16 = True
+ # * build logger
+
+ training_args.output_dir = '/home/PointLLM/trash'
+ logger = build_logger(__name__, training_args.output_dir + '/train.log')
+
+ if training_args.model_debug:
+ # * do not load checkpoint, load from config
+ config = transformers.AutoConfig.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ torch_dtype=torch.float32
+ )
+ model = PointLLMLlamaForCausalLM._from_config(config)
+ else:
+ model = PointLLMLlamaForCausalLM.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ torch_dtype=torch.float16
+ )
+
+ model.config.use_cache = False
+
+ if training_args.fix_llm:
+ # * This will fix all the parameters
+ logger.info("LLM is fixed. Fix_llm flag is set to True")
+ # * fix llama, lm_head, pointnet, projection layer here
+ model.requires_grad_(False)
+ model.get_model().fix_llm = True
+ model.get_model().point_proj.requires_grad_(True)
+ model.get_model().point_backbone.requires_grad_(True) # * set as True for fsdp, use fix_pointnet flag to control
+ else:
+ model.get_model().fix_llm = False
+ logger.warning("LLM is trainable. Fix_llm flag is set to False")
+
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ model_max_length=training_args.model_max_length,
+ padding_side="right",
+ use_fast=True,
+ )
+
+ # tokenizer = transformers.AutoTokenizer.from_pretrained(
+ # model_args.model_name_or_path,
+ # padding_side="right",
+ # use_fast=False,
+ # )
+
+ if model_args.version == "v0" or "v0" in model_args.model_name_or_path:
+ raise ValueError("v0 is deprecated.")
+ else:
+ # tokenizer.pad_token = tokenizer.unk_token
+ tokenizer.pad_token = tokenizer.eos_token
+ conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1_1"]
+
+ if not training_args.fix_pointnet:
+ # * not fix pointnet
+ logger.info("Point backbone is trainable. Fix_pointnet flag is set to False, pointnet grad will be recorded.")
+ model.get_model().fix_pointnet = False
+ else:
+ logger.info("Point backbone is fixed. Fix_pointnet flag is set to True, pointnet grad will not be recorded.")
+ model.get_model().fix_pointnet = True # * use with torch.inference_mode to control, not requires_grad for fsdp for second stage
+ if not training_args.stage_2:
+ logger.info("Set requires_grad of point backbone to False")
+ model.get_model().point_backbone.requires_grad_(False) # * fix pointnet for first stage, need for fsdp in stage2
+
+ if training_args.tune_mm_mlp_adapter:
+ # * not fix the projection layer
+ # * may need to set the embed_tokens to require_grad = True if added new tokens
+ # * this is done in initialize_tokenizer_point_backbone_config
+ logger.info("Point projection layer is trainable.")
+ else:
+ model.get_model().point_proj.requires_grad_(False)
+ logger.info("Point prejcetion layer is fixed.")
+
+ if not training_args.stage_2:
+ # * we assume in stage2, llm, point_backbone, and projection layer can be loaded from the model checkpoint
+ print(f"Default point_backbone_ckpt is {training_args.point_backbone_ckpt}.")
+ model.get_model().load_point_backbone_checkpoint(training_args.point_backbone_ckpt)
+ model.initialize_tokenizer_point_backbone_config(tokenizer=tokenizer, device=training_args.device, fix_llm=training_args.fix_llm)
+ else:
+ # * stage2
+ model.initialize_tokenizer_point_backbone_config_wo_embedding(tokenizer=tokenizer)
+
+ point_backbone_config = model.get_model().point_backbone_config
+
+ data_args.point_token_len = point_backbone_config['point_token_len']
+ data_args.mm_use_point_start_end = point_backbone_config['mm_use_point_start_end']
+ data_args.point_backbone_config = point_backbone_config
+
+ params_no_grad = [n for n, p in model.named_parameters() if not p.requires_grad]
+ if len(params_no_grad) > 0:
+ if training_args.fsdp is not None and len(training_args.fsdp) > 0:
+ if len(params_no_grad) < 10:
+ print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}'. format(len(params_no_grad), params_no_grad))
+ else:
+ print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}...(omitted)'. format(len(params_no_grad), ', '.join(params_no_grad[:10])))
+ print("[WARNING] Attempting to use FSDP with partially frozen paramters, this is experimental.")
+ print("[WARNING] As of 4/30/23, this feature requires PyTorch-nightly build. See here for details: https://github.com/haotian-liu/LLaVA#experimental-use-fsdp-to-save-memory-in-pretraining")
+
+ from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
+ def patch_FSDP_use_orig_params(func):
+ def wrap_func(*args, **kwargs):
+ use_orig_params = kwargs.pop('use_orig_params', True)
+ return func(*args, **kwargs, use_orig_params=use_orig_params)
+ return wrap_func
+
+ FSDP.__init__ = patch_FSDP_use_orig_params(FSDP.__init__)
+
+ data_module = make_object_point_data_module(tokenizer=tokenizer,
+ data_args=data_args)
+
+
+ # for name, param in model.model.layers.named_parameters():
+ # param.requires_grad = False
+ #
+ # for name, param in model.model.point_backbone.named_parameters():
+ # param.requires_grad = False
+ for name, param in model.model.named_parameters():
+ param.requires_grad = False
+
+ for name, param in model.model.named_parameters():
+ if 'point_proj' in name:
+ param.requires_grad = True
+
+ for name, param in model.model.named_parameters():
+ if 'q_layernorm' in name:
+ param.requires_grad = True
+
+ if 'k_layernorm' in name:
+ param.requires_grad = True
+
+ if 'post_layernorm' in name:
+ param.requires_grad = True
+
+ if 'input_layernorm' in name:
+ param.requires_grad = True
+
+ # if 'input_layernorm' in name:
+ # param.requires_grad = True
+
+ if 'final_layernorm' in name:
+ param.requires_grad = True
+ ############################
+ #
+ #
+ # for name, param in model.model.layers.named_parameters():
+ # param.requires_grad = False
+ #
+ # # for i, layer in enumerate(llama_model.model.layers):
+ # # # ๅฆๆๅฑ็็ดขๅผๅฐไบ5๏ผๅๅฐ่ฏฅๅฑ็ๅๆฐ่ฎพ็ฝฎไธบๅฏ่ฎญ็ป
+ # # if i < 5:
+ # # for param in layer.parameters():
+ # # param.requires_grad = True
+ # # # ๅฐ่ฟไบๅฑ็ๅๆฐ่ฝฌๆขไธบFP32
+ # # layer.to(torch.float32)
+ for i, layer in enumerate(model.model.layers):
+ # layer.register_forward_hook(print_layer_output)
+ # set trainable to True for the input_layernorm layer
+ layer.self_attn.q_layernorm.weight.requires_grad = True
+ layer.self_attn.k_layernorm.weight.requires_grad = True
+ layer.post_layernorm.weight.requires_grad = True
+ layer.input_layernorm.weight.requires_grad = True
+
+ layer.self_attn.q_layernorm.weight.data = layer.self_attn.q_layernorm.weight.data.float()
+ layer.self_attn.k_layernorm.weight.data = layer.self_attn.k_layernorm.weight.data.float()
+ layer.post_layernorm.weight.data = layer.post_layernorm.weight.data.float()
+ layer.input_layernorm.weight.data = layer.input_layernorm.weight.data.float()
+
+ # ๅฏนๅ็ฝฎ้กน่ฟ่ก็ฑปไผผๆไฝ
+ if layer.self_attn.q_layernorm.bias is not None:
+ layer.self_attn.q_layernorm.bias.data = layer.self_attn.q_layernorm.bias.data.float()
+ if layer.self_attn.k_layernorm.bias is not None:
+ layer.self_attn.k_layernorm.bias.data = layer.self_attn.k_layernorm.bias.data.float()
+ if layer.input_layernorm.bias is not None:
+ layer.input_layernorm.bias.data = layer.input_layernorm.bias.data.float()
+
+ model.model.final_layernorm.weight.requires_grad = True
+ model.model.final_layernorm.weight.data = model.model.final_layernorm.weight.data.float()
+ if model.model.final_layernorm.bias is not None:
+ model.model.final_layernorm.bias.data = model.model.final_layernorm.bias.float()
+
+
+
+
+
+ ###################################
+
+
+ for name, param in model.model.named_parameters():
+ if param.requires_grad: # ๅฆๆๅๆฐ้่ฆๆขฏๅบฆ๏ผ้ฃไนๅฎๅฐ่ขซๆดๆฐ
+
+ logger.info(f"Parameter {name} will be updated.")
+
+ # import os
+ # torch.save({
+ # 'base_model': model.model.point_backbone.state_dict(), }, os.path.join('/home/pointllm_weight_2/point_model/point_model.pth'))
+
+ # model = model.half()
+ trainer = PointLLMTrainer(model=model,
+ tokenizer=tokenizer,
+ args=training_args,
+ **data_module)
+
+ if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
+ trainer.train(resume_from_checkpoint=True)
+ else:
+ trainer.train()
+ trainer.save_state()
+ safe_save_model_for_hf_trainer(trainer=trainer,
+ output_dir=training_args.output_dir)
+
+
+if __name__ == "__main__":
+ train()
diff --git a/pointllm/train/train_mem.py b/pointllm/train/train_mem.py
new file mode 100644
index 0000000000000000000000000000000000000000..6588d4940a8423780aedc2d71ae5868c22935d84
--- /dev/null
+++ b/pointllm/train/train_mem.py
@@ -0,0 +1,16 @@
+# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
+# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
+# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
+
+# Need to call this before importing transformers.
+# from pointllm.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
+#
+# replace_llama_attn_with_flash_attn()
+
+import os
+os.environ["CUDA_VISIBLE_DEVICES"]="2"
+
+from pointllm.train.train import train
+
+if __name__ == "__main__":
+ train()
diff --git a/pointllm/utils.py b/pointllm/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..7da3faaa8ddb9698d0cadabc93e6ff40ce1277d7
--- /dev/null
+++ b/pointllm/utils.py
@@ -0,0 +1,155 @@
+import logging
+import logging.handlers
+import os
+import sys
+
+import requests
+
+import yaml
+from easydict import EasyDict
+
+server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
+moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
+
+handler = None
+
+
+def merge_new_config(config, new_config):
+ for key, val in new_config.items():
+ if not isinstance(val, dict):
+ if key == '_base_':
+ with open(new_config['_base_'], 'r') as f:
+ try:
+ val = yaml.load(f, Loader=yaml.FullLoader)
+ except:
+ val = yaml.load(f)
+ config[key] = EasyDict()
+ merge_new_config(config[key], val)
+ else:
+ config[key] = val
+ continue
+ if key not in config:
+ config[key] = EasyDict()
+ merge_new_config(config[key], val)
+ return config
+
+def cfg_from_yaml_file(cfg_file):
+ config = EasyDict()
+ with open(cfg_file, 'r') as f:
+ new_config = yaml.load(f, Loader=yaml.FullLoader)
+ merge_new_config(config=config, new_config=new_config)
+ return config
+
+
+def build_logger(logger_name, logger_filepath):
+ global handler
+
+ formatter = logging.Formatter(
+ fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ )
+
+ # Set the format of root handlers
+ if not logging.getLogger().handlers:
+ logging.basicConfig(level=logging.INFO)
+ logging.getLogger().handlers[0].setFormatter(formatter)
+
+ # Redirect stdout and stderr to loggers
+ stdout_logger = logging.getLogger("stdout")
+ stdout_logger.setLevel(logging.INFO)
+ sl = StreamToLogger(stdout_logger, logging.INFO)
+ sys.stdout = sl
+
+ stderr_logger = logging.getLogger("stderr")
+ stderr_logger.setLevel(logging.ERROR)
+ sl = StreamToLogger(stderr_logger, logging.ERROR)
+ sys.stderr = sl
+
+ # Get logger
+ logger = logging.getLogger(logger_name)
+ logger.setLevel(logging.INFO)
+
+ # Add a file handler for all loggers
+ if handler is None:
+ # * get the logger_file's directory, and create it if not exist
+ logger_filedir = os.path.dirname(logger_filepath)
+ os.makedirs(logger_filedir, exist_ok=True)
+ handler = logging.handlers.TimedRotatingFileHandler(
+ logger_filepath, when='D', utc=True)
+ handler.setFormatter(formatter)
+
+ for name, item in logging.root.manager.loggerDict.items():
+ if isinstance(item, logging.Logger):
+ item.addHandler(handler)
+
+ return logger
+
+
+class StreamToLogger(object):
+ """
+ Fake file-like stream object that redirects writes to a logger instance.
+ """
+ def __init__(self, logger, log_level=logging.INFO):
+ self.terminal = sys.stdout
+ self.logger = logger
+ self.log_level = log_level
+ self.linebuf = ''
+
+ def __getattr__(self, attr):
+ return getattr(self.terminal, attr)
+
+ def write(self, buf):
+ temp_linebuf = self.linebuf + buf
+ self.linebuf = ''
+ for line in temp_linebuf.splitlines(True):
+ # From the io.TextIOWrapper docs:
+ # On output, if newline is None, any '\n' characters written
+ # are translated to the system default line separator.
+ # By default sys.stdout.write() expects '\n' newlines and then
+ # translates them so this is still cross platform.
+ if line[-1] == '\n':
+ self.logger.log(self.log_level, line.rstrip())
+ else:
+ self.linebuf += line
+
+ def flush(self):
+ if self.linebuf != '':
+ self.logger.log(self.log_level, self.linebuf.rstrip())
+ self.linebuf = ''
+
+
+def disable_torch_init():
+ """
+ Disable the redundant torch default initialization to accelerate model creation.
+ """
+ import torch
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
+
+
+def violates_moderation(text):
+ """
+ Check whether the text violates OpenAI moderation API.
+ """
+ url = "https://api.openai.com/v1/moderations"
+ # url = "https://api.chatanywhere.tech/v1/moderations"
+ headers = {"Content-Type": "application/json",
+ "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
+ text = text.replace("\n", "")
+ data = "{" + '"input": ' + f'"{text}"' + "}"
+ data = data.encode("utf-8")
+ try:
+ ret = requests.post(url, headers=headers, data=data, timeout=5)
+ flagged = ret.json()["results"][0]["flagged"]
+ except requests.exceptions.RequestException as e:
+ flagged = False
+ except KeyError as e:
+ flagged = False
+
+ return flagged
+
+
+def pretty_print_semaphore(semaphore):
+ if semaphore is None:
+ return "None"
+ return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
diff --git a/pointnet++/Pointnet2_PyTorch/.pre-commit-config.yaml b/pointnet++/Pointnet2_PyTorch/.pre-commit-config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..5599b7939a536090964170afb54d3c98b7ff417a
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/.pre-commit-config.yaml
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a67b582096cb7cf43d6a7b50a3dca95bbd43a03aee98f154d00e47c5a019f899
+size 911
diff --git a/pointnet++/Pointnet2_PyTorch/.travis.yml b/pointnet++/Pointnet2_PyTorch/.travis.yml
new file mode 100644
index 0000000000000000000000000000000000000000..be9128267b9ed016a7691d853aa857cb39edf504
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/.travis.yml
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1e502f1212a68bfb32d52671af430e0e6a277f253cd7ec96c2e118728cae3b48
+size 260
diff --git a/pointnet++/Pointnet2_PyTorch/MANIFEST.in b/pointnet++/Pointnet2_PyTorch/MANIFEST.in
new file mode 100644
index 0000000000000000000000000000000000000000..57fb41fb03d65eeaa6d6b3b39ce8125afc4b5a96
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/MANIFEST.in
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:29368275bc245acc81145ed464567466593c72a4bbc5b79befd7e03322aabe68
+size 24
diff --git a/pointnet++/Pointnet2_PyTorch/README.rst b/pointnet++/Pointnet2_PyTorch/README.rst
new file mode 100644
index 0000000000000000000000000000000000000000..65f47496618289167e5668213979b28d09598a6d
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/README.rst
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e011f2344e4743f2eae813d51e2e129df9f62177c49d40a116136fb479f99047
+size 3136
diff --git a/pointnet++/Pointnet2_PyTorch/UNLICENSE b/pointnet++/Pointnet2_PyTorch/UNLICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..03ff25bd3c37ec15281290da4f2805a2a0575456
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/UNLICENSE
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7e12e5df4bae12cb21581ba157ced20e1986a0508dd10d0e8a4ab9a4cf94e85c
+size 1211
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2/__init__.py b/pointnet++/Pointnet2_PyTorch/pointnet2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebc06de2f04ee2d31611ff44848f568c3b894443
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2/__init__.py
@@ -0,0 +1,2 @@
+from pointnet2 import data, models, utils
+from pointnet2._version import __version__
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2/_version.py b/pointnet++/Pointnet2_PyTorch/pointnet2/_version.py
new file mode 100644
index 0000000000000000000000000000000000000000..528787cfc8ad81ed41822a8104b60b4896632906
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2/_version.py
@@ -0,0 +1 @@
+__version__ = "3.0.0"
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2/config/config.yaml b/pointnet++/Pointnet2_PyTorch/pointnet2/config/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..3dfb64c6158510fc559b10b678c56509b43293a1
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2/config/config.yaml
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4e45624d553fcf01d82437136e30864e4daebc14b838d17f3f22194e1bae498d
+size 213
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2/config/model/msg.yaml b/pointnet++/Pointnet2_PyTorch/pointnet2/config/model/msg.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d2af0b7a88fdf59051af62f04688e08a9e23f6b5
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2/config/model/msg.yaml
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:41a4260af50f8fd00f0972c84f4d5bc25429c4118432677dc46a282b55b203b0
+size 25
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2/config/model/ssg.yaml b/pointnet++/Pointnet2_PyTorch/pointnet2/config/model/ssg.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d2af0b7a88fdf59051af62f04688e08a9e23f6b5
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2/config/model/ssg.yaml
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:41a4260af50f8fd00f0972c84f4d5bc25429c4118432677dc46a282b55b203b0
+size 25
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2/config/task/cls.yaml b/pointnet++/Pointnet2_PyTorch/pointnet2/config/task/cls.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b2fd95883283a07a3c43ce4ff1343e1062882610
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2/config/task/cls.yaml
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ac02b1c5fb55df752291a9ea018b9b6766af1338066f6952abffeaa9458d14a9
+size 170
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2/config/task/semseg.yaml b/pointnet++/Pointnet2_PyTorch/pointnet2/config/task/semseg.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a86655a978e40c5eb428114888ae6de5ff936d0f
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2/config/task/semseg.yaml
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5ae3598ae1baddd410454923d92a3e6eaf60611bc4ba5ac6a5401428bab257f6
+size 168
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2/config/task_model/cls-msg.yaml b/pointnet++/Pointnet2_PyTorch/pointnet2/config/task_model/cls-msg.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..85733dcfa38a235e0d44fec117858e5fde1a52ad
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2/config/task_model/cls-msg.yaml
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c73495683667b7707aae2cb0eff60f4038f49157f8cd4f23b77f099f40296214
+size 85
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2/config/task_model/cls-ssg.yaml b/pointnet++/Pointnet2_PyTorch/pointnet2/config/task_model/cls-ssg.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..3bb63c54a6d12e2c77b263f44d3670df82f24863
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2/config/task_model/cls-ssg.yaml
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:35b37a9ff63a3b70b34d3cda34a854c9913a1f6f433d9c964304f2438cf76184
+size 85
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2/config/task_model/semseg-msg.yaml b/pointnet++/Pointnet2_PyTorch/pointnet2/config/task_model/semseg-msg.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7d6979274fec4b38211954b15b067647bcfa883b
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2/config/task_model/semseg-msg.yaml
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7e740a77b7057d3a793f9f0043c558ab2f95bf44e4097c47e0d6cfbda05b2047
+size 77
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2/config/task_model/semseg-ssg.yaml b/pointnet++/Pointnet2_PyTorch/pointnet2/config/task_model/semseg-ssg.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c7821bf13fc2ebfb8cd181f2adcebc57c477268e
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2/config/task_model/semseg-ssg.yaml
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:71021d5805c00a130e5feac65cd90d7b0c5543ee88effc605cdf9aca729354de
+size 77
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2/data/.gitignore b/pointnet++/Pointnet2_PyTorch/pointnet2/data/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..d6cd25d7e43b5673e1bb2e80c0b0dc652cdcc5f9
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2/data/.gitignore
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9e43e5b226e0f1b36792df5b6c3621a66103de6dc1bed1aaf5ef7793d22584cc
+size 52
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2/data/Indoor3DSemSegLoader.py b/pointnet++/Pointnet2_PyTorch/pointnet2/data/Indoor3DSemSegLoader.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a449a732ba81f995221b38250145fb2a3f0789c
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2/data/Indoor3DSemSegLoader.py
@@ -0,0 +1,105 @@
+import os
+import shlex
+import subprocess
+
+import h5py
+import numpy as np
+import torch
+import torch.utils.data as data
+
+BASE_DIR = os.path.dirname(os.path.abspath(__file__))
+
+
+def _get_data_files(list_filename):
+ with open(list_filename) as f:
+ return [line.rstrip() for line in f]
+
+
+def _load_data_file(name):
+ f = h5py.File(name, "r")
+ data = f["data"][:]
+ label = f["label"][:]
+ return data, label
+
+
+class Indoor3DSemSeg(data.Dataset):
+ def __init__(self, num_points, train=True, download=True, data_precent=1.0):
+ super().__init__()
+ self.data_precent = data_precent
+ self.folder = "indoor3d_sem_seg_hdf5_data"
+ self.data_dir = os.path.join(BASE_DIR, self.folder)
+ self.url = (
+ "https://shapenet.cs.stanford.edu/media/indoor3d_sem_seg_hdf5_data.zip"
+ )
+
+ if download and not os.path.exists(self.data_dir):
+ zipfile = os.path.join(BASE_DIR, os.path.basename(self.url))
+ subprocess.check_call(
+ shlex.split("curl {} -o {}".format(self.url, zipfile))
+ )
+
+ subprocess.check_call(
+ shlex.split("unzip {} -d {}".format(zipfile, BASE_DIR))
+ )
+
+ subprocess.check_call(shlex.split("rm {}".format(zipfile)))
+
+ self.train, self.num_points = train, num_points
+
+ all_files = _get_data_files(os.path.join(self.data_dir, "all_files.txt"))
+ room_filelist = _get_data_files(
+ os.path.join(self.data_dir, "room_filelist.txt")
+ )
+
+ data_batchlist, label_batchlist = [], []
+ for f in all_files:
+ data, label = _load_data_file(os.path.join(BASE_DIR, f))
+ data_batchlist.append(data)
+ label_batchlist.append(label)
+
+ data_batches = np.concatenate(data_batchlist, 0)
+ labels_batches = np.concatenate(label_batchlist, 0)
+
+ test_area = "Area_5"
+ train_idxs, test_idxs = [], []
+ for i, room_name in enumerate(room_filelist):
+ if test_area in room_name:
+ test_idxs.append(i)
+ else:
+ train_idxs.append(i)
+
+ if self.train:
+ self.points = data_batches[train_idxs, ...]
+ self.labels = labels_batches[train_idxs, ...]
+ else:
+ self.points = data_batches[test_idxs, ...]
+ self.labels = labels_batches[test_idxs, ...]
+
+ def __getitem__(self, idx):
+ pt_idxs = np.arange(0, self.num_points)
+ np.random.shuffle(pt_idxs)
+
+ current_points = torch.from_numpy(self.points[idx, pt_idxs].copy()).float()
+ current_labels = torch.from_numpy(self.labels[idx, pt_idxs].copy()).long()
+
+ return current_points, current_labels
+
+ def __len__(self):
+ return int(self.points.shape[0] * self.data_precent)
+
+ def set_num_points(self, pts):
+ self.num_points = pts
+
+ def randomize(self):
+ pass
+
+
+if __name__ == "__main__":
+ dset = Indoor3DSemSeg(16, "./", train=True)
+ print(dset[0])
+ print(len(dset))
+ dloader = torch.utils.data.DataLoader(dset, batch_size=32, shuffle=True)
+ for i, data in enumerate(dloader, 0):
+ inputs, labels = data
+ if i == len(dloader) - 1:
+ print(inputs.size())
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2/data/ModelNet40Loader.py b/pointnet++/Pointnet2_PyTorch/pointnet2/data/ModelNet40Loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..7afafd5eeb3aca10b1a548d45c03fd7f31234a1e
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2/data/ModelNet40Loader.py
@@ -0,0 +1,161 @@
+import os
+import os.path as osp
+import shlex
+import shutil
+import subprocess
+
+import lmdb
+import msgpack_numpy
+import numpy as np
+import torch
+import torch.utils.data as data
+import tqdm
+
+BASE_DIR = os.path.dirname(os.path.abspath(__file__))
+
+
+def pc_normalize(pc):
+ l = pc.shape[0]
+ centroid = np.mean(pc, axis=0)
+ pc = pc - centroid
+ m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
+ pc = pc / m
+ return pc
+
+
+class ModelNet40Cls(data.Dataset):
+ def __init__(self, num_points, transforms=None, train=True, download=True):
+ super().__init__()
+
+ self.transforms = transforms
+
+ self.set_num_points(num_points)
+ self._cache = os.path.join(BASE_DIR, "modelnet40_normal_resampled_cache")
+
+ if not osp.exists(self._cache):
+ self.folder = "modelnet40_normal_resampled"
+ self.data_dir = os.path.join(BASE_DIR, self.folder)
+ self.url = (
+ "https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip"
+ )
+
+ if download and not os.path.exists(self.data_dir):
+ zipfile = os.path.join(BASE_DIR, os.path.basename(self.url))
+ subprocess.check_call(
+ shlex.split("curl {} -o {}".format(self.url, zipfile))
+ )
+
+ subprocess.check_call(
+ shlex.split("unzip {} -d {}".format(zipfile, BASE_DIR))
+ )
+
+ subprocess.check_call(shlex.split("rm {}".format(zipfile)))
+
+ self.train = train
+ self.set_num_points(num_points)
+
+ self.catfile = os.path.join(self.data_dir, "modelnet40_shape_names.txt")
+ self.cat = [line.rstrip() for line in open(self.catfile)]
+ self.classes = dict(zip(self.cat, range(len(self.cat))))
+
+ os.makedirs(self._cache)
+
+ print("Converted to LMDB for faster dataloading while training")
+ for split in ["train", "test"]:
+ if split == "train":
+ shape_ids = [
+ line.rstrip()
+ for line in open(
+ os.path.join(self.data_dir, "modelnet40_train.txt")
+ )
+ ]
+ else:
+ shape_ids = [
+ line.rstrip()
+ for line in open(
+ os.path.join(self.data_dir, "modelnet40_test.txt")
+ )
+ ]
+
+ shape_names = ["_".join(x.split("_")[0:-1]) for x in shape_ids]
+ # list of (shape_name, shape_txt_file_path) tuple
+ self.datapath = [
+ (
+ shape_names[i],
+ os.path.join(self.data_dir, shape_names[i], shape_ids[i])
+ + ".txt",
+ )
+ for i in range(len(shape_ids))
+ ]
+
+ with lmdb.open(
+ osp.join(self._cache, split), map_size=1 << 36
+ ) as lmdb_env, lmdb_env.begin(write=True) as txn:
+ for i in tqdm.trange(len(self.datapath)):
+ fn = self.datapath[i]
+ point_set = np.loadtxt(fn[1], delimiter=",").astype(np.float32)
+ cls = self.classes[self.datapath[i][0]]
+ cls = int(cls)
+
+ txn.put(
+ str(i).encode(),
+ msgpack_numpy.packb(
+ dict(pc=point_set, lbl=cls), use_bin_type=True
+ ),
+ )
+
+ shutil.rmtree(self.data_dir)
+
+ self._lmdb_file = osp.join(self._cache, "train" if train else "test")
+ with lmdb.open(self._lmdb_file, map_size=1 << 36) as lmdb_env:
+ self._len = lmdb_env.stat()["entries"]
+
+ self._lmdb_env = None
+
+ def __getitem__(self, idx):
+ if self._lmdb_env is None:
+ self._lmdb_env = lmdb.open(
+ self._lmdb_file, map_size=1 << 36, readonly=True, lock=False
+ )
+
+ with self._lmdb_env.begin(buffers=True) as txn:
+ ele = msgpack_numpy.unpackb(txn.get(str(idx).encode()), raw=False)
+
+ point_set = ele["pc"]
+
+ pt_idxs = np.arange(0, self.num_points)
+ np.random.shuffle(pt_idxs)
+
+ point_set = point_set[pt_idxs, :]
+ point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])
+
+ if self.transforms is not None:
+ point_set = self.transforms(point_set)
+
+ return point_set, ele["lbl"]
+
+ def __len__(self):
+ return self._len
+
+ def set_num_points(self, pts):
+ self.num_points = min(int(1e4), pts)
+
+
+if __name__ == "__main__":
+ from torchvision import transforms
+ import data_utils as d_utils
+
+ transforms = transforms.Compose(
+ [
+ d_utils.PointcloudToTensor(),
+ d_utils.PointcloudRotate(axis=np.array([1, 0, 0])),
+ d_utils.PointcloudScale(),
+ d_utils.PointcloudTranslate(),
+ d_utils.PointcloudJitter(),
+ ]
+ )
+ dset = ModelNet40Cls(16, train=True, transforms=transforms)
+ print(dset[0][0])
+ print(dset[0][1])
+ print(len(dset))
+ dloader = torch.utils.data.DataLoader(dset, batch_size=32, shuffle=True)
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2/data/__init__.py b/pointnet++/Pointnet2_PyTorch/pointnet2/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff9bc73fdd487e7368a3ebdbe4ebc0dc049ab769
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2/data/__init__.py
@@ -0,0 +1,2 @@
+from .Indoor3DSemSegLoader import Indoor3DSemSeg
+from .ModelNet40Loader import ModelNet40Cls
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2/data/data_utils.py b/pointnet++/Pointnet2_PyTorch/pointnet2/data/data_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..778444320d84893b3b4a5b01ed00f65287078fc5
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2/data/data_utils.py
@@ -0,0 +1,141 @@
+import numpy as np
+import torch
+
+
+def angle_axis(angle, axis):
+ # type: (float, np.ndarray) -> float
+ r"""Returns a 4x4 rotation matrix that performs a rotation around axis by angle
+
+ Parameters
+ ----------
+ angle : float
+ Angle to rotate by
+ axis: np.ndarray
+ Axis to rotate about
+
+ Returns
+ -------
+ torch.Tensor
+ 3x3 rotation matrix
+ """
+ u = axis / np.linalg.norm(axis)
+ cosval, sinval = np.cos(angle), np.sin(angle)
+
+ # yapf: disable
+ cross_prod_mat = np.array([[0.0, -u[2], u[1]],
+ [u[2], 0.0, -u[0]],
+ [-u[1], u[0], 0.0]])
+
+ R = torch.from_numpy(
+ cosval * np.eye(3)
+ + sinval * cross_prod_mat
+ + (1.0 - cosval) * np.outer(u, u)
+ )
+ # yapf: enable
+ return R.float()
+
+
+class PointcloudScale(object):
+ def __init__(self, lo=0.8, hi=1.25):
+ self.lo, self.hi = lo, hi
+
+ def __call__(self, points):
+ scaler = np.random.uniform(self.lo, self.hi)
+ points[:, 0:3] *= scaler
+ return points
+
+
+class PointcloudRotate(object):
+ def __init__(self, axis=np.array([0.0, 1.0, 0.0])):
+ self.axis = axis
+
+ def __call__(self, points):
+ rotation_angle = np.random.uniform() * 2 * np.pi
+ rotation_matrix = angle_axis(rotation_angle, self.axis)
+
+ normals = points.size(1) > 3
+ if not normals:
+ return torch.matmul(points, rotation_matrix.t())
+ else:
+ pc_xyz = points[:, 0:3]
+ pc_normals = points[:, 3:]
+ points[:, 0:3] = torch.matmul(pc_xyz, rotation_matrix.t())
+ points[:, 3:] = torch.matmul(pc_normals, rotation_matrix.t())
+
+ return points
+
+
+class PointcloudRotatePerturbation(object):
+ def __init__(self, angle_sigma=0.06, angle_clip=0.18):
+ self.angle_sigma, self.angle_clip = angle_sigma, angle_clip
+
+ def _get_angles(self):
+ angles = np.clip(
+ self.angle_sigma * np.random.randn(3), -self.angle_clip, self.angle_clip
+ )
+
+ return angles
+
+ def __call__(self, points):
+ angles = self._get_angles()
+ Rx = angle_axis(angles[0], np.array([1.0, 0.0, 0.0]))
+ Ry = angle_axis(angles[1], np.array([0.0, 1.0, 0.0]))
+ Rz = angle_axis(angles[2], np.array([0.0, 0.0, 1.0]))
+
+ rotation_matrix = torch.matmul(torch.matmul(Rz, Ry), Rx)
+
+ normals = points.size(1) > 3
+ if not normals:
+ return torch.matmul(points, rotation_matrix.t())
+ else:
+ pc_xyz = points[:, 0:3]
+ pc_normals = points[:, 3:]
+ points[:, 0:3] = torch.matmul(pc_xyz, rotation_matrix.t())
+ points[:, 3:] = torch.matmul(pc_normals, rotation_matrix.t())
+
+ return points
+
+
+class PointcloudJitter(object):
+ def __init__(self, std=0.01, clip=0.05):
+ self.std, self.clip = std, clip
+
+ def __call__(self, points):
+ jittered_data = (
+ points.new(points.size(0), 3)
+ .normal_(mean=0.0, std=self.std)
+ .clamp_(-self.clip, self.clip)
+ )
+ points[:, 0:3] += jittered_data
+ return points
+
+
+class PointcloudTranslate(object):
+ def __init__(self, translate_range=0.1):
+ self.translate_range = translate_range
+
+ def __call__(self, points):
+ translation = np.random.uniform(-self.translate_range, self.translate_range)
+ points[:, 0:3] += translation
+ return points
+
+
+class PointcloudToTensor(object):
+ def __call__(self, points):
+ return torch.from_numpy(points).float()
+
+
+class PointcloudRandomInputDropout(object):
+ def __init__(self, max_dropout_ratio=0.875):
+ assert max_dropout_ratio >= 0 and max_dropout_ratio < 1
+ self.max_dropout_ratio = max_dropout_ratio
+
+ def __call__(self, points):
+ pc = points.numpy()
+
+ dropout_ratio = np.random.random() * self.max_dropout_ratio # 0~0.875
+ drop_idx = np.where(np.random.random((pc.shape[0])) <= dropout_ratio)[0]
+ if len(drop_idx) > 0:
+ pc[drop_idx] = pc[0] # set to the first point
+
+ return torch.from_numpy(pc).float()
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2/models/__init__.py b/pointnet++/Pointnet2_PyTorch/pointnet2/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3ca7984afc208a33ec19d729f81340217f7841a
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2/models/__init__.py
@@ -0,0 +1,4 @@
+from pointnet2.models.pointnet2_msg_cls import PointNet2ClassificationMSG
+from pointnet2.models.pointnet2_msg_sem import PointNet2SemSegMSG
+from pointnet2.models.pointnet2_ssg_cls import PointNet2ClassificationSSG
+from pointnet2.models.pointnet2_ssg_sem import PointNet2SemSegSSG
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2/models/pointnet2_msg_cls.py b/pointnet++/Pointnet2_PyTorch/pointnet2/models/pointnet2_msg_cls.py
new file mode 100644
index 0000000000000000000000000000000000000000..892c0a1c9dc22ee76463f14da51a801b66cf34a8
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2/models/pointnet2_msg_cls.py
@@ -0,0 +1,44 @@
+import pytorch_lightning as pl
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from pointnet2_ops.pointnet2_modules import PointnetSAModule, PointnetSAModuleMSG
+
+from pointnet2.models.pointnet2_ssg_cls import PointNet2ClassificationSSG
+
+
+class PointNet2ClassificationMSG(PointNet2ClassificationSSG):
+ def _build_model(self):
+ super()._build_model()
+
+ self.SA_modules = nn.ModuleList()
+ self.SA_modules.append(
+ PointnetSAModuleMSG(
+ npoint=512,
+ radii=[0.1, 0.2, 0.4],
+ nsamples=[16, 32, 128],
+ mlps=[[3, 32, 32, 64], [3, 64, 64, 128], [3, 64, 96, 128]],
+ use_xyz=self.hparams["model.use_xyz"],
+ )
+ )
+
+ input_channels = 64 + 128 + 128
+ self.SA_modules.append(
+ PointnetSAModuleMSG(
+ npoint=128,
+ radii=[0.2, 0.4, 0.8],
+ nsamples=[32, 64, 128],
+ mlps=[
+ [input_channels, 64, 64, 128],
+ [input_channels, 128, 128, 256],
+ [input_channels, 128, 128, 256],
+ ],
+ use_xyz=self.hparams["model.use_xyz"],
+ )
+ )
+ self.SA_modules.append(
+ PointnetSAModule(
+ mlp=[128 + 256 + 256, 256, 512, 1024],
+ use_xyz=self.hparams["model.use_xyz"],
+ )
+ )
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2/models/pointnet2_msg_sem.py b/pointnet++/Pointnet2_PyTorch/pointnet2/models/pointnet2_msg_sem.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1b02a7397a7a6ef51a2ecfa21c201ea4ed713b2
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2/models/pointnet2_msg_sem.py
@@ -0,0 +1,74 @@
+from collections import namedtuple
+
+import pytorch_lightning as pl
+import torch
+import torch.nn as nn
+from pointnet2_ops.pointnet2_modules import PointnetFPModule, PointnetSAModuleMSG
+
+from pointnet2.models.pointnet2_ssg_sem import PointNet2SemSegSSG
+
+
+class PointNet2SemSegMSG(PointNet2SemSegSSG):
+ def _build_model(self):
+ self.SA_modules = nn.ModuleList()
+ c_in = 6
+ self.SA_modules.append(
+ PointnetSAModuleMSG(
+ npoint=1024,
+ radii=[0.05, 0.1],
+ nsamples=[16, 32],
+ mlps=[[c_in, 16, 16, 32], [c_in, 32, 32, 64]],
+ use_xyz=self.hparams["model.use_xyz"],
+ )
+ )
+ c_out_0 = 32 + 64
+
+ c_in = c_out_0
+ self.SA_modules.append(
+ PointnetSAModuleMSG(
+ npoint=256,
+ radii=[0.1, 0.2],
+ nsamples=[16, 32],
+ mlps=[[c_in, 64, 64, 128], [c_in, 64, 96, 128]],
+ use_xyz=self.hparams["model.use_xyz"],
+ )
+ )
+ c_out_1 = 128 + 128
+
+ c_in = c_out_1
+ self.SA_modules.append(
+ PointnetSAModuleMSG(
+ npoint=64,
+ radii=[0.2, 0.4],
+ nsamples=[16, 32],
+ mlps=[[c_in, 128, 196, 256], [c_in, 128, 196, 256]],
+ use_xyz=self.hparams["model.use_xyz"],
+ )
+ )
+ c_out_2 = 256 + 256
+
+ c_in = c_out_2
+ self.SA_modules.append(
+ PointnetSAModuleMSG(
+ npoint=16,
+ radii=[0.4, 0.8],
+ nsamples=[16, 32],
+ mlps=[[c_in, 256, 256, 512], [c_in, 256, 384, 512]],
+ use_xyz=self.hparams["model.use_xyz"],
+ )
+ )
+ c_out_3 = 512 + 512
+
+ self.FP_modules = nn.ModuleList()
+ self.FP_modules.append(PointnetFPModule(mlp=[256 + 6, 128, 128]))
+ self.FP_modules.append(PointnetFPModule(mlp=[512 + c_out_0, 256, 256]))
+ self.FP_modules.append(PointnetFPModule(mlp=[512 + c_out_1, 512, 512]))
+ self.FP_modules.append(PointnetFPModule(mlp=[c_out_3 + c_out_2, 512, 512]))
+
+ self.fc_lyaer = nn.Sequential(
+ nn.Conv1d(128, 128, kernel_size=1, bias=False),
+ nn.BatchNorm1d(128),
+ nn.ReLU(True),
+ nn.Dropout(0.5),
+ nn.Conv1d(128, 13, kernel_size=1),
+ )
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2/models/pointnet2_ssg_cls.py b/pointnet++/Pointnet2_PyTorch/pointnet2/models/pointnet2_ssg_cls.py
new file mode 100644
index 0000000000000000000000000000000000000000..282fac405cc702e35452bb25d2f448c742eb2a51
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2/models/pointnet2_ssg_cls.py
@@ -0,0 +1,230 @@
+import pytorch_lightning as pl
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim.lr_scheduler as lr_sched
+from pointnet2_ops.pointnet2_modules import PointnetFPModule, PointnetSAModule
+from torch.utils.data import DataLoader, DistributedSampler
+from torchvision import transforms
+
+import pointnet2.data.data_utils as d_utils
+from pointnet2.data.ModelNet40Loader import ModelNet40Cls
+
+
+def set_bn_momentum_default(bn_momentum):
+ def fn(m):
+ if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
+ m.momentum = bn_momentum
+
+ return fn
+
+
+class BNMomentumScheduler(lr_sched.LambdaLR):
+ def __init__(self, model, bn_lambda, last_epoch=-1, setter=set_bn_momentum_default):
+ if not isinstance(model, nn.Module):
+ raise RuntimeError(
+ "Class '{}' is not a PyTorch nn Module".format(type(model)._name_)
+ )
+
+ self.model = model
+ self.setter = setter
+ self.lmbd = bn_lambda
+
+ self.step(last_epoch + 1)
+ self.last_epoch = last_epoch
+
+ def step(self, epoch=None):
+ if epoch is None:
+ epoch = self.last_epoch + 1
+
+ self.last_epoch = epoch
+ self.model.apply(self.setter(self.lmbd(epoch)))
+
+ def state_dict(self):
+ return dict(last_epoch=self.last_epoch)
+
+ def load_state_dict(self, state):
+ self.last_epoch = state["last_epoch"]
+ self.step(self.last_epoch)
+
+
+lr_clip = 1e-5
+bnm_clip = 1e-2
+
+
+class PointNet2ClassificationSSG(pl.LightningModule):
+ def __init__(self, hparams):
+ super().__init__()
+
+ self.hparams = hparams
+
+ self._build_model()
+
+ def _build_model(self):
+ self.SA_modules = nn.ModuleList()
+ self.SA_modules.append(
+ PointnetSAModule(
+ npoint=512,
+ radius=0.2,
+ nsample=64,
+ mlp=[3, 64, 64, 128],
+ use_xyz=self.hparams["model.use_xyz"],
+ )
+ )
+ self.SA_modules.append(
+ PointnetSAModule(
+ npoint=128,
+ radius=0.4,
+ nsample=64,
+ mlp=[128, 128, 128, 256],
+ use_xyz=self.hparams["model.use_xyz"],
+ )
+ )
+ self.SA_modules.append(
+ PointnetSAModule(
+ mlp=[256, 256, 512, 1024], use_xyz=self.hparams["model.use_xyz"]
+ )
+ )
+
+ self.fc_layer = nn.Sequential(
+ nn.Linear(1024, 512, bias=False),
+ nn.BatchNorm1d(512),
+ nn.ReLU(True),
+ nn.Linear(512, 256, bias=False),
+ nn.BatchNorm1d(256),
+ nn.ReLU(True),
+ nn.Dropout(0.5),
+ nn.Linear(256, 40),
+ )
+
+ def _break_up_pc(self, pc):
+ xyz = pc[..., 0:3].contiguous()
+ features = pc[..., 3:].transpose(1, 2).contiguous() if pc.size(-1) > 3 else None
+
+ return xyz, features
+
+ def forward(self, pointcloud):
+ r"""
+ Forward pass of the network
+
+ Parameters
+ ----------
+ pointcloud: Variable(torch.cuda.FloatTensor)
+ (B, N, 3 + input_channels) tensor
+ Point cloud to run predicts on
+ Each point in the point-cloud MUST
+ be formated as (x, y, z, features...)
+ """
+ xyz, features = self._break_up_pc(pointcloud)
+
+ for module in self.SA_modules:
+ xyz, features = module(xyz, features)
+
+ return self.fc_layer(features.squeeze(-1))
+
+ def training_step(self, batch, batch_idx):
+ pc, labels = batch
+
+ logits = self.forward(pc)
+ loss = F.cross_entropy(logits, labels)
+ with torch.no_grad():
+ acc = (torch.argmax(logits, dim=1) == labels).float().mean()
+
+ log = dict(train_loss=loss, train_acc=acc)
+
+ return dict(loss=loss, log=log, progress_bar=dict(train_acc=acc))
+
+ def validation_step(self, batch, batch_idx):
+ pc, labels = batch
+
+ logits = self.forward(pc)
+ loss = F.cross_entropy(logits, labels)
+ acc = (torch.argmax(logits, dim=1) == labels).float().mean()
+
+ return dict(val_loss=loss, val_acc=acc)
+
+ def validation_end(self, outputs):
+ reduced_outputs = {}
+ for k in outputs[0]:
+ for o in outputs:
+ reduced_outputs[k] = reduced_outputs.get(k, []) + [o[k]]
+
+ for k in reduced_outputs:
+ reduced_outputs[k] = torch.stack(reduced_outputs[k]).mean()
+
+ reduced_outputs.update(
+ dict(log=reduced_outputs.copy(), progress_bar=reduced_outputs.copy())
+ )
+
+ return reduced_outputs
+
+ def configure_optimizers(self):
+ lr_lbmd = lambda _: max(
+ self.hparams["optimizer.lr_decay"]
+ ** (
+ int(
+ self.global_step
+ * self.hparams["batch_size"]
+ / self.hparams["optimizer.decay_step"]
+ )
+ ),
+ lr_clip / self.hparams["optimizer.lr"],
+ )
+ bn_lbmd = lambda _: max(
+ self.hparams["optimizer.bn_momentum"]
+ * self.hparams["optimizer.bnm_decay"]
+ ** (
+ int(
+ self.global_step
+ * self.hparams["batch_size"]
+ / self.hparams["optimizer.decay_step"]
+ )
+ ),
+ bnm_clip,
+ )
+
+ optimizer = torch.optim.Adam(
+ self.parameters(),
+ lr=self.hparams["optimizer.lr"],
+ weight_decay=self.hparams["optimizer.weight_decay"],
+ )
+ lr_scheduler = lr_sched.LambdaLR(optimizer, lr_lambda=lr_lbmd)
+ bnm_scheduler = BNMomentumScheduler(self, bn_lambda=bn_lbmd)
+
+ return [optimizer], [lr_scheduler, bnm_scheduler]
+
+ def prepare_data(self):
+ train_transforms = transforms.Compose(
+ [
+ d_utils.PointcloudToTensor(),
+ d_utils.PointcloudScale(),
+ d_utils.PointcloudRotate(),
+ d_utils.PointcloudRotatePerturbation(),
+ d_utils.PointcloudTranslate(),
+ d_utils.PointcloudJitter(),
+ d_utils.PointcloudRandomInputDropout(),
+ ]
+ )
+
+ self.train_dset = ModelNet40Cls(
+ self.hparams["num_points"], transforms=train_transforms, train=True
+ )
+ self.val_dset = ModelNet40Cls(
+ self.hparams["num_points"], transforms=None, train=False
+ )
+
+ def _build_dataloader(self, dset, mode):
+ return DataLoader(
+ dset,
+ batch_size=self.hparams["batch_size"],
+ shuffle=mode == "train",
+ num_workers=4,
+ pin_memory=True,
+ drop_last=mode == "train",
+ )
+
+ def train_dataloader(self):
+ return self._build_dataloader(self.train_dset, mode="train")
+
+ def val_dataloader(self):
+ return self._build_dataloader(self.val_dset, mode="val")
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2/models/pointnet2_ssg_sem.py b/pointnet++/Pointnet2_PyTorch/pointnet2/models/pointnet2_ssg_sem.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3f70bc0bfb5bad4e8cdca5bd3f58a8d8c8118d8
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2/models/pointnet2_ssg_sem.py
@@ -0,0 +1,94 @@
+import pytorch_lightning as pl
+import torch
+import torch.nn as nn
+from pointnet2_ops.pointnet2_modules import PointnetFPModule, PointnetSAModule
+from torch.utils.data import DataLoader
+
+from pointnet2.data import Indoor3DSemSeg
+from pointnet2.models.pointnet2_ssg_cls import PointNet2ClassificationSSG
+
+
+class PointNet2SemSegSSG(PointNet2ClassificationSSG):
+ def _build_model(self):
+ self.SA_modules = nn.ModuleList()
+ self.SA_modules.append(
+ PointnetSAModule(
+ npoint=1024,
+ radius=0.1,
+ nsample=32,
+ mlp=[6, 32, 32, 64],
+ use_xyz=self.hparams["model.use_xyz"],
+ )
+ )
+ self.SA_modules.append(
+ PointnetSAModule(
+ npoint=256,
+ radius=0.2,
+ nsample=32,
+ mlp=[64, 64, 64, 128],
+ use_xyz=self.hparams["model.use_xyz"],
+ )
+ )
+ self.SA_modules.append(
+ PointnetSAModule(
+ npoint=64,
+ radius=0.4,
+ nsample=32,
+ mlp=[128, 128, 128, 256],
+ use_xyz=self.hparams["model.use_xyz"],
+ )
+ )
+ self.SA_modules.append(
+ PointnetSAModule(
+ npoint=16,
+ radius=0.8,
+ nsample=32,
+ mlp=[256, 256, 256, 512],
+ use_xyz=self.hparams["model.use_xyz"],
+ )
+ )
+
+ self.FP_modules = nn.ModuleList()
+ self.FP_modules.append(PointnetFPModule(mlp=[128 + 6, 128, 128, 128]))
+ self.FP_modules.append(PointnetFPModule(mlp=[256 + 64, 256, 128]))
+ self.FP_modules.append(PointnetFPModule(mlp=[256 + 128, 256, 256]))
+ self.FP_modules.append(PointnetFPModule(mlp=[512 + 256, 256, 256]))
+
+ self.fc_lyaer = nn.Sequential(
+ nn.Conv1d(128, 128, kernel_size=1, bias=False),
+ nn.BatchNorm1d(128),
+ nn.ReLU(True),
+ nn.Dropout(0.5),
+ nn.Conv1d(128, 13, kernel_size=1),
+ )
+
+ def forward(self, pointcloud):
+ r"""
+ Forward pass of the network
+
+ Parameters
+ ----------
+ pointcloud: Variable(torch.cuda.FloatTensor)
+ (B, N, 3 + input_channels) tensor
+ Point cloud to run predicts on
+ Each point in the point-cloud MUST
+ be formated as (x, y, z, features...)
+ """
+ xyz, features = self._break_up_pc(pointcloud)
+
+ l_xyz, l_features = [xyz], [features]
+ for i in range(len(self.SA_modules)):
+ li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i])
+ l_xyz.append(li_xyz)
+ l_features.append(li_features)
+
+ for i in range(-1, -(len(self.FP_modules) + 1), -1):
+ l_features[i - 1] = self.FP_modules[i](
+ l_xyz[i - 1], l_xyz[i], l_features[i - 1], l_features[i]
+ )
+
+ return self.fc_lyaer(l_features[0])
+
+ def prepare_data(self):
+ self.train_dset = Indoor3DSemSeg(self.hparams["num_points"], train=True)
+ self.val_dset = Indoor3DSemSeg(self.hparams["num_points"], train=False)
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2/train.py b/pointnet++/Pointnet2_PyTorch/pointnet2/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..c980c1f4e19657a9652e4523ebf72ab8277b8b32
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2/train.py
@@ -0,0 +1,55 @@
+import os
+
+import hydra
+import omegaconf
+import pytorch_lightning as pl
+import torch
+from pytorch_lightning.loggers import TensorBoardLogger
+
+torch.backends.cudnn.enabled = True
+torch.backends.cudnn.benchmark = True
+
+
+def hydra_params_to_dotdict(hparams):
+ def _to_dot_dict(cfg):
+ res = {}
+ for k, v in cfg.items():
+ if isinstance(v, omegaconf.DictConfig):
+ res.update(
+ {k + "." + subk: subv for subk, subv in _to_dot_dict(v).items()}
+ )
+ elif isinstance(v, (str, int, float, bool)):
+ res[k] = v
+
+ return res
+
+ return _to_dot_dict(hparams)
+
+
+@hydra.main("config/config.yaml")
+def main(cfg):
+ model = hydra.utils.instantiate(cfg.task_model, hydra_params_to_dotdict(cfg))
+
+ early_stop_callback = pl.callbacks.EarlyStopping(patience=5)
+ checkpoint_callback = pl.callbacks.ModelCheckpoint(
+ monitor="val_acc",
+ mode="max",
+ save_top_k=2,
+ filepath=os.path.join(
+ cfg.task_model.name, "{epoch}-{val_loss:.2f}-{val_acc:.3f}"
+ ),
+ verbose=True,
+ )
+ trainer = pl.Trainer(
+ gpus=list(cfg.gpus),
+ max_epochs=cfg.epochs,
+ early_stop_callback=early_stop_callback,
+ checkpoint_callback=checkpoint_callback,
+ distributed_backend=cfg.distrib_backend,
+ )
+
+ trainer.fit(model)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2/utils/.gitignore b/pointnet++/Pointnet2_PyTorch/pointnet2/utils/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..1274d1d12917bd4fa1dd8b07e49eb4c40de4e214
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2/utils/.gitignore
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fbf8a5f5427a2eea75ddd7810f43f41cc7277d4943ed4141bd12a9ca040d793f
+size 11
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/MANIFEST.in b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/MANIFEST.in
new file mode 100644
index 0000000000000000000000000000000000000000..082bc839a6d802bd5ba1c7372659a4b2fa3be116
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/MANIFEST.in
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cdbbbfb701c38455f913a405af616d94d60f0064b89d3fc40034506901d9eed3
+size 29
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/__init__.py b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5fd361f9abbacc218f7699b8c439902b9d1bf745
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/__init__.py
@@ -0,0 +1,3 @@
+import pointnet2_ops.pointnet2_modules
+import pointnet2_ops.pointnet2_utils
+from pointnet2_ops._version import __version__
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/include/ball_query.h b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/include/ball_query.h
new file mode 100644
index 0000000000000000000000000000000000000000..c178accd7ee197d3c68d943af02b849bb5a2caf0
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/include/ball_query.h
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:595e651f7ac88e14248199cba61e9d0ddd3e18f8535e1ce38d471b30a246427a
+size 163
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/include/cuda_utils.h b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/include/cuda_utils.h
new file mode 100644
index 0000000000000000000000000000000000000000..864e2bd95dceb3498a613ccdddaa2cc862cb7a5d
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/include/cuda_utils.h
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:de26c10e6748e61da7f6516863f2a20d923ccd78e539efdfe18db8a0aeb0a288
+size 1303
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/include/group_points.h b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/include/group_points.h
new file mode 100644
index 0000000000000000000000000000000000000000..5b56f4bd6ade4fcdf00bfe2a7dc17412e52f0257
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/include/group_points.h
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e702934ff68b3030b052af38b469f00786ea7b3620e017ef58bcc4cf5f0d6325
+size 183
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/include/interpolate.h b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/include/interpolate.h
new file mode 100644
index 0000000000000000000000000000000000000000..0178c2d76760423fbf64764634444f67edf1b478
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/include/interpolate.h
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9d96cc311411e9e95cc03e47ea6d63269a1b888342d54d074348b967c96d07c8
+size 386
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/include/sampling.h b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/include/sampling.h
new file mode 100644
index 0000000000000000000000000000000000000000..14a57d2330f3dab172b079388f4f3fbec85f91dc
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/include/sampling.h
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e1b919a8bacf4c077dff6d86d2226df4eeb041e5d135be2a2d829c5de933d068
+size 260
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/include/utils.h b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/include/utils.h
new file mode 100644
index 0000000000000000000000000000000000000000..e5c32ab2ff2102c7ad545a843069d1182e1e07f4
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/include/utils.h
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e0f94d69288a2010d12aa35494f127b7e3dccd154d9aebdecef84f05a6fc6fbe
+size 983
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/src/ball_query.cpp b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/src/ball_query.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..4b1c841022ef5046019c7055a2cd38a36381b6e5
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/src/ball_query.cpp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ff4b6076cfc3299faed905927391144511fdbfe807b14ae2970f26c1756f6aaa
+size 1037
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/src/ball_query_gpu.cu b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/src/ball_query_gpu.cu
new file mode 100644
index 0000000000000000000000000000000000000000..ebff1506053fe6660987faa9e2e2fc13d80a2a07
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/src/ball_query_gpu.cu
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7253c10da9db4b02d25e43b7a713ee956dae27b4ec25220f1b2abbaccc79530f
+size 1784
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/src/bindings.cpp b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/src/bindings.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..22f678d56e2fe7c42aecf1ef8b0071fa883dad47
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/src/bindings.cpp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:542241ad191246365594c3861e93b1e1d625fb0fd1fd95eec113fed98c66d8b5
+size 570
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/src/group_points.cpp b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/src/group_points.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..8cb6b4abedf5e8ed74c4e3475a674760f3f19808
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/src/group_points.cpp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c8079fe09e1a35a6142545f3277c71c483b54e19653930378e2f2f1d3b9d7971
+size 1952
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/src/group_points_gpu.cu b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/src/group_points_gpu.cu
new file mode 100644
index 0000000000000000000000000000000000000000..bc281cb478d1af1f2d89ac3dc7fcbadf24cb4ff2
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/src/group_points_gpu.cu
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cf0bc3c20f709ed5be113c5d2618d3d61dcc49fb2449d53323596c5d1d4857fb
+size 2885
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/src/interpolate.cpp b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/src/interpolate.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..922f3fe9aa2085d7c58acefaf0ff9555a9eb219c
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/src/interpolate.cpp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0fb9ca45eb12d9998618b69d42272f6a58814437a1a2803941e6b496944411a1
+size 3304
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/src/interpolate_gpu.cu b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/src/interpolate_gpu.cu
new file mode 100644
index 0000000000000000000000000000000000000000..af4dabfb5c87f76086b9a5fb21334316219aad07
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/src/interpolate_gpu.cu
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e4a327f15d49bcefe62bd72adc407157f97039c3b34037f8fc767b9a15936173
+size 5141
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/src/sampling.cpp b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/src/sampling.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..a264c69d7b9ae80d81e11b111a5e2c0a770a7fd7
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/src/sampling.cpp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:eb800bdb22cc5b2a3742eae1c125cd6883969c329a3e36c8c4d3d9de98246312
+size 2894
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/src/sampling_gpu.cu b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/src/sampling_gpu.cu
new file mode 100644
index 0000000000000000000000000000000000000000..b50ac8f3fb305603d1a5791cc34ba0b97c8287b6
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/src/sampling_gpu.cu
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a1ec92aabdb6e4ab3ad0b982ebb28fa79b1b1fc95e3accbfaa511ef9bc317fe4
+size 7019
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/_ext.cpython-310-x86_64-linux-gnu.so b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/_ext.cpython-310-x86_64-linux-gnu.so
new file mode 100755
index 0000000000000000000000000000000000000000..5fea7126284f204079ee7f51235194bcc35bb920
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/_ext.cpython-310-x86_64-linux-gnu.so
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6a62c46bedbfc07ed066d6580f82254c0d9cfdc4e6b3509cbd44b6031a2933fd
+size 478280
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/_version.py b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/_version.py
new file mode 100644
index 0000000000000000000000000000000000000000..528787cfc8ad81ed41822a8104b60b4896632906
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/_version.py
@@ -0,0 +1 @@
+__version__ = "3.0.0"
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/pointnet2_modules.py b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/pointnet2_modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0ad4f6bc23f54ca2d61454e657a6f533e9b875c
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/pointnet2_modules.py
@@ -0,0 +1,209 @@
+from typing import List, Optional, Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from pointnet2_ops import pointnet2_utils
+
+
+def build_shared_mlp(mlp_spec: List[int], bn: bool = True):
+ layers = []
+ for i in range(1, len(mlp_spec)):
+ layers.append(
+ nn.Conv2d(mlp_spec[i - 1], mlp_spec[i], kernel_size=1, bias=not bn)
+ )
+ if bn:
+ layers.append(nn.BatchNorm2d(mlp_spec[i]))
+ layers.append(nn.ReLU(True))
+
+ return nn.Sequential(*layers)
+
+
+class _PointnetSAModuleBase(nn.Module):
+ def __init__(self):
+ super(_PointnetSAModuleBase, self).__init__()
+ self.npoint = None
+ self.groupers = None
+ self.mlps = None
+
+ def forward(
+ self, xyz: torch.Tensor, features: Optional[torch.Tensor]
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ r"""
+ Parameters
+ ----------
+ xyz : torch.Tensor
+ (B, N, 3) tensor of the xyz coordinates of the features
+ features : torch.Tensor
+ (B, C, N) tensor of the descriptors of the the features
+
+ Returns
+ -------
+ new_xyz : torch.Tensor
+ (B, npoint, 3) tensor of the new features' xyz
+ new_features : torch.Tensor
+ (B, \sum_k(mlps[k][-1]), npoint) tensor of the new_features descriptors
+ """
+
+ new_features_list = []
+
+ xyz_flipped = xyz.transpose(1, 2).contiguous()
+ new_xyz = (
+ pointnet2_utils.gather_operation(
+ xyz_flipped, pointnet2_utils.furthest_point_sample(xyz, self.npoint)
+ )
+ .transpose(1, 2)
+ .contiguous()
+ if self.npoint is not None
+ else None
+ )
+
+ for i in range(len(self.groupers)):
+ new_features = self.groupers[i](
+ xyz, new_xyz, features
+ ) # (B, C, npoint, nsample)
+
+ new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample)
+ new_features = F.max_pool2d(
+ new_features, kernel_size=[1, new_features.size(3)]
+ ) # (B, mlp[-1], npoint, 1)
+ new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint)
+
+ new_features_list.append(new_features)
+
+ return new_xyz, torch.cat(new_features_list, dim=1)
+
+
+class PointnetSAModuleMSG(_PointnetSAModuleBase):
+ r"""Pointnet set abstrction layer with multiscale grouping
+
+ Parameters
+ ----------
+ npoint : int
+ Number of features
+ radii : list of float32
+ list of radii to group with
+ nsamples : list of int32
+ Number of samples in each ball query
+ mlps : list of list of int32
+ Spec of the pointnet before the global max_pool for each scale
+ bn : bool
+ Use batchnorm
+ """
+
+ def __init__(self, npoint, radii, nsamples, mlps, bn=True, use_xyz=True):
+ # type: (PointnetSAModuleMSG, int, List[float], List[int], List[List[int]], bool, bool) -> None
+ super(PointnetSAModuleMSG, self).__init__()
+
+ assert len(radii) == len(nsamples) == len(mlps)
+
+ self.npoint = npoint
+ self.groupers = nn.ModuleList()
+ self.mlps = nn.ModuleList()
+ for i in range(len(radii)):
+ radius = radii[i]
+ nsample = nsamples[i]
+ self.groupers.append(
+ pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz)
+ if npoint is not None
+ else pointnet2_utils.GroupAll(use_xyz)
+ )
+ mlp_spec = mlps[i]
+ if use_xyz:
+ mlp_spec[0] += 3
+
+ self.mlps.append(build_shared_mlp(mlp_spec, bn))
+
+
+class PointnetSAModule(PointnetSAModuleMSG):
+ r"""Pointnet set abstrction layer
+
+ Parameters
+ ----------
+ npoint : int
+ Number of features
+ radius : float
+ Radius of ball
+ nsample : int
+ Number of samples in the ball query
+ mlp : list
+ Spec of the pointnet before the global max_pool
+ bn : bool
+ Use batchnorm
+ """
+
+ def __init__(
+ self, mlp, npoint=None, radius=None, nsample=None, bn=True, use_xyz=True
+ ):
+ # type: (PointnetSAModule, List[int], int, float, int, bool, bool) -> None
+ super(PointnetSAModule, self).__init__(
+ mlps=[mlp],
+ npoint=npoint,
+ radii=[radius],
+ nsamples=[nsample],
+ bn=bn,
+ use_xyz=use_xyz,
+ )
+
+
+class PointnetFPModule(nn.Module):
+ r"""Propigates the features of one set to another
+
+ Parameters
+ ----------
+ mlp : list
+ Pointnet module parameters
+ bn : bool
+ Use batchnorm
+ """
+
+ def __init__(self, mlp, bn=True):
+ # type: (PointnetFPModule, List[int], bool) -> None
+ super(PointnetFPModule, self).__init__()
+ self.mlp = build_shared_mlp(mlp, bn=bn)
+
+ def forward(self, unknown, known, unknow_feats, known_feats):
+ # type: (PointnetFPModule, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor
+ r"""
+ Parameters
+ ----------
+ unknown : torch.Tensor
+ (B, n, 3) tensor of the xyz positions of the unknown features
+ known : torch.Tensor
+ (B, m, 3) tensor of the xyz positions of the known features
+ unknow_feats : torch.Tensor
+ (B, C1, n) tensor of the features to be propigated to
+ known_feats : torch.Tensor
+ (B, C2, m) tensor of features to be propigated
+
+ Returns
+ -------
+ new_features : torch.Tensor
+ (B, mlp[-1], n) tensor of the features of the unknown features
+ """
+
+ if known is not None:
+ dist, idx = pointnet2_utils.three_nn(unknown, known)
+ dist_recip = 1.0 / (dist + 1e-8)
+ norm = torch.sum(dist_recip, dim=2, keepdim=True)
+ weight = dist_recip / norm
+
+ interpolated_feats = pointnet2_utils.three_interpolate(
+ known_feats, idx, weight
+ )
+ else:
+ interpolated_feats = known_feats.expand(
+ *(known_feats.size()[0:2] + [unknown.size(1)])
+ )
+
+ if unknow_feats is not None:
+ new_features = torch.cat(
+ [interpolated_feats, unknow_feats], dim=1
+ ) # (B, C2 + C1, n)
+ else:
+ new_features = interpolated_feats
+
+ new_features = new_features.unsqueeze(-1)
+ new_features = self.mlp(new_features)
+
+ return new_features.squeeze(-1)
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/pointnet2_utils.py b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/pointnet2_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..150fcccade21001971a76e6c3628963972305739
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/lib.linux-x86_64-cpython-310/pointnet2_ops/pointnet2_utils.py
@@ -0,0 +1,379 @@
+import torch
+import torch.nn as nn
+import warnings
+from torch.autograd import Function
+from typing import *
+
+try:
+ import pointnet2_ops._ext as _ext
+except ImportError:
+ from torch.utils.cpp_extension import load
+ import glob
+ import os.path as osp
+ import os
+
+ warnings.warn("Unable to load pointnet2_ops cpp extension. JIT Compiling.")
+
+ _ext_src_root = osp.join(osp.dirname(__file__), "_ext-src")
+ _ext_sources = glob.glob(osp.join(_ext_src_root, "src", "*.cpp")) + glob.glob(
+ osp.join(_ext_src_root, "src", "*.cu")
+ )
+ _ext_headers = glob.glob(osp.join(_ext_src_root, "include", "*"))
+
+ os.environ["TORCH_CUDA_ARCH_LIST"] = "3.7+PTX;5.0;6.0;6.1;6.2;7.0;7.5"
+ _ext = load(
+ "_ext",
+ sources=_ext_sources,
+ extra_include_paths=[osp.join(_ext_src_root, "include")],
+ extra_cflags=["-O3"],
+ extra_cuda_cflags=["-O3", "-Xfatbin", "-compress-all"],
+ with_cuda=True,
+ )
+
+
+class FurthestPointSampling(Function):
+ @staticmethod
+ def forward(ctx, xyz, npoint):
+ # type: (Any, torch.Tensor, int) -> torch.Tensor
+ r"""
+ Uses iterative furthest point sampling to select a set of npoint features that have the largest
+ minimum distance
+
+ Parameters
+ ----------
+ xyz : torch.Tensor
+ (B, N, 3) tensor where N > npoint
+ npoint : int32
+ number of features in the sampled set
+
+ Returns
+ -------
+ torch.Tensor
+ (B, npoint) tensor containing the set
+ """
+ out = _ext.furthest_point_sampling(xyz, npoint)
+
+ ctx.mark_non_differentiable(out)
+
+ return out
+
+ @staticmethod
+ def backward(ctx, grad_out):
+ return ()
+
+
+furthest_point_sample = FurthestPointSampling.apply
+
+
+class GatherOperation(Function):
+ @staticmethod
+ def forward(ctx, features, idx):
+ # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor
+ r"""
+
+ Parameters
+ ----------
+ features : torch.Tensor
+ (B, C, N) tensor
+
+ idx : torch.Tensor
+ (B, npoint) tensor of the features to gather
+
+ Returns
+ -------
+ torch.Tensor
+ (B, C, npoint) tensor
+ """
+
+ ctx.save_for_backward(idx, features)
+
+ return _ext.gather_points(features, idx)
+
+ @staticmethod
+ def backward(ctx, grad_out):
+ idx, features = ctx.saved_tensors
+ N = features.size(2)
+
+ grad_features = _ext.gather_points_grad(grad_out.contiguous(), idx, N)
+ return grad_features, None
+
+
+gather_operation = GatherOperation.apply
+
+
+class ThreeNN(Function):
+ @staticmethod
+ def forward(ctx, unknown, known):
+ # type: (Any, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]
+ r"""
+ Find the three nearest neighbors of unknown in known
+ Parameters
+ ----------
+ unknown : torch.Tensor
+ (B, n, 3) tensor of known features
+ known : torch.Tensor
+ (B, m, 3) tensor of unknown features
+
+ Returns
+ -------
+ dist : torch.Tensor
+ (B, n, 3) l2 distance to the three nearest neighbors
+ idx : torch.Tensor
+ (B, n, 3) index of 3 nearest neighbors
+ """
+ dist2, idx = _ext.three_nn(unknown, known)
+ dist = torch.sqrt(dist2)
+
+ ctx.mark_non_differentiable(dist, idx)
+
+ return dist, idx
+
+ @staticmethod
+ def backward(ctx, grad_dist, grad_idx):
+ return ()
+
+
+three_nn = ThreeNN.apply
+
+
+class ThreeInterpolate(Function):
+ @staticmethod
+ def forward(ctx, features, idx, weight):
+ # type(Any, torch.Tensor, torch.Tensor, torch.Tensor) -> Torch.Tensor
+ r"""
+ Performs weight linear interpolation on 3 features
+ Parameters
+ ----------
+ features : torch.Tensor
+ (B, c, m) Features descriptors to be interpolated from
+ idx : torch.Tensor
+ (B, n, 3) three nearest neighbors of the target features in features
+ weight : torch.Tensor
+ (B, n, 3) weights
+
+ Returns
+ -------
+ torch.Tensor
+ (B, c, n) tensor of the interpolated features
+ """
+ ctx.save_for_backward(idx, weight, features)
+
+ return _ext.three_interpolate(features, idx, weight)
+
+ @staticmethod
+ def backward(ctx, grad_out):
+ # type: (Any, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
+ r"""
+ Parameters
+ ----------
+ grad_out : torch.Tensor
+ (B, c, n) tensor with gradients of ouputs
+
+ Returns
+ -------
+ grad_features : torch.Tensor
+ (B, c, m) tensor with gradients of features
+
+ None
+
+ None
+ """
+ idx, weight, features = ctx.saved_tensors
+ m = features.size(2)
+
+ grad_features = _ext.three_interpolate_grad(
+ grad_out.contiguous(), idx, weight, m
+ )
+
+ return grad_features, torch.zeros_like(idx), torch.zeros_like(weight)
+
+
+three_interpolate = ThreeInterpolate.apply
+
+
+class GroupingOperation(Function):
+ @staticmethod
+ def forward(ctx, features, idx):
+ # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor
+ r"""
+
+ Parameters
+ ----------
+ features : torch.Tensor
+ (B, C, N) tensor of features to group
+ idx : torch.Tensor
+ (B, npoint, nsample) tensor containing the indicies of features to group with
+
+ Returns
+ -------
+ torch.Tensor
+ (B, C, npoint, nsample) tensor
+ """
+ ctx.save_for_backward(idx, features)
+
+ return _ext.group_points(features, idx)
+
+ @staticmethod
+ def backward(ctx, grad_out):
+ # type: (Any, torch.tensor) -> Tuple[torch.Tensor, torch.Tensor]
+ r"""
+
+ Parameters
+ ----------
+ grad_out : torch.Tensor
+ (B, C, npoint, nsample) tensor of the gradients of the output from forward
+
+ Returns
+ -------
+ torch.Tensor
+ (B, C, N) gradient of the features
+ None
+ """
+ idx, features = ctx.saved_tensors
+ N = features.size(2)
+
+ grad_features = _ext.group_points_grad(grad_out.contiguous(), idx, N)
+
+ return grad_features, torch.zeros_like(idx)
+
+
+grouping_operation = GroupingOperation.apply
+
+
+class BallQuery(Function):
+ @staticmethod
+ def forward(ctx, radius, nsample, xyz, new_xyz):
+ # type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor
+ r"""
+
+ Parameters
+ ----------
+ radius : float
+ radius of the balls
+ nsample : int
+ maximum number of features in the balls
+ xyz : torch.Tensor
+ (B, N, 3) xyz coordinates of the features
+ new_xyz : torch.Tensor
+ (B, npoint, 3) centers of the ball query
+
+ Returns
+ -------
+ torch.Tensor
+ (B, npoint, nsample) tensor with the indicies of the features that form the query balls
+ """
+ output = _ext.ball_query(new_xyz, xyz, radius, nsample)
+
+ ctx.mark_non_differentiable(output)
+
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_out):
+ return ()
+
+
+ball_query = BallQuery.apply
+
+
+class QueryAndGroup(nn.Module):
+ r"""
+ Groups with a ball query of radius
+
+ Parameters
+ ---------
+ radius : float32
+ Radius of ball
+ nsample : int32
+ Maximum number of features to gather in the ball
+ """
+
+ def __init__(self, radius, nsample, use_xyz=True):
+ # type: (QueryAndGroup, float, int, bool) -> None
+ super(QueryAndGroup, self).__init__()
+ self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz
+
+ def forward(self, xyz, new_xyz, features=None):
+ # type: (QueryAndGroup, torch.Tensor. torch.Tensor, torch.Tensor) -> Tuple[Torch.Tensor]
+ r"""
+ Parameters
+ ----------
+ xyz : torch.Tensor
+ xyz coordinates of the features (B, N, 3)
+ new_xyz : torch.Tensor
+ centriods (B, npoint, 3)
+ features : torch.Tensor
+ Descriptors of the features (B, C, N)
+
+ Returns
+ -------
+ new_features : torch.Tensor
+ (B, 3 + C, npoint, nsample) tensor
+ """
+
+ idx = ball_query(self.radius, self.nsample, xyz, new_xyz)
+ xyz_trans = xyz.transpose(1, 2).contiguous()
+ grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample)
+ grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1)
+
+ if features is not None:
+ grouped_features = grouping_operation(features, idx)
+ if self.use_xyz:
+ new_features = torch.cat(
+ [grouped_xyz, grouped_features], dim=1
+ ) # (B, C + 3, npoint, nsample)
+ else:
+ new_features = grouped_features
+ else:
+ assert (
+ self.use_xyz
+ ), "Cannot have not features and not use xyz as a feature!"
+ new_features = grouped_xyz
+
+ return new_features
+
+
+class GroupAll(nn.Module):
+ r"""
+ Groups all features
+
+ Parameters
+ ---------
+ """
+
+ def __init__(self, use_xyz=True):
+ # type: (GroupAll, bool) -> None
+ super(GroupAll, self).__init__()
+ self.use_xyz = use_xyz
+
+ def forward(self, xyz, new_xyz, features=None):
+ # type: (GroupAll, torch.Tensor, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor]
+ r"""
+ Parameters
+ ----------
+ xyz : torch.Tensor
+ xyz coordinates of the features (B, N, 3)
+ new_xyz : torch.Tensor
+ Ignored
+ features : torch.Tensor
+ Descriptors of the features (B, C, N)
+
+ Returns
+ -------
+ new_features : torch.Tensor
+ (B, C + 3, 1, N) tensor
+ """
+
+ grouped_xyz = xyz.transpose(1, 2).unsqueeze(2)
+ if features is not None:
+ grouped_features = features.unsqueeze(2)
+ if self.use_xyz:
+ new_features = torch.cat(
+ [grouped_xyz, grouped_features], dim=1
+ ) # (B, 3 + C, 1, N)
+ else:
+ new_features = grouped_features
+ else:
+ new_features = grouped_xyz
+
+ return new_features
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/temp.linux-x86_64-cpython-310/.ninja_deps b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/temp.linux-x86_64-cpython-310/.ninja_deps
new file mode 100644
index 0000000000000000000000000000000000000000..500a42b447777fd4c12b2eda34f5c4b92614ea7f
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/temp.linux-x86_64-cpython-310/.ninja_deps
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f28ff6ba39b4388f6201c4a0743d234c1702091e79f6349a2945e3de91c7986d
+size 3863592
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/temp.linux-x86_64-cpython-310/.ninja_log b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/temp.linux-x86_64-cpython-310/.ninja_log
new file mode 100644
index 0000000000000000000000000000000000000000..2a4da3e3e9c3d9f3445bc4e52afb731fe47eb31a
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/temp.linux-x86_64-cpython-310/.ninja_log
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8c3fa5c2d21e297e6c099158e7d8c3155c5ab187d27dae04e098aa12958d7b0f
+size 20074
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/temp.linux-x86_64-cpython-310/build.ninja b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/temp.linux-x86_64-cpython-310/build.ninja
new file mode 100644
index 0000000000000000000000000000000000000000..a5073ba81c9cd341d6b49e3e241e6acd1a13eb95
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/temp.linux-x86_64-cpython-310/build.ninja
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:631e7aa6979183e35bef0ba4fc8464808aafb7073436ac22463fcf33e928f12c
+size 5286
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/temp.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/src/ball_query.o b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/temp.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/src/ball_query.o
new file mode 100644
index 0000000000000000000000000000000000000000..7b9b9ecb5da6ce041bd7e5f61798a0db582a5900
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/temp.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/src/ball_query.o
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:292d6ed48dffe27004cf35f488ddd55a61ec42927d725eb60ef87508ac512e37
+size 239576
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/temp.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/src/ball_query_gpu.o b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/temp.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/src/ball_query_gpu.o
new file mode 100644
index 0000000000000000000000000000000000000000..60bce61dd4130c74c7bcec2c75ccf30987273fd3
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/temp.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/src/ball_query_gpu.o
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7a7bf41b639df087f20cc9c7e8463c4643616469a82b4f93dd454e692d959e1e
+size 24312
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/temp.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/src/bindings.o b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/temp.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/src/bindings.o
new file mode 100644
index 0000000000000000000000000000000000000000..aff88ca45a739b68cbf2096e5a8750704f190e88
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/temp.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/src/bindings.o
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cdd69cb4ba517803ba65a0bfc70d10501877a3782b81a40f31a07555375e7995
+size 282840
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/temp.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/src/group_points.o b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/temp.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/src/group_points.o
new file mode 100644
index 0000000000000000000000000000000000000000..cbc67d021245c1fba9848e20ca60a3dab3790873
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/temp.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/src/group_points.o
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:572d86964def0ea0f5e6956a18f4b0e3b1758207a875594411f3c74e8b504990
+size 245936
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/temp.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/src/group_points_gpu.o b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/temp.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/src/group_points_gpu.o
new file mode 100644
index 0000000000000000000000000000000000000000..f0aeb52dffe6b79a5430bf18ec008e3c5b44595a
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/temp.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/src/group_points_gpu.o
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:89c29b866e6e5660aa9d72e857f6ea0c73e130470ff85c5764de091a4f2a60dd
+size 52752
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/temp.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/src/interpolate.o b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/temp.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/src/interpolate.o
new file mode 100644
index 0000000000000000000000000000000000000000..d9e6bd83749afe90b932f9f4d73dc8f9f0c20634
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/temp.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/src/interpolate.o
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7bedea13d6aa0dd2f5bf7b16dbece36d9192bbfd4829f67f47a9b94968506fd0
+size 254864
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/temp.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/src/interpolate_gpu.o b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/temp.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/src/interpolate_gpu.o
new file mode 100644
index 0000000000000000000000000000000000000000..5f6bfd97cfd06f15423697f35fa2c3b34e2fcba1
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/temp.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/src/interpolate_gpu.o
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d4f37610f73d1c4f7be4cf194b5310c8f8b3fe9f4f6403cae43b3bb99facc0b8
+size 72704
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/temp.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/src/sampling.o b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/temp.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/src/sampling.o
new file mode 100644
index 0000000000000000000000000000000000000000..2700fe48fea615c85d1736f26874916cb1e51ffa
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/temp.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/src/sampling.o
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7f171a7a29e93bc2614c486eade6280828973aaa4ae8bb563f7ca8a1e6c9a5a6
+size 251584
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/temp.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/src/sampling_gpu.o b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/temp.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/src/sampling_gpu.o
new file mode 100644
index 0000000000000000000000000000000000000000..f80d05f5406610a3c7f70698a9e1b05aa8a97143
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/build/temp.linux-x86_64-cpython-310/pointnet2_ops/_ext-src/src/sampling_gpu.o
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:abfd53165b20c261b9d34ce08b06f0790d7b20a203edbafc001235bb2bed3a0f
+size 121680
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/dist/pointnet2_ops-3.0.0-py3.10-linux-x86_64.egg b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/dist/pointnet2_ops-3.0.0-py3.10-linux-x86_64.egg
new file mode 100644
index 0000000000000000000000000000000000000000..ba7e2565d5f4c18ca05eecfd835944350907d00f
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/dist/pointnet2_ops-3.0.0-py3.10-linux-x86_64.egg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:edc71ebba3f0d879728570f7e3e29224db718793308d0b4daea8276965248c3f
+size 208857
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops.egg-info/PKG-INFO b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops.egg-info/PKG-INFO
new file mode 100644
index 0000000000000000000000000000000000000000..f1b2a66a7e9446a1463ad2b883e417b9421f5f03
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops.egg-info/PKG-INFO
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:946f5a00fe646c66f428a0589a5a6ff91192b4ee3c82757703935b6e4004e18c
+size 143
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops.egg-info/SOURCES.txt b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops.egg-info/SOURCES.txt
new file mode 100644
index 0000000000000000000000000000000000000000..784201f9d9979d195bcee1e9909de5dc160d285d
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops.egg-info/SOURCES.txt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c7d532de32a43fbdc6b009ed5919971696e337e94d0b7b8a3f22ac57c6c37721
+size 974
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops.egg-info/dependency_links.txt b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops.egg-info/dependency_links.txt
new file mode 100644
index 0000000000000000000000000000000000000000..fd481aab725cd00018d4e32426f317aa39787c65
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops.egg-info/dependency_links.txt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:01ba4719c80b6fe911b091a7c05124b64eeece964e09c058ef8f9805daca546b
+size 1
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops.egg-info/requires.txt b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops.egg-info/requires.txt
new file mode 100644
index 0000000000000000000000000000000000000000..87b2b2d2bc8e4ac964c65991f30ebd99402f90bb
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops.egg-info/requires.txt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b27f0a8dddf6f20ed2ec121773b7164293306214d0f43de635e50f25ed5a06c6
+size 11
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops.egg-info/top_level.txt b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops.egg-info/top_level.txt
new file mode 100644
index 0000000000000000000000000000000000000000..ba9309af82776cd8fc37ce7520fe1b839bb9da4a
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops.egg-info/top_level.txt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b88f822efaefe234d87a8ad854c64138ae3bf1be98d3149bbbb1a29670253126
+size 14
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/__init__.py b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5fd361f9abbacc218f7699b8c439902b9d1bf745
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/__init__.py
@@ -0,0 +1,3 @@
+import pointnet2_ops.pointnet2_modules
+import pointnet2_ops.pointnet2_utils
+from pointnet2_ops._version import __version__
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/ball_query.h b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/ball_query.h
new file mode 100644
index 0000000000000000000000000000000000000000..c178accd7ee197d3c68d943af02b849bb5a2caf0
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/ball_query.h
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:595e651f7ac88e14248199cba61e9d0ddd3e18f8535e1ce38d471b30a246427a
+size 163
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/cuda_utils.h b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/cuda_utils.h
new file mode 100644
index 0000000000000000000000000000000000000000..864e2bd95dceb3498a613ccdddaa2cc862cb7a5d
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/cuda_utils.h
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:de26c10e6748e61da7f6516863f2a20d923ccd78e539efdfe18db8a0aeb0a288
+size 1303
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/group_points.h b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/group_points.h
new file mode 100644
index 0000000000000000000000000000000000000000..5b56f4bd6ade4fcdf00bfe2a7dc17412e52f0257
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/group_points.h
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e702934ff68b3030b052af38b469f00786ea7b3620e017ef58bcc4cf5f0d6325
+size 183
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/interpolate.h b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/interpolate.h
new file mode 100644
index 0000000000000000000000000000000000000000..0178c2d76760423fbf64764634444f67edf1b478
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/interpolate.h
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9d96cc311411e9e95cc03e47ea6d63269a1b888342d54d074348b967c96d07c8
+size 386
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/sampling.h b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/sampling.h
new file mode 100644
index 0000000000000000000000000000000000000000..14a57d2330f3dab172b079388f4f3fbec85f91dc
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/sampling.h
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e1b919a8bacf4c077dff6d86d2226df4eeb041e5d135be2a2d829c5de933d068
+size 260
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/utils.h b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/utils.h
new file mode 100644
index 0000000000000000000000000000000000000000..e5c32ab2ff2102c7ad545a843069d1182e1e07f4
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/utils.h
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e0f94d69288a2010d12aa35494f127b7e3dccd154d9aebdecef84f05a6fc6fbe
+size 983
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query.cpp b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..4b1c841022ef5046019c7055a2cd38a36381b6e5
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query.cpp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ff4b6076cfc3299faed905927391144511fdbfe807b14ae2970f26c1756f6aaa
+size 1037
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query_gpu.cu b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query_gpu.cu
new file mode 100644
index 0000000000000000000000000000000000000000..ebff1506053fe6660987faa9e2e2fc13d80a2a07
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query_gpu.cu
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7253c10da9db4b02d25e43b7a713ee956dae27b4ec25220f1b2abbaccc79530f
+size 1784
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/bindings.cpp b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/bindings.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..22f678d56e2fe7c42aecf1ef8b0071fa883dad47
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/bindings.cpp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:542241ad191246365594c3861e93b1e1d625fb0fd1fd95eec113fed98c66d8b5
+size 570
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points.cpp b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..8cb6b4abedf5e8ed74c4e3475a674760f3f19808
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points.cpp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c8079fe09e1a35a6142545f3277c71c483b54e19653930378e2f2f1d3b9d7971
+size 1952
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points_gpu.cu b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points_gpu.cu
new file mode 100644
index 0000000000000000000000000000000000000000..bc281cb478d1af1f2d89ac3dc7fcbadf24cb4ff2
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points_gpu.cu
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cf0bc3c20f709ed5be113c5d2618d3d61dcc49fb2449d53323596c5d1d4857fb
+size 2885
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate.cpp b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..922f3fe9aa2085d7c58acefaf0ff9555a9eb219c
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate.cpp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0fb9ca45eb12d9998618b69d42272f6a58814437a1a2803941e6b496944411a1
+size 3304
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate_gpu.cu b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate_gpu.cu
new file mode 100644
index 0000000000000000000000000000000000000000..af4dabfb5c87f76086b9a5fb21334316219aad07
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate_gpu.cu
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e4a327f15d49bcefe62bd72adc407157f97039c3b34037f8fc767b9a15936173
+size 5141
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/sampling.cpp b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/sampling.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..a264c69d7b9ae80d81e11b111a5e2c0a770a7fd7
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/sampling.cpp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:eb800bdb22cc5b2a3742eae1c125cd6883969c329a3e36c8c4d3d9de98246312
+size 2894
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/sampling_gpu.cu b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/sampling_gpu.cu
new file mode 100644
index 0000000000000000000000000000000000000000..b50ac8f3fb305603d1a5791cc34ba0b97c8287b6
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/sampling_gpu.cu
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a1ec92aabdb6e4ab3ad0b982ebb28fa79b1b1fc95e3accbfaa511ef9bc317fe4
+size 7019
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/_version.py b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/_version.py
new file mode 100644
index 0000000000000000000000000000000000000000..528787cfc8ad81ed41822a8104b60b4896632906
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/_version.py
@@ -0,0 +1 @@
+__version__ = "3.0.0"
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/pointnet2_modules.py b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/pointnet2_modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0ad4f6bc23f54ca2d61454e657a6f533e9b875c
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/pointnet2_modules.py
@@ -0,0 +1,209 @@
+from typing import List, Optional, Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from pointnet2_ops import pointnet2_utils
+
+
+def build_shared_mlp(mlp_spec: List[int], bn: bool = True):
+ layers = []
+ for i in range(1, len(mlp_spec)):
+ layers.append(
+ nn.Conv2d(mlp_spec[i - 1], mlp_spec[i], kernel_size=1, bias=not bn)
+ )
+ if bn:
+ layers.append(nn.BatchNorm2d(mlp_spec[i]))
+ layers.append(nn.ReLU(True))
+
+ return nn.Sequential(*layers)
+
+
+class _PointnetSAModuleBase(nn.Module):
+ def __init__(self):
+ super(_PointnetSAModuleBase, self).__init__()
+ self.npoint = None
+ self.groupers = None
+ self.mlps = None
+
+ def forward(
+ self, xyz: torch.Tensor, features: Optional[torch.Tensor]
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ r"""
+ Parameters
+ ----------
+ xyz : torch.Tensor
+ (B, N, 3) tensor of the xyz coordinates of the features
+ features : torch.Tensor
+ (B, C, N) tensor of the descriptors of the the features
+
+ Returns
+ -------
+ new_xyz : torch.Tensor
+ (B, npoint, 3) tensor of the new features' xyz
+ new_features : torch.Tensor
+ (B, \sum_k(mlps[k][-1]), npoint) tensor of the new_features descriptors
+ """
+
+ new_features_list = []
+
+ xyz_flipped = xyz.transpose(1, 2).contiguous()
+ new_xyz = (
+ pointnet2_utils.gather_operation(
+ xyz_flipped, pointnet2_utils.furthest_point_sample(xyz, self.npoint)
+ )
+ .transpose(1, 2)
+ .contiguous()
+ if self.npoint is not None
+ else None
+ )
+
+ for i in range(len(self.groupers)):
+ new_features = self.groupers[i](
+ xyz, new_xyz, features
+ ) # (B, C, npoint, nsample)
+
+ new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample)
+ new_features = F.max_pool2d(
+ new_features, kernel_size=[1, new_features.size(3)]
+ ) # (B, mlp[-1], npoint, 1)
+ new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint)
+
+ new_features_list.append(new_features)
+
+ return new_xyz, torch.cat(new_features_list, dim=1)
+
+
+class PointnetSAModuleMSG(_PointnetSAModuleBase):
+ r"""Pointnet set abstrction layer with multiscale grouping
+
+ Parameters
+ ----------
+ npoint : int
+ Number of features
+ radii : list of float32
+ list of radii to group with
+ nsamples : list of int32
+ Number of samples in each ball query
+ mlps : list of list of int32
+ Spec of the pointnet before the global max_pool for each scale
+ bn : bool
+ Use batchnorm
+ """
+
+ def __init__(self, npoint, radii, nsamples, mlps, bn=True, use_xyz=True):
+ # type: (PointnetSAModuleMSG, int, List[float], List[int], List[List[int]], bool, bool) -> None
+ super(PointnetSAModuleMSG, self).__init__()
+
+ assert len(radii) == len(nsamples) == len(mlps)
+
+ self.npoint = npoint
+ self.groupers = nn.ModuleList()
+ self.mlps = nn.ModuleList()
+ for i in range(len(radii)):
+ radius = radii[i]
+ nsample = nsamples[i]
+ self.groupers.append(
+ pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz)
+ if npoint is not None
+ else pointnet2_utils.GroupAll(use_xyz)
+ )
+ mlp_spec = mlps[i]
+ if use_xyz:
+ mlp_spec[0] += 3
+
+ self.mlps.append(build_shared_mlp(mlp_spec, bn))
+
+
+class PointnetSAModule(PointnetSAModuleMSG):
+ r"""Pointnet set abstrction layer
+
+ Parameters
+ ----------
+ npoint : int
+ Number of features
+ radius : float
+ Radius of ball
+ nsample : int
+ Number of samples in the ball query
+ mlp : list
+ Spec of the pointnet before the global max_pool
+ bn : bool
+ Use batchnorm
+ """
+
+ def __init__(
+ self, mlp, npoint=None, radius=None, nsample=None, bn=True, use_xyz=True
+ ):
+ # type: (PointnetSAModule, List[int], int, float, int, bool, bool) -> None
+ super(PointnetSAModule, self).__init__(
+ mlps=[mlp],
+ npoint=npoint,
+ radii=[radius],
+ nsamples=[nsample],
+ bn=bn,
+ use_xyz=use_xyz,
+ )
+
+
+class PointnetFPModule(nn.Module):
+ r"""Propigates the features of one set to another
+
+ Parameters
+ ----------
+ mlp : list
+ Pointnet module parameters
+ bn : bool
+ Use batchnorm
+ """
+
+ def __init__(self, mlp, bn=True):
+ # type: (PointnetFPModule, List[int], bool) -> None
+ super(PointnetFPModule, self).__init__()
+ self.mlp = build_shared_mlp(mlp, bn=bn)
+
+ def forward(self, unknown, known, unknow_feats, known_feats):
+ # type: (PointnetFPModule, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor
+ r"""
+ Parameters
+ ----------
+ unknown : torch.Tensor
+ (B, n, 3) tensor of the xyz positions of the unknown features
+ known : torch.Tensor
+ (B, m, 3) tensor of the xyz positions of the known features
+ unknow_feats : torch.Tensor
+ (B, C1, n) tensor of the features to be propigated to
+ known_feats : torch.Tensor
+ (B, C2, m) tensor of features to be propigated
+
+ Returns
+ -------
+ new_features : torch.Tensor
+ (B, mlp[-1], n) tensor of the features of the unknown features
+ """
+
+ if known is not None:
+ dist, idx = pointnet2_utils.three_nn(unknown, known)
+ dist_recip = 1.0 / (dist + 1e-8)
+ norm = torch.sum(dist_recip, dim=2, keepdim=True)
+ weight = dist_recip / norm
+
+ interpolated_feats = pointnet2_utils.three_interpolate(
+ known_feats, idx, weight
+ )
+ else:
+ interpolated_feats = known_feats.expand(
+ *(known_feats.size()[0:2] + [unknown.size(1)])
+ )
+
+ if unknow_feats is not None:
+ new_features = torch.cat(
+ [interpolated_feats, unknow_feats], dim=1
+ ) # (B, C2 + C1, n)
+ else:
+ new_features = interpolated_feats
+
+ new_features = new_features.unsqueeze(-1)
+ new_features = self.mlp(new_features)
+
+ return new_features.squeeze(-1)
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/pointnet2_utils.py b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/pointnet2_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..150fcccade21001971a76e6c3628963972305739
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/pointnet2_ops/pointnet2_utils.py
@@ -0,0 +1,379 @@
+import torch
+import torch.nn as nn
+import warnings
+from torch.autograd import Function
+from typing import *
+
+try:
+ import pointnet2_ops._ext as _ext
+except ImportError:
+ from torch.utils.cpp_extension import load
+ import glob
+ import os.path as osp
+ import os
+
+ warnings.warn("Unable to load pointnet2_ops cpp extension. JIT Compiling.")
+
+ _ext_src_root = osp.join(osp.dirname(__file__), "_ext-src")
+ _ext_sources = glob.glob(osp.join(_ext_src_root, "src", "*.cpp")) + glob.glob(
+ osp.join(_ext_src_root, "src", "*.cu")
+ )
+ _ext_headers = glob.glob(osp.join(_ext_src_root, "include", "*"))
+
+ os.environ["TORCH_CUDA_ARCH_LIST"] = "3.7+PTX;5.0;6.0;6.1;6.2;7.0;7.5"
+ _ext = load(
+ "_ext",
+ sources=_ext_sources,
+ extra_include_paths=[osp.join(_ext_src_root, "include")],
+ extra_cflags=["-O3"],
+ extra_cuda_cflags=["-O3", "-Xfatbin", "-compress-all"],
+ with_cuda=True,
+ )
+
+
+class FurthestPointSampling(Function):
+ @staticmethod
+ def forward(ctx, xyz, npoint):
+ # type: (Any, torch.Tensor, int) -> torch.Tensor
+ r"""
+ Uses iterative furthest point sampling to select a set of npoint features that have the largest
+ minimum distance
+
+ Parameters
+ ----------
+ xyz : torch.Tensor
+ (B, N, 3) tensor where N > npoint
+ npoint : int32
+ number of features in the sampled set
+
+ Returns
+ -------
+ torch.Tensor
+ (B, npoint) tensor containing the set
+ """
+ out = _ext.furthest_point_sampling(xyz, npoint)
+
+ ctx.mark_non_differentiable(out)
+
+ return out
+
+ @staticmethod
+ def backward(ctx, grad_out):
+ return ()
+
+
+furthest_point_sample = FurthestPointSampling.apply
+
+
+class GatherOperation(Function):
+ @staticmethod
+ def forward(ctx, features, idx):
+ # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor
+ r"""
+
+ Parameters
+ ----------
+ features : torch.Tensor
+ (B, C, N) tensor
+
+ idx : torch.Tensor
+ (B, npoint) tensor of the features to gather
+
+ Returns
+ -------
+ torch.Tensor
+ (B, C, npoint) tensor
+ """
+
+ ctx.save_for_backward(idx, features)
+
+ return _ext.gather_points(features, idx)
+
+ @staticmethod
+ def backward(ctx, grad_out):
+ idx, features = ctx.saved_tensors
+ N = features.size(2)
+
+ grad_features = _ext.gather_points_grad(grad_out.contiguous(), idx, N)
+ return grad_features, None
+
+
+gather_operation = GatherOperation.apply
+
+
+class ThreeNN(Function):
+ @staticmethod
+ def forward(ctx, unknown, known):
+ # type: (Any, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]
+ r"""
+ Find the three nearest neighbors of unknown in known
+ Parameters
+ ----------
+ unknown : torch.Tensor
+ (B, n, 3) tensor of known features
+ known : torch.Tensor
+ (B, m, 3) tensor of unknown features
+
+ Returns
+ -------
+ dist : torch.Tensor
+ (B, n, 3) l2 distance to the three nearest neighbors
+ idx : torch.Tensor
+ (B, n, 3) index of 3 nearest neighbors
+ """
+ dist2, idx = _ext.three_nn(unknown, known)
+ dist = torch.sqrt(dist2)
+
+ ctx.mark_non_differentiable(dist, idx)
+
+ return dist, idx
+
+ @staticmethod
+ def backward(ctx, grad_dist, grad_idx):
+ return ()
+
+
+three_nn = ThreeNN.apply
+
+
+class ThreeInterpolate(Function):
+ @staticmethod
+ def forward(ctx, features, idx, weight):
+ # type(Any, torch.Tensor, torch.Tensor, torch.Tensor) -> Torch.Tensor
+ r"""
+ Performs weight linear interpolation on 3 features
+ Parameters
+ ----------
+ features : torch.Tensor
+ (B, c, m) Features descriptors to be interpolated from
+ idx : torch.Tensor
+ (B, n, 3) three nearest neighbors of the target features in features
+ weight : torch.Tensor
+ (B, n, 3) weights
+
+ Returns
+ -------
+ torch.Tensor
+ (B, c, n) tensor of the interpolated features
+ """
+ ctx.save_for_backward(idx, weight, features)
+
+ return _ext.three_interpolate(features, idx, weight)
+
+ @staticmethod
+ def backward(ctx, grad_out):
+ # type: (Any, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
+ r"""
+ Parameters
+ ----------
+ grad_out : torch.Tensor
+ (B, c, n) tensor with gradients of ouputs
+
+ Returns
+ -------
+ grad_features : torch.Tensor
+ (B, c, m) tensor with gradients of features
+
+ None
+
+ None
+ """
+ idx, weight, features = ctx.saved_tensors
+ m = features.size(2)
+
+ grad_features = _ext.three_interpolate_grad(
+ grad_out.contiguous(), idx, weight, m
+ )
+
+ return grad_features, torch.zeros_like(idx), torch.zeros_like(weight)
+
+
+three_interpolate = ThreeInterpolate.apply
+
+
+class GroupingOperation(Function):
+ @staticmethod
+ def forward(ctx, features, idx):
+ # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor
+ r"""
+
+ Parameters
+ ----------
+ features : torch.Tensor
+ (B, C, N) tensor of features to group
+ idx : torch.Tensor
+ (B, npoint, nsample) tensor containing the indicies of features to group with
+
+ Returns
+ -------
+ torch.Tensor
+ (B, C, npoint, nsample) tensor
+ """
+ ctx.save_for_backward(idx, features)
+
+ return _ext.group_points(features, idx)
+
+ @staticmethod
+ def backward(ctx, grad_out):
+ # type: (Any, torch.tensor) -> Tuple[torch.Tensor, torch.Tensor]
+ r"""
+
+ Parameters
+ ----------
+ grad_out : torch.Tensor
+ (B, C, npoint, nsample) tensor of the gradients of the output from forward
+
+ Returns
+ -------
+ torch.Tensor
+ (B, C, N) gradient of the features
+ None
+ """
+ idx, features = ctx.saved_tensors
+ N = features.size(2)
+
+ grad_features = _ext.group_points_grad(grad_out.contiguous(), idx, N)
+
+ return grad_features, torch.zeros_like(idx)
+
+
+grouping_operation = GroupingOperation.apply
+
+
+class BallQuery(Function):
+ @staticmethod
+ def forward(ctx, radius, nsample, xyz, new_xyz):
+ # type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor
+ r"""
+
+ Parameters
+ ----------
+ radius : float
+ radius of the balls
+ nsample : int
+ maximum number of features in the balls
+ xyz : torch.Tensor
+ (B, N, 3) xyz coordinates of the features
+ new_xyz : torch.Tensor
+ (B, npoint, 3) centers of the ball query
+
+ Returns
+ -------
+ torch.Tensor
+ (B, npoint, nsample) tensor with the indicies of the features that form the query balls
+ """
+ output = _ext.ball_query(new_xyz, xyz, radius, nsample)
+
+ ctx.mark_non_differentiable(output)
+
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_out):
+ return ()
+
+
+ball_query = BallQuery.apply
+
+
+class QueryAndGroup(nn.Module):
+ r"""
+ Groups with a ball query of radius
+
+ Parameters
+ ---------
+ radius : float32
+ Radius of ball
+ nsample : int32
+ Maximum number of features to gather in the ball
+ """
+
+ def __init__(self, radius, nsample, use_xyz=True):
+ # type: (QueryAndGroup, float, int, bool) -> None
+ super(QueryAndGroup, self).__init__()
+ self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz
+
+ def forward(self, xyz, new_xyz, features=None):
+ # type: (QueryAndGroup, torch.Tensor. torch.Tensor, torch.Tensor) -> Tuple[Torch.Tensor]
+ r"""
+ Parameters
+ ----------
+ xyz : torch.Tensor
+ xyz coordinates of the features (B, N, 3)
+ new_xyz : torch.Tensor
+ centriods (B, npoint, 3)
+ features : torch.Tensor
+ Descriptors of the features (B, C, N)
+
+ Returns
+ -------
+ new_features : torch.Tensor
+ (B, 3 + C, npoint, nsample) tensor
+ """
+
+ idx = ball_query(self.radius, self.nsample, xyz, new_xyz)
+ xyz_trans = xyz.transpose(1, 2).contiguous()
+ grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample)
+ grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1)
+
+ if features is not None:
+ grouped_features = grouping_operation(features, idx)
+ if self.use_xyz:
+ new_features = torch.cat(
+ [grouped_xyz, grouped_features], dim=1
+ ) # (B, C + 3, npoint, nsample)
+ else:
+ new_features = grouped_features
+ else:
+ assert (
+ self.use_xyz
+ ), "Cannot have not features and not use xyz as a feature!"
+ new_features = grouped_xyz
+
+ return new_features
+
+
+class GroupAll(nn.Module):
+ r"""
+ Groups all features
+
+ Parameters
+ ---------
+ """
+
+ def __init__(self, use_xyz=True):
+ # type: (GroupAll, bool) -> None
+ super(GroupAll, self).__init__()
+ self.use_xyz = use_xyz
+
+ def forward(self, xyz, new_xyz, features=None):
+ # type: (GroupAll, torch.Tensor, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor]
+ r"""
+ Parameters
+ ----------
+ xyz : torch.Tensor
+ xyz coordinates of the features (B, N, 3)
+ new_xyz : torch.Tensor
+ Ignored
+ features : torch.Tensor
+ Descriptors of the features (B, C, N)
+
+ Returns
+ -------
+ new_features : torch.Tensor
+ (B, C + 3, 1, N) tensor
+ """
+
+ grouped_xyz = xyz.transpose(1, 2).unsqueeze(2)
+ if features is not None:
+ grouped_features = features.unsqueeze(2)
+ if self.use_xyz:
+ new_features = torch.cat(
+ [grouped_xyz, grouped_features], dim=1
+ ) # (B, 3 + C, 1, N)
+ else:
+ new_features = grouped_features
+ else:
+ new_features = grouped_xyz
+
+ return new_features
diff --git a/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/setup.py b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..faf715418ebbab275ae08b253607188a270acea0
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pointnet2_ops_lib/setup.py
@@ -0,0 +1,39 @@
+import glob
+import os
+import os.path as osp
+
+from setuptools import find_packages, setup
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension
+
+this_dir = osp.dirname(osp.abspath(__file__))
+_ext_src_root = osp.join("pointnet2_ops", "_ext-src")
+_ext_sources = glob.glob(osp.join(_ext_src_root, "src", "*.cpp")) + glob.glob(
+ osp.join(_ext_src_root, "src", "*.cu")
+)
+_ext_headers = glob.glob(osp.join(_ext_src_root, "include", "*"))
+
+requirements = ["torch>=1.4"]
+
+exec(open(osp.join("pointnet2_ops", "_version.py")).read())
+
+os.environ["TORCH_CUDA_ARCH_LIST"] = "3.7+PTX;5.0;6.0;6.1;6.2;7.0;7.5"
+setup(
+ name="pointnet2_ops",
+ version=__version__,
+ author="Erik Wijmans",
+ packages=find_packages(),
+ install_requires=requirements,
+ ext_modules=[
+ CUDAExtension(
+ name="pointnet2_ops._ext",
+ sources=_ext_sources,
+ extra_compile_args={
+ "cxx": ["-O3"],
+ "nvcc": ["-O3", "-Xfatbin", "-compress-all"],
+ },
+ include_dirs=[osp.join(this_dir, _ext_src_root, "include")],
+ )
+ ],
+ cmdclass={"build_ext": BuildExtension},
+ include_package_data=True,
+)
diff --git a/pointnet++/Pointnet2_PyTorch/pyproject.toml b/pointnet++/Pointnet2_PyTorch/pyproject.toml
new file mode 100644
index 0000000000000000000000000000000000000000..c1652f11b14334b103b534a3205a79183ccdc9d6
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/pyproject.toml
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:eabba26c63405dc2e5dbabff6d4eb2222a7fe78755481432932abc6439281648
+size 646
diff --git a/pointnet++/Pointnet2_PyTorch/requirements.txt b/pointnet++/Pointnet2_PyTorch/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..ec6ea1605315b4bb9f418a3d8250078b2024101b
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/requirements.txt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3f5092fd619265614bf642ba868ea117621da6bf3f28dd477b1ab2129d6f77bb
+size 98
diff --git a/pointnet++/Pointnet2_PyTorch/setup.py b/pointnet++/Pointnet2_PyTorch/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..fac260ad71f5640ab7da06ecad0daa6be06e1dfb
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/setup.py
@@ -0,0 +1,16 @@
+import os.path as osp
+
+from setuptools import find_packages, setup
+
+requirements = ["hydra-core==0.11.3", "pytorch-lightning==0.7.1"]
+
+
+exec(open(osp.join("pointnet2", "_version.py")).read())
+
+setup(
+ name="pointnet2",
+ version=__version__,
+ author="Erik Wijmans",
+ packages=find_packages(),
+ install_requires=requirements,
+)
diff --git a/pointnet++/Pointnet2_PyTorch/tests/conftest.py b/pointnet++/Pointnet2_PyTorch/tests/conftest.py
new file mode 100644
index 0000000000000000000000000000000000000000..94adf6b9cb063ce55e459861e9ccd4ab2a591d4b
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/tests/conftest.py
@@ -0,0 +1,60 @@
+import os
+
+import hydra
+import hydra.experimental
+import numpy as np
+import pytest
+import torch
+
+pytest_plugins = ["helpers_namespace"]
+
+hydra.experimental.initialize(
+ os.path.join(os.path.dirname(__file__), "../pointnet2/config")
+)
+
+
+@pytest.helpers.register
+def build_cfg(overrides=[]):
+ return hydra.experimental.compose("config.yaml", overrides)
+
+
+@pytest.helpers.register
+def get_model(overrides=[]):
+ cfg = build_cfg(overrides)
+ return hydra.utils.instantiate(cfg.task_model, cfg)
+
+
+def _test_loop(model, inputs, labels):
+ optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
+
+ prev_loss = 1e10
+ for _ in range(5):
+ optimizer.zero_grad()
+ res = model.training_step((inputs, labels), None)
+ loss = res["loss"]
+ loss.backward()
+ optimizer.step()
+
+ assert loss.item() < prev_loss + 1.0, "Loss spiked upwards"
+
+ prev_loss = loss.item()
+
+
+@pytest.helpers.register
+def cls_test(model):
+ B, N = 4, 2048
+ inputs = torch.randn(B, N, 6).cuda()
+ labels = torch.from_numpy(np.random.randint(0, 3, size=B)).cuda()
+ model.cuda()
+
+ _test_loop(model, inputs, labels)
+
+
+@pytest.helpers.register
+def semseg_test(model):
+ B, N = 4, 2048
+ inputs = torch.randn(B, N, 9).cuda()
+ labels = torch.from_numpy(np.random.randint(0, 3, size=B * N)).view(B, N).cuda()
+ model.cuda()
+
+ _test_loop(model, inputs, labels)
diff --git a/pointnet++/Pointnet2_PyTorch/tests/test_cls.py b/pointnet++/Pointnet2_PyTorch/tests/test_cls.py
new file mode 100644
index 0000000000000000000000000000000000000000..34f7bef16d1a4e0a5ea58bf56f64011b3885eee4
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/tests/test_cls.py
@@ -0,0 +1,10 @@
+import pytest
+
+
+@pytest.mark.parametrize("use_xyz", ["True", "False"])
+@pytest.mark.parametrize("model", ["ssg", "msg"])
+def test_cls(use_xyz, model):
+ model = pytest.helpers.get_model(
+ ["task=cls", f"model={model}", f"model.use_xyz={use_xyz}"]
+ )
+ pytest.helpers.cls_test(model)
diff --git a/pointnet++/Pointnet2_PyTorch/tests/test_semseg.py b/pointnet++/Pointnet2_PyTorch/tests/test_semseg.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0a3cf902f955bb5b961d9ad236d27e21a84c206
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/tests/test_semseg.py
@@ -0,0 +1,10 @@
+import pytest
+
+
+@pytest.mark.parametrize("use_xyz", ["True", "False"])
+@pytest.mark.parametrize("model", ["ssg", "msg"])
+def test_semseg(use_xyz, model):
+ model = pytest.helpers.get_model(
+ ["task=semseg", f"model={model}", f"model.use_xyz={use_xyz}"]
+ )
+ pytest.helpers.semseg_test(model)
diff --git a/pointnet++/Pointnet2_PyTorch/tox.ini b/pointnet++/Pointnet2_PyTorch/tox.ini
new file mode 100644
index 0000000000000000000000000000000000000000..deeca6f39668ea9d06f891dc5625431a240e99ab
--- /dev/null
+++ b/pointnet++/Pointnet2_PyTorch/tox.ini
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:484b1f1c953193599a7a48fabf552432df0294e19e7dd56faea29e31dc2f50af
+size 206
diff --git a/pretrained_weight/Uni3D_PC_encoder/.gitattributes b/pretrained_weight/Uni3D_PC_encoder/.gitattributes
new file mode 100644
index 0000000000000000000000000000000000000000..489d0cf0f2b10ab1ab15b57e9cb384b3b7a0abdf
--- /dev/null
+++ b/pretrained_weight/Uni3D_PC_encoder/.gitattributes
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:11ad7efa24975ee4b0c3c3a38ed18737f0658a5f75a0a96787b576a78a023361
+size 1519
diff --git a/pretrained_weight/Uni3D_PC_encoder/modelzoo/uni3d-base/model.pt b/pretrained_weight/Uni3D_PC_encoder/modelzoo/uni3d-base/model.pt
new file mode 100644
index 0000000000000000000000000000000000000000..fa1e5f9de9c6d172a196cec21beb126e227c9f34
--- /dev/null
+++ b/pretrained_weight/Uni3D_PC_encoder/modelzoo/uni3d-base/model.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bb533f918957e939f0e500376bef5166437883443da10fd9adccde0d1b512fba
+size 178008217
diff --git a/pretrained_weight/Uni3D_PC_encoder/modelzoo/uni3d-giant/model.pt b/pretrained_weight/Uni3D_PC_encoder/modelzoo/uni3d-giant/model.pt
new file mode 100644
index 0000000000000000000000000000000000000000..6cbadc0ffff5226997e50f919b8b5f53904800ae
--- /dev/null
+++ b/pretrained_weight/Uni3D_PC_encoder/modelzoo/uni3d-giant/model.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f0d8ecd047935b037fcc40bdfd66bef632631a968c37fef0f90f1351641573c3
+size 2034909361
diff --git a/pretrained_weight/Uni3D_PC_encoder/modelzoo/uni3d-large/model.pt b/pretrained_weight/Uni3D_PC_encoder/modelzoo/uni3d-large/model.pt
new file mode 100644
index 0000000000000000000000000000000000000000..5cffb62b626cd7938fe60b0a9d121943856268f5
--- /dev/null
+++ b/pretrained_weight/Uni3D_PC_encoder/modelzoo/uni3d-large/model.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:540a3cafac52c251cbda1844d109bc598f142b4f5a2118db87cd722b01800c20
+size 614861649
diff --git a/pretrained_weight/Uni3D_PC_encoder/modelzoo/uni3d-small/model.pt b/pretrained_weight/Uni3D_PC_encoder/modelzoo/uni3d-small/model.pt
new file mode 100644
index 0000000000000000000000000000000000000000..4bca5757ce2b2ec9abddd3e8630b6373452ea04b
--- /dev/null
+++ b/pretrained_weight/Uni3D_PC_encoder/modelzoo/uni3d-small/model.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9a6342fda2245f00fa0200eb7fc667ee67fee9a953c82961e04d242130a3350f
+size 45713996
diff --git a/pretrained_weight/Uni3D_PC_encoder/modelzoo/uni3d-tiny/model.pt b/pretrained_weight/Uni3D_PC_encoder/modelzoo/uni3d-tiny/model.pt
new file mode 100644
index 0000000000000000000000000000000000000000..135103b3b6f8cf667b5fde53cd050f58fe316e78
--- /dev/null
+++ b/pretrained_weight/Uni3D_PC_encoder/modelzoo/uni3d-tiny/model.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6ff7e821e6997cb814b529e6dc3b224ad8349500471c8517e491685a718688e0
+size 12835148
diff --git a/pretrained_weight/clip_used_in_Uni3D/.gitattributes b/pretrained_weight/clip_used_in_Uni3D/.gitattributes
new file mode 100644
index 0000000000000000000000000000000000000000..78596b12a8a92779be9564dc3276b38fdcb7008a
--- /dev/null
+++ b/pretrained_weight/clip_used_in_Uni3D/.gitattributes
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:db7c0371f46f0840b8f25794c1e3321c9b5820a8cb6ba9694a46fc64b8fae5a6
+size 1477
diff --git a/pretrained_weight/clip_used_in_Uni3D/README.md b/pretrained_weight/clip_used_in_Uni3D/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..b3a85054346c09db40f959c831dfa4e42f3881f3
--- /dev/null
+++ b/pretrained_weight/clip_used_in_Uni3D/README.md
@@ -0,0 +1,8 @@
+---
+license: mit
+library_name: open_clip
+tags:
+- zero-shot-image-classification
+- clip
+---
+# Model card for eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k
diff --git a/pretrained_weight/clip_used_in_Uni3D/merges.txt b/pretrained_weight/clip_used_in_Uni3D/merges.txt
new file mode 100644
index 0000000000000000000000000000000000000000..68d6d97fa98260a11c613ccb92cc0d02c4524b55
--- /dev/null
+++ b/pretrained_weight/clip_used_in_Uni3D/merges.txt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f526393189112391ce6f9795d4695f704121ce452c3aad1f5335cc41337eba85
+size 524657
diff --git a/pretrained_weight/clip_used_in_Uni3D/open_clip_config.json b/pretrained_weight/clip_used_in_Uni3D/open_clip_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..75a2fb9ccba8fe57d516e07095b1a1a89ede300a
--- /dev/null
+++ b/pretrained_weight/clip_used_in_Uni3D/open_clip_config.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7dd195154d6f3c5d1d00ef8a80f84dd54c6384e34e3df1dbe71b3142c505a2ac
+size 584
diff --git a/pretrained_weight/clip_used_in_Uni3D/open_clip_pytorch_model.bin b/pretrained_weight/clip_used_in_Uni3D/open_clip_pytorch_model.bin
new file mode 100644
index 0000000000000000000000000000000000000000..5cf543e6da5253fddc49d18036c196ea3f2249e7
--- /dev/null
+++ b/pretrained_weight/clip_used_in_Uni3D/open_clip_pytorch_model.bin
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1e215717b06f1cb61315f3bda24bfeacb8a5d6ad04c1ea5b58693b78f0220541
+size 10090277147
diff --git a/pretrained_weight/clip_used_in_Uni3D/special_tokens_map.json b/pretrained_weight/clip_used_in_Uni3D/special_tokens_map.json
new file mode 100644
index 0000000000000000000000000000000000000000..4c7213d97d558ccf75ed47851714b2a0e7066dd7
--- /dev/null
+++ b/pretrained_weight/clip_used_in_Uni3D/special_tokens_map.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c4864a9376a8401918425bed71fc14fc0e81f9b59ec45c1cf96cccb2df508eac
+size 472
diff --git a/pretrained_weight/clip_used_in_Uni3D/tokenizer.json b/pretrained_weight/clip_used_in_Uni3D/tokenizer.json
new file mode 100644
index 0000000000000000000000000000000000000000..3f9bdd75e6c854854017ca2eb114528624846085
--- /dev/null
+++ b/pretrained_weight/clip_used_in_Uni3D/tokenizer.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d8b124290bc4bcd18cd3f72747f525e2a1d8c266cf3089e52da38ee417564ac5
+size 2224053
diff --git a/pretrained_weight/clip_used_in_Uni3D/tokenizer_config.json b/pretrained_weight/clip_used_in_Uni3D/tokenizer_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..7aeffe852d0161afb2345f7de55f1310c50f168d
--- /dev/null
+++ b/pretrained_weight/clip_used_in_Uni3D/tokenizer_config.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fdd524218f0554acb77cdecf36807c01f4d1ad2574fd6345d0bdef48873c3d14
+size 755
diff --git a/pretrained_weight/clip_used_in_Uni3D/vocab.json b/pretrained_weight/clip_used_in_Uni3D/vocab.json
new file mode 100644
index 0000000000000000000000000000000000000000..85384c5941a6cd82152114823d430ec656b9d9ff
--- /dev/null
+++ b/pretrained_weight/clip_used_in_Uni3D/vocab.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5047b556ce86ccaf6aa22b3ffccfc52d391ea4accdab9c2f2407da5b742d4363
+size 862328
diff --git a/pretrained_weight/eval_model_weight/all-mpnet-base-v2/.gitattributes b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/.gitattributes
new file mode 100644
index 0000000000000000000000000000000000000000..a3b2bfd5ca8cc2dd5585d06eca159ec88bf61e97
--- /dev/null
+++ b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/.gitattributes
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:98ccb431c012ebfe976280fbd45aea4cec7409935868ccecf3954370f96732a1
+size 1229
diff --git a/pretrained_weight/eval_model_weight/all-mpnet-base-v2/1_Pooling/config.json b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/1_Pooling/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..bc7312058a481e74a0a42d5e428263980eba874a
--- /dev/null
+++ b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/1_Pooling/config.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a37f83ada23e7887be6b88f4998927dbeac0038af301553c7cd5461413bf1a56
+size 190
diff --git a/pretrained_weight/eval_model_weight/all-mpnet-base-v2/README.md b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..0a624adbb92ec093d62f228b9e9d9e9ced190a79
--- /dev/null
+++ b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/README.md
@@ -0,0 +1,177 @@
+---
+language: en
+license: apache-2.0
+library_name: sentence-transformers
+tags:
+- sentence-transformers
+- feature-extraction
+- sentence-similarity
+- transformers
+datasets:
+- s2orc
+- flax-sentence-embeddings/stackexchange_xml
+- ms_marco
+- gooaq
+- yahoo_answers_topics
+- code_search_net
+- search_qa
+- eli5
+- snli
+- multi_nli
+- wikihow
+- natural_questions
+- trivia_qa
+- embedding-data/sentence-compression
+- embedding-data/flickr30k-captions
+- embedding-data/altlex
+- embedding-data/simple-wiki
+- embedding-data/QQP
+- embedding-data/SPECTER
+- embedding-data/PAQ_pairs
+- embedding-data/WikiAnswers
+pipeline_tag: sentence-similarity
+---
+
+
+# all-mpnet-base-v2
+This is a [sentence-transformers](https://www.SBERT.net) model: It maps sentences & paragraphs to a 768 dimensional dense vector space and can be used for tasks like clustering or semantic search.
+
+## Usage (Sentence-Transformers)
+Using this model becomes easy when you have [sentence-transformers](https://www.SBERT.net) installed:
+
+```
+pip install -U sentence-transformers
+```
+
+Then you can use the model like this:
+```python
+from sentence_transformers import SentenceTransformer
+sentences = ["This is an example sentence", "Each sentence is converted"]
+
+model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
+embeddings = model.encode(sentences)
+print(embeddings)
+```
+
+## Usage (HuggingFace Transformers)
+Without [sentence-transformers](https://www.SBERT.net), you can use the model like this: First, you pass your input through the transformer model, then you have to apply the right pooling-operation on-top of the contextualized word embeddings.
+
+```python
+from transformers import AutoTokenizer, AutoModel
+import torch
+import torch.nn.functional as F
+
+#Mean Pooling - Take attention mask into account for correct averaging
+def mean_pooling(model_output, attention_mask):
+ token_embeddings = model_output[0] #First element of model_output contains all token embeddings
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
+
+
+# Sentences we want sentence embeddings for
+sentences = ['This is an example sentence', 'Each sentence is converted']
+
+# Load model from HuggingFace Hub
+tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')
+model = AutoModel.from_pretrained('sentence-transformers/all-mpnet-base-v2')
+
+# Tokenize sentences
+encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
+
+# Compute token embeddings
+with torch.no_grad():
+ model_output = model(**encoded_input)
+
+# Perform pooling
+sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
+
+# Normalize embeddings
+sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
+
+print("Sentence embeddings:")
+print(sentence_embeddings)
+```
+
+## Evaluation Results
+
+For an automated evaluation of this model, see the *Sentence Embeddings Benchmark*: [https://seb.sbert.net](https://seb.sbert.net?model_name=sentence-transformers/all-mpnet-base-v2)
+
+------
+
+## Background
+
+The project aims to train sentence embedding models on very large sentence level datasets using a self-supervised
+contrastive learning objective. We used the pretrained [`microsoft/mpnet-base`](https://huggingface.co/microsoft/mpnet-base) model and fine-tuned in on a
+1B sentence pairs dataset. We use a contrastive learning objective: given a sentence from the pair, the model should predict which out of a set of randomly sampled other sentences, was actually paired with it in our dataset.
+
+We developped this model during the
+[Community week using JAX/Flax for NLP & CV](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/7104),
+organized by Hugging Face. We developped this model as part of the project:
+[Train the Best Sentence Embedding Model Ever with 1B Training Pairs](https://discuss.huggingface.co/t/train-the-best-sentence-embedding-model-ever-with-1b-training-pairs/7354). We benefited from efficient hardware infrastructure to run the project: 7 TPUs v3-8, as well as intervention from Googles Flax, JAX, and Cloud team member about efficient deep learning frameworks.
+
+## Intended uses
+
+Our model is intented to be used as a sentence and short paragraph encoder. Given an input text, it ouptuts a vector which captures
+the semantic information. The sentence vector may be used for information retrieval, clustering or sentence similarity tasks.
+
+By default, input text longer than 384 word pieces is truncated.
+
+
+## Training procedure
+
+### Pre-training
+
+We use the pretrained [`microsoft/mpnet-base`](https://huggingface.co/microsoft/mpnet-base) model. Please refer to the model card for more detailed information about the pre-training procedure.
+
+### Fine-tuning
+
+We fine-tune the model using a contrastive objective. Formally, we compute the cosine similarity from each possible sentence pairs from the batch.
+We then apply the cross entropy loss by comparing with true pairs.
+
+#### Hyper parameters
+
+We trained ou model on a TPU v3-8. We train the model during 100k steps using a batch size of 1024 (128 per TPU core).
+We use a learning rate warm up of 500. The sequence length was limited to 128 tokens. We used the AdamW optimizer with
+a 2e-5 learning rate. The full training script is accessible in this current repository: `train_script.py`.
+
+#### Training data
+
+We use the concatenation from multiple datasets to fine-tune our model. The total number of sentence pairs is above 1 billion sentences.
+We sampled each dataset given a weighted probability which configuration is detailed in the `data_config.json` file.
+
+
+| Dataset | Paper | Number of training tuples |
+|--------------------------------------------------------|:----------------------------------------:|:--------------------------:|
+| [Reddit comments (2015-2018)](https://github.com/PolyAI-LDN/conversational-datasets/tree/master/reddit) | [paper](https://arxiv.org/abs/1904.06472) | 726,484,430 |
+| [S2ORC](https://github.com/allenai/s2orc) Citation pairs (Abstracts) | [paper](https://aclanthology.org/2020.acl-main.447/) | 116,288,806 |
+| [WikiAnswers](https://github.com/afader/oqa#wikianswers-corpus) Duplicate question pairs | [paper](https://doi.org/10.1145/2623330.2623677) | 77,427,422 |
+| [PAQ](https://github.com/facebookresearch/PAQ) (Question, Answer) pairs | [paper](https://arxiv.org/abs/2102.07033) | 64,371,441 |
+| [S2ORC](https://github.com/allenai/s2orc) Citation pairs (Titles) | [paper](https://aclanthology.org/2020.acl-main.447/) | 52,603,982 |
+| [S2ORC](https://github.com/allenai/s2orc) (Title, Abstract) | [paper](https://aclanthology.org/2020.acl-main.447/) | 41,769,185 |
+| [Stack Exchange](https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_xml) (Title, Body) pairs | - | 25,316,456 |
+| [Stack Exchange](https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_xml) (Title+Body, Answer) pairs | - | 21,396,559 |
+| [Stack Exchange](https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_xml) (Title, Answer) pairs | - | 21,396,559 |
+| [MS MARCO](https://microsoft.github.io/msmarco/) triplets | [paper](https://doi.org/10.1145/3404835.3462804) | 9,144,553 |
+| [GOOAQ: Open Question Answering with Diverse Answer Types](https://github.com/allenai/gooaq) | [paper](https://arxiv.org/pdf/2104.08727.pdf) | 3,012,496 |
+| [Yahoo Answers](https://www.kaggle.com/soumikrakshit/yahoo-answers-dataset) (Title, Answer) | [paper](https://proceedings.neurips.cc/paper/2015/hash/250cf8b51c773f3f8dc8b4be867a9a02-Abstract.html) | 1,198,260 |
+| [Code Search](https://huggingface.co/datasets/code_search_net) | - | 1,151,414 |
+| [COCO](https://cocodataset.org/#home) Image captions | [paper](https://link.springer.com/chapter/10.1007%2F978-3-319-10602-1_48) | 828,395|
+| [SPECTER](https://github.com/allenai/specter) citation triplets | [paper](https://doi.org/10.18653/v1/2020.acl-main.207) | 684,100 |
+| [Yahoo Answers](https://www.kaggle.com/soumikrakshit/yahoo-answers-dataset) (Question, Answer) | [paper](https://proceedings.neurips.cc/paper/2015/hash/250cf8b51c773f3f8dc8b4be867a9a02-Abstract.html) | 681,164 |
+| [Yahoo Answers](https://www.kaggle.com/soumikrakshit/yahoo-answers-dataset) (Title, Question) | [paper](https://proceedings.neurips.cc/paper/2015/hash/250cf8b51c773f3f8dc8b4be867a9a02-Abstract.html) | 659,896 |
+| [SearchQA](https://huggingface.co/datasets/search_qa) | [paper](https://arxiv.org/abs/1704.05179) | 582,261 |
+| [Eli5](https://huggingface.co/datasets/eli5) | [paper](https://doi.org/10.18653/v1/p19-1346) | 325,475 |
+| [Flickr 30k](https://shannon.cs.illinois.edu/DenotationGraph/) | [paper](https://transacl.org/ojs/index.php/tacl/article/view/229/33) | 317,695 |
+| [Stack Exchange](https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_xml) Duplicate questions (titles) | | 304,525 |
+| AllNLI ([SNLI](https://nlp.stanford.edu/projects/snli/) and [MultiNLI](https://cims.nyu.edu/~sbowman/multinli/) | [paper SNLI](https://doi.org/10.18653/v1/d15-1075), [paper MultiNLI](https://doi.org/10.18653/v1/n18-1101) | 277,230 |
+| [Stack Exchange](https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_xml) Duplicate questions (bodies) | | 250,519 |
+| [Stack Exchange](https://huggingface.co/datasets/flax-sentence-embeddings/stackexchange_xml) Duplicate questions (titles+bodies) | | 250,460 |
+| [Sentence Compression](https://github.com/google-research-datasets/sentence-compression) | [paper](https://www.aclweb.org/anthology/D13-1155/) | 180,000 |
+| [Wikihow](https://github.com/pvl/wikihow_pairs_dataset) | [paper](https://arxiv.org/abs/1810.09305) | 128,542 |
+| [Altlex](https://github.com/chridey/altlex/) | [paper](https://aclanthology.org/P16-1135.pdf) | 112,696 |
+| [Quora Question Triplets](https://quoradata.quora.com/First-Quora-Dataset-Release-Question-Pairs) | - | 103,663 |
+| [Simple Wikipedia](https://cs.pomona.edu/~dkauchak/simplification/) | [paper](https://www.aclweb.org/anthology/P11-2117/) | 102,225 |
+| [Natural Questions (NQ)](https://ai.google.com/research/NaturalQuestions) | [paper](https://transacl.org/ojs/index.php/tacl/article/view/1455) | 100,231 |
+| [SQuAD2.0](https://rajpurkar.github.io/SQuAD-explorer/) | [paper](https://aclanthology.org/P18-2124.pdf) | 87,599 |
+| [TriviaQA](https://huggingface.co/datasets/trivia_qa) | - | 73,346 |
+| **Total** | | **1,170,060,424** |
\ No newline at end of file
diff --git a/pretrained_weight/eval_model_weight/all-mpnet-base-v2/config.json b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..2c338b6a75180db77122b059820d483b07fd7bfe
--- /dev/null
+++ b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/config.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d46a3e04ded82bba22528424480697d394eeda6a27484e08c5bb2bdf5906cfa0
+size 571
diff --git a/pretrained_weight/eval_model_weight/all-mpnet-base-v2/config_sentence_transformers.json b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/config_sentence_transformers.json
new file mode 100644
index 0000000000000000000000000000000000000000..7d444dde4563ae591e58580aa70e31e3a263b658
--- /dev/null
+++ b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/config_sentence_transformers.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:061ca9d39661d6c6d6de5ba27f79a1cd5770ea247f8d46412a68a498dc5ac9f3
+size 116
diff --git a/pretrained_weight/eval_model_weight/all-mpnet-base-v2/data_config.json b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/data_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..24e535805a5c3ad4ef36ec6059a4bea5f67fc676
--- /dev/null
+++ b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/data_config.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:32edcb108fc2516b920734a862ae0692bcae1c5d45d5f8d972cb0d53434a4c54
+size 39265
diff --git a/pretrained_weight/eval_model_weight/all-mpnet-base-v2/model.safetensors b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/model.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..8a4d902ecd2fe30af6f3ccb0d37b7e280992d8ad
--- /dev/null
+++ b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/model.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:78c0197b6159d92658e319bc1d72e4c73a9a03dd03815e70e555c5ef05615658
+size 437971872
diff --git a/pretrained_weight/eval_model_weight/all-mpnet-base-v2/modules.json b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/modules.json
new file mode 100644
index 0000000000000000000000000000000000000000..79350d624ca0fe094b42fff0f1fea2c063a94c57
--- /dev/null
+++ b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/modules.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:84e40c8e006c9b1d6c122e02cba9b02458120b5fb0c87b746c41e0207cf642cf
+size 349
diff --git a/pretrained_weight/eval_model_weight/all-mpnet-base-v2/onnx/model.onnx b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/onnx/model.onnx
new file mode 100644
index 0000000000000000000000000000000000000000..f2c572c400f18dfce7d6dd28788b47c8b6fc8d33
--- /dev/null
+++ b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/onnx/model.onnx
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:74187b16d9c946fea252e120cfd7a12c5779d8b8b86838a2e4c56573c47941bd
+size 435826548
diff --git a/pretrained_weight/eval_model_weight/all-mpnet-base-v2/onnx/model_O1.onnx b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/onnx/model_O1.onnx
new file mode 100644
index 0000000000000000000000000000000000000000..5c8850a2679393189f75be4113fa653791d27cfd
--- /dev/null
+++ b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/onnx/model_O1.onnx
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5c0b47004076ab40bf15a2c52b98a53e985ebb84faaeeb6d2551768f96e384b0
+size 435730180
diff --git a/pretrained_weight/eval_model_weight/all-mpnet-base-v2/onnx/model_O2.onnx b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/onnx/model_O2.onnx
new file mode 100644
index 0000000000000000000000000000000000000000..148b23cd035aa6b6aad9f15e9e71879323f1b8a8
--- /dev/null
+++ b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/onnx/model_O2.onnx
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:14d01256f5f3d2245b15b596173bca4367c9405fde5700dd7fb4e110708c1793
+size 435666661
diff --git a/pretrained_weight/eval_model_weight/all-mpnet-base-v2/onnx/model_O3.onnx b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/onnx/model_O3.onnx
new file mode 100644
index 0000000000000000000000000000000000000000..cb220e8df4b826b7fef4d6dcdd709456964bf4aa
--- /dev/null
+++ b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/onnx/model_O3.onnx
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dd55510706038d0817b7d41bf2078f01472e4865190584ad624e8ab79bbcb310
+size 435666516
diff --git a/pretrained_weight/eval_model_weight/all-mpnet-base-v2/onnx/model_O4.onnx b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/onnx/model_O4.onnx
new file mode 100644
index 0000000000000000000000000000000000000000..37e2d7943e64361ddc7131075fe16cb708b46372
--- /dev/null
+++ b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/onnx/model_O4.onnx
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cab2a54139fc4fd5b8e2a23cb5729ee28dc44cfde685ad3356d533653e635310
+size 217894954
diff --git a/pretrained_weight/eval_model_weight/all-mpnet-base-v2/onnx/model_qint8_arm64.onnx b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/onnx/model_qint8_arm64.onnx
new file mode 100644
index 0000000000000000000000000000000000000000..78a22a50d1ed39d96186de6dd8ac0c2cff7c5131
--- /dev/null
+++ b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/onnx/model_qint8_arm64.onnx
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c392a9c545c7d4438a16fed8287a76a576b27eaf029c1c23bbf78a7a666d197f
+size 110124379
diff --git a/pretrained_weight/eval_model_weight/all-mpnet-base-v2/onnx/model_qint8_avx512.onnx b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/onnx/model_qint8_avx512.onnx
new file mode 100644
index 0000000000000000000000000000000000000000..78a22a50d1ed39d96186de6dd8ac0c2cff7c5131
--- /dev/null
+++ b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/onnx/model_qint8_avx512.onnx
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c392a9c545c7d4438a16fed8287a76a576b27eaf029c1c23bbf78a7a666d197f
+size 110124379
diff --git a/pretrained_weight/eval_model_weight/all-mpnet-base-v2/onnx/model_qint8_avx512_vnni.onnx b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/onnx/model_qint8_avx512_vnni.onnx
new file mode 100644
index 0000000000000000000000000000000000000000..78a22a50d1ed39d96186de6dd8ac0c2cff7c5131
--- /dev/null
+++ b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/onnx/model_qint8_avx512_vnni.onnx
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c392a9c545c7d4438a16fed8287a76a576b27eaf029c1c23bbf78a7a666d197f
+size 110124379
diff --git a/pretrained_weight/eval_model_weight/all-mpnet-base-v2/onnx/model_quint8_avx2.onnx b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/onnx/model_quint8_avx2.onnx
new file mode 100644
index 0000000000000000000000000000000000000000..38be8d1e74075429c6ade2f16bd74a59a328895e
--- /dev/null
+++ b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/onnx/model_quint8_avx2.onnx
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:aa5c27172d77bbd1cbae3628cbac4b26d7c12adabff25d2d4285d0f29159b237
+size 110207323
diff --git a/pretrained_weight/eval_model_weight/all-mpnet-base-v2/openvino/openvino_model.bin b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/openvino/openvino_model.bin
new file mode 100644
index 0000000000000000000000000000000000000000..6dbc2ce15f36a06ab7d9870dca7f031cc2261408
--- /dev/null
+++ b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/openvino/openvino_model.bin
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5c3279d833888eaab745e24b652126c5a71375af185ac21aa47e112e2468dec0
+size 435583684
diff --git a/pretrained_weight/eval_model_weight/all-mpnet-base-v2/openvino/openvino_model.xml b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/openvino/openvino_model.xml
new file mode 100644
index 0000000000000000000000000000000000000000..dfeaa2053722ecb6bb7be747ea602cac687b4616
--- /dev/null
+++ b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/openvino/openvino_model.xml
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a2912e3dbd3426b77984992953998d8026a3d2377104093079e810b53fc51bf6
+size 432773
diff --git a/pretrained_weight/eval_model_weight/all-mpnet-base-v2/openvino/openvino_model_qint8_quantized.bin b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/openvino/openvino_model_qint8_quantized.bin
new file mode 100644
index 0000000000000000000000000000000000000000..2e8a02f7325f31f0ade43ee1ec36ac70391b2cfa
--- /dev/null
+++ b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/openvino/openvino_model_qint8_quantized.bin
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fde0c650018f5e244f793316b666aaf4758d4e19072f430e59eb2bcc414895ce
+size 109974792
diff --git a/pretrained_weight/eval_model_weight/all-mpnet-base-v2/openvino/openvino_model_qint8_quantized.xml b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/openvino/openvino_model_qint8_quantized.xml
new file mode 100644
index 0000000000000000000000000000000000000000..78d53d70a8791fc4bd5a3713664e43d6fca13c5c
--- /dev/null
+++ b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/openvino/openvino_model_qint8_quantized.xml
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:930bc2a849d48941bb4752d8dac018f0c0ee8709ba023e47aeab4f8bb9c25b59
+size 741875
diff --git a/pretrained_weight/eval_model_weight/all-mpnet-base-v2/pytorch_model.bin b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/pytorch_model.bin
new file mode 100644
index 0000000000000000000000000000000000000000..30f684bec488ea3bf1b2c7816f997243b6e86434
--- /dev/null
+++ b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/pytorch_model.bin
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a8fd120b1a0032e70ff3d4b8ab8e46a6d01c2cb08ffe7c007a021c1788928146
+size 438011953
diff --git a/pretrained_weight/eval_model_weight/all-mpnet-base-v2/sentence_bert_config.json b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/sentence_bert_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..337ee72d2d0466155ca9b398771efb2bd77b7f4a
--- /dev/null
+++ b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/sentence_bert_config.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cabfacded9272091a06ff595a46ef027a76ddf4ac9e77d0fcf11c605748f1667
+size 53
diff --git a/pretrained_weight/eval_model_weight/all-mpnet-base-v2/special_tokens_map.json b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/special_tokens_map.json
new file mode 100644
index 0000000000000000000000000000000000000000..5596be0bc1a599cd864e701a913cafce06f1ccd6
--- /dev/null
+++ b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/special_tokens_map.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9ef40e9c160511bf3f46ceb71f1471dafa1e9473d5120bb816c36b2efa75f8ba
+size 239
diff --git a/pretrained_weight/eval_model_weight/all-mpnet-base-v2/tokenizer.json b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/tokenizer.json
new file mode 100644
index 0000000000000000000000000000000000000000..0351f0921b98d81a82773453af9fa8516d5f0b4d
--- /dev/null
+++ b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/tokenizer.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b8be2c30ba5dd723a6d5ee26d013da103d5408d92ddcb23747622f9e48f1d842
+size 466021
diff --git a/pretrained_weight/eval_model_weight/all-mpnet-base-v2/tokenizer_config.json b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/tokenizer_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..4d38726da50aff9696af73cc9ee6e302d795d8f3
--- /dev/null
+++ b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/tokenizer_config.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:67f2ff7e223518e729869bb3a70f0caf8368fe549383fc11cfe2dfb42fffc268
+size 363
diff --git a/pretrained_weight/eval_model_weight/all-mpnet-base-v2/train_script.py b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/train_script.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ac8dde434b7e7919dff242b49154d1af9d1b620
--- /dev/null
+++ b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/train_script.py
@@ -0,0 +1,344 @@
+"""
+Train script for a single file
+
+Need to set the TPU address first:
+export XRT_TPU_CONFIG="localservice;0;localhost:51011"
+"""
+
+import torch.multiprocessing as mp
+import threading
+import time
+import random
+import sys
+import argparse
+import gzip
+import json
+import logging
+import tqdm
+import torch
+from torch import nn
+from torch.utils.data import DataLoader
+import torch
+import torch_xla
+import torch_xla.core
+import torch_xla.core.functions
+import torch_xla.core.xla_model as xm
+import torch_xla.distributed.xla_multiprocessing as xmp
+import torch_xla.distributed.parallel_loader as pl
+import os
+from shutil import copyfile
+
+
+from transformers import (
+ AdamW,
+ AutoModel,
+ AutoTokenizer,
+ get_linear_schedule_with_warmup,
+ set_seed,
+)
+
+class AutoModelForSentenceEmbedding(nn.Module):
+ def __init__(self, model_name, tokenizer, normalize=True):
+ super(AutoModelForSentenceEmbedding, self).__init__()
+
+ self.model = AutoModel.from_pretrained(model_name)
+ self.normalize = normalize
+ self.tokenizer = tokenizer
+
+ def forward(self, **kwargs):
+ model_output = self.model(**kwargs)
+ embeddings = self.mean_pooling(model_output, kwargs['attention_mask'])
+ if self.normalize:
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
+
+ return embeddings
+
+ def mean_pooling(self, model_output, attention_mask):
+ token_embeddings = model_output[0] # First element of model_output contains all token embeddings
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
+
+ def save_pretrained(self, output_path):
+ if xm.is_master_ordinal():
+ self.tokenizer.save_pretrained(output_path)
+ self.model.config.save_pretrained(output_path)
+
+ xm.save(self.model.state_dict(), os.path.join(output_path, "pytorch_model.bin"))
+
+
+
+
+def train_function(index, args, queue):
+ tokenizer = AutoTokenizer.from_pretrained(args.model)
+ model = AutoModelForSentenceEmbedding(args.model, tokenizer)
+
+
+ ### Train Loop
+ device = xm.xla_device()
+ model = model.to(device)
+
+ # Instantiate optimizer
+ optimizer = AdamW(params=model.parameters(), lr=2e-5, correct_bias=True)
+
+ lr_scheduler = get_linear_schedule_with_warmup(
+ optimizer=optimizer,
+ num_warmup_steps=500,
+ num_training_steps=args.steps,
+ )
+
+ # Now we train the model
+ cross_entropy_loss = nn.CrossEntropyLoss()
+ max_grad_norm = 1
+
+ model.train()
+
+ for global_step in tqdm.trange(args.steps, disable=not xm.is_master_ordinal()):
+ #### Get the batch data
+ batch = queue.get()
+ #print(index, "batch {}x{}".format(len(batch), ",".join([str(len(b)) for b in batch])))
+
+
+ if len(batch[0]) == 2: #(anchor, positive)
+ text1 = tokenizer([b[0] for b in batch], return_tensors="pt", max_length=args.max_length, truncation=True, padding="max_length")
+ text2 = tokenizer([b[1] for b in batch], return_tensors="pt", max_length=args.max_length, truncation=True, padding="max_length")
+
+ ### Compute embeddings
+ embeddings_a = model(**text1.to(device))
+ embeddings_b = model(**text2.to(device))
+
+ ### Gather all embedings
+ embeddings_a = torch_xla.core.functions.all_gather(embeddings_a)
+ embeddings_b = torch_xla.core.functions.all_gather(embeddings_b)
+
+ ### Compute similarity scores 512 x 512
+ scores = torch.mm(embeddings_a, embeddings_b.transpose(0, 1)) * args.scale
+
+ ### Compute cross-entropy loss
+ labels = torch.tensor(range(len(scores)), dtype=torch.long, device=embeddings_a.device) # Example a[i] should match with b[i]
+
+ ## Symmetric loss as in CLIP
+ loss = (cross_entropy_loss(scores, labels) + cross_entropy_loss(scores.transpose(0, 1), labels)) / 2
+
+ else: #(anchor, positive, negative)
+ text1 = tokenizer([b[0] for b in batch], return_tensors="pt", max_length=args.max_length, truncation=True, padding="max_length")
+ text2 = tokenizer([b[1] for b in batch], return_tensors="pt", max_length=args.max_length, truncation=True, padding="max_length")
+ text3 = tokenizer([b[2] for b in batch], return_tensors="pt", max_length=args.max_length, truncation=True, padding="max_length")
+
+ embeddings_a = model(**text1.to(device))
+ embeddings_b1 = model(**text2.to(device))
+ embeddings_b2 = model(**text3.to(device))
+
+ embeddings_a = torch_xla.core.functions.all_gather(embeddings_a)
+ embeddings_b1 = torch_xla.core.functions.all_gather(embeddings_b1)
+ embeddings_b2 = torch_xla.core.functions.all_gather(embeddings_b2)
+
+ embeddings_b = torch.cat([embeddings_b1, embeddings_b2])
+
+ ### Compute similarity scores 512 x 1024
+ scores = torch.mm(embeddings_a, embeddings_b.transpose(0, 1)) * args.scale
+
+ ### Compute cross-entropy loss
+ labels = torch.tensor(range(len(scores)), dtype=torch.long, device=embeddings_a.device) # Example a[i] should match with b[i]
+
+ ## One-way loss
+ loss = cross_entropy_loss(scores, labels)
+
+
+ # Backward pass
+ optimizer.zero_grad()
+ loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
+
+ xm.optimizer_step(optimizer, barrier=True)
+ lr_scheduler.step()
+
+
+ #Save model
+ if (global_step+1) % args.save_steps == 0:
+ output_path = os.path.join(args.output, str(global_step+1))
+ xm.master_print("save model: "+output_path)
+ model.save_pretrained(output_path)
+
+
+ output_path = os.path.join(args.output, "final")
+ xm.master_print("save model final: "+ output_path)
+ model.save_pretrained(output_path)
+
+
+def produce_data(args, queue, filepaths, dataset_indices):
+ global_batch_size = args.batch_size*args.nprocs #Global batch size
+ size_per_dataset = int(global_batch_size / args.datasets_per_batch) #How many datasets per batch
+ num_same_dataset = int(size_per_dataset / args.batch_size)
+ print("producer", "global_batch_size", global_batch_size)
+ print("producer", "size_per_dataset", size_per_dataset)
+ print("producer", "num_same_dataset", num_same_dataset)
+
+ datasets = []
+ for filepath in filepaths:
+ if "reddit_" in filepath: #Special dataset class for Reddit files
+ data_obj = RedditDataset(filepath)
+ else:
+ data_obj = Dataset(filepath)
+ datasets.append(iter(data_obj))
+
+ # Store if dataset is in a 2 col or 3 col format
+ num_cols = {idx: len(next(dataset)) for idx, dataset in enumerate(datasets)}
+
+ while True:
+ texts_in_batch = set()
+ batch_format = None #2 vs 3 col format for this batch
+
+ #Add data from several sub datasets
+ for _ in range(args.datasets_per_batch):
+ valid_dataset = False #Check that datasets have the same 2/3 col format
+ while not valid_dataset:
+ data_idx = random.choice(dataset_indices)
+ if batch_format is None:
+ batch_format = num_cols[data_idx]
+ valid_dataset = True
+ else: #Check that this dataset has the same format
+ valid_dataset = (batch_format == num_cols[data_idx])
+
+ #Get data from this dataset
+ dataset = datasets[data_idx]
+ for _ in range(num_same_dataset):
+ for _ in range(args.nprocs):
+ batch_device = [] #A batch for one device
+ while len(batch_device) < args.batch_size:
+ sample = next(dataset)
+ in_batch = False
+ for text in sample:
+ if text in texts_in_batch:
+ in_batch = True
+ break
+
+ if not in_batch:
+ for text in sample:
+ texts_in_batch.add(text)
+ batch_device.append(sample)
+
+ queue.put(batch_device)
+
+
+class RedditDataset:
+ """
+ A class that handles the reddit data files
+ """
+ def __init__(self, filepath):
+ self.filepath = filepath
+
+ def __iter__(self):
+ while True:
+ with gzip.open(self.filepath, "rt") as fIn:
+ for line in fIn:
+ data = json.loads(line)
+
+ if "response" in data and "context" in data:
+ yield [data["response"], data["context"]]
+
+class Dataset:
+ """
+ A class that handles one dataset
+ """
+ def __init__(self, filepath):
+ self.filepath = filepath
+
+ def __iter__(self):
+ max_dataset_size = 10*1000*1000 #Cache small datasets in memory
+ dataset = []
+ data_format = None
+
+ while dataset is None or len(dataset) == 0:
+ with gzip.open(self.filepath, "rt") as fIn:
+ for line in fIn:
+ data = json.loads(line)
+ if isinstance(data, dict):
+ data = data['texts']
+
+ if data_format is None:
+ data_format = len(data)
+
+ #Ensure that all entries are of the same 2/3 col format
+ assert len(data) == data_format
+
+ if dataset is not None:
+ dataset.append(data)
+ if len(dataset) >= max_dataset_size:
+ dataset = None
+
+ yield data
+
+ # Data loaded. Now stream to the queue
+ # Shuffle for each epoch
+ while True:
+ random.shuffle(dataset)
+ for data in dataset:
+ yield data
+
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--model', default='nreimers/MiniLM-L6-H384-uncased')
+ parser.add_argument('--steps', type=int, default=2000)
+ parser.add_argument('--save_steps', type=int, default=10000)
+ parser.add_argument('--batch_size', type=int, default=64)
+ parser.add_argument('--max_length', type=int, default=128)
+ parser.add_argument('--nprocs', type=int, default=8)
+ parser.add_argument('--datasets_per_batch', type=int, default=2, help="Number of datasets per batch")
+ parser.add_argument('--scale', type=float, default=20, help="Use 20 for cossim, and 1 when you work with unnormalized embeddings with dot product")
+ parser.add_argument('--data_folder', default="/data", help="Folder with your dataset files")
+ parser.add_argument('data_config', help="A data_config.json file")
+ parser.add_argument('output')
+ args = parser.parse_args()
+
+ # Ensure global batch size is divisble by data_sample_size
+ assert (args.batch_size*args.nprocs) % args.datasets_per_batch == 0
+
+ logging.info("Output: "+args.output)
+ if os.path.exists(args.output):
+ print("Output folder already exists.")
+ input("Continue?")
+
+ # Write train script to output path
+ os.makedirs(args.output, exist_ok=True)
+
+ data_config_path = os.path.join(args.output, 'data_config.json')
+ copyfile(args.data_config, data_config_path)
+
+ train_script_path = os.path.join(args.output, 'train_script.py')
+ copyfile(__file__, train_script_path)
+ with open(train_script_path, 'a') as fOut:
+ fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
+
+
+
+ #Load data config
+ with open(args.data_config) as fIn:
+ data_config = json.load(fIn)
+
+ queue = mp.Queue(maxsize=100*args.nprocs)
+
+ filepaths = []
+ dataset_indices = []
+ for idx, data in enumerate(data_config):
+ filepaths.append(os.path.join(os.path.expanduser(args.data_folder), data['name']))
+ dataset_indices.extend([idx]*data['weight'])
+
+ # Start producer
+ p = mp.Process(target=produce_data, args=(args, queue, filepaths, dataset_indices))
+ p.start()
+
+ # Run training
+ print("Start processes:", args.nprocs)
+ xmp.spawn(train_function, args=(args, queue), nprocs=args.nprocs, start_method='fork')
+ print("Training done")
+ print("It might be that not all processes exit automatically. In that case you must manually kill this process.")
+ print("With 'pkill python' you can kill all remaining python processes")
+ p.kill()
+ exit()
+
+
+
+# Script was called via:
+#python train_many_data_files_v2.py --steps 1000000 --batch_size 64 --model microsoft/mpnet-base train_data_configs/all_datasets_v4.json output/all_datasets_v4_mpnet-base
\ No newline at end of file
diff --git a/pretrained_weight/eval_model_weight/all-mpnet-base-v2/vocab.txt b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/vocab.txt
new file mode 100644
index 0000000000000000000000000000000000000000..660b0e95c930c7834bdd971d42cc4cee7de09586
--- /dev/null
+++ b/pretrained_weight/eval_model_weight/all-mpnet-base-v2/vocab.txt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dbd90cb94e2247bd4d4ccaecbf616d2290e66691d7d5e5bb81f063c2d0649ada
+size 231536
diff --git a/pretrained_weight/eval_model_weight/sup-simcse-roberta-large/.gitattributes b/pretrained_weight/eval_model_weight/sup-simcse-roberta-large/.gitattributes
new file mode 100644
index 0000000000000000000000000000000000000000..4201df72315fb56981be878f531bb178f173541d
--- /dev/null
+++ b/pretrained_weight/eval_model_weight/sup-simcse-roberta-large/.gitattributes
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:989510633921a32875e2290fefbac3b82b1803a98c40d489ee3e7d7c47ae7d21
+size 736
diff --git a/pretrained_weight/eval_model_weight/sup-simcse-roberta-large/README.md b/pretrained_weight/eval_model_weight/sup-simcse-roberta-large/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..378fb7869177b643993ed1896dd0c065877dd9be
--- /dev/null
+++ b/pretrained_weight/eval_model_weight/sup-simcse-roberta-large/README.md
@@ -0,0 +1,168 @@
+---
+
+tags:
+- feature-extraction
+
+---
+# Model Card for sup-simcse-roberta-large
+
+
+# Model Details
+
+## Model Description
+
+
+
+- **Developed by:** Princeton-nlp
+- **Shared by [Optional]:** More information needed
+- **Model type:** Feature Extraction
+- **Language(s) (NLP):** More information needed
+- **License:** More information needed
+- **Related Models:**
+ - **Parent Model:** RoBERTa-large
+- **Resources for more information:**
+ - [GitHub Repo](https://github.com/princeton-nlp/SimCSE)
+ - [Associated Paper](https://arxiv.org/abs/2104.08821)
+ - [Blog Post]({0})
+
+# Uses
+
+
+## Direct Use
+
+This model can be used for the task of Feature Extraction
+
+## Downstream Use [Optional]
+
+More information needed
+
+## Out-of-Scope Use
+
+The model should not be used to intentionally create hostile or alienating environments for people.
+
+# Bias, Risks, and Limitations
+
+Significant research has explored bias and fairness issues with language models (see, e.g., [Sheng et al. (2021)](https://aclanthology.org/2021.acl-long.330.pdf) and [Bender et al. (2021)](https://dl.acm.org/doi/pdf/10.1145/3442188.3445922)). Predictions generated by the model may include disturbing and harmful stereotypes across protected classes; identity characteristics; and sensitive, social, and occupational groups.
+
+
+## Recommendations
+
+Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
+
+
+# Training Details
+
+## Training Data
+The model craters note in the [Github Repository](https://github.com/princeton-nlp/SimCSE/blob/main/README.md)
+> We train unsupervised SimCSE on 106 randomly sampled sentences from English Wikipedia, and train supervised SimCSE on the combination of MNLI and SNLI datasets (314k).
+
+## Training Procedure
+
+
+### Preprocessing
+
+More information needed
+
+### Speeds, Sizes, Times
+
+More information needed
+
+# Evaluation
+
+
+## Testing Data, Factors & Metrics
+
+### Testing Data
+
+ The model craters note in the [associated paper](https://arxiv.org/pdf/2104.08821.pdf)
+> Our evaluation code for sentence embeddings is based on a modified version of [SentEval](https://github.com/facebookresearch/SentEval). It evaluates sentence embeddings on semantic textual similarity (STS) tasks and downstream transfer tasks. For STS tasks, our evaluation takes the "all" setting, and report Spearman's correlation. See [associated paper](https://arxiv.org/pdf/2104.08821.pdf) (Appendix B) for evaluation details.
+
+### Factors
+
+
+### Metrics
+
+More information needed
+## Results
+
+More information needed
+
+# Model Examination
+
+More information needed
+
+# Environmental Impact
+
+
+Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
+
+- **Hardware Type:** More information needed
+- **Hours used:** More information needed
+- **Cloud Provider:** More information needed
+- **Compute Region:** More information needed
+- **Carbon Emitted:** More information needed
+
+# Technical Specifications [optional]
+
+## Model Architecture and Objective
+
+More information needed
+
+## Compute Infrastructure
+
+More information needed
+
+### Hardware
+
+More information needed
+
+### Software
+More information needed
+
+# Citation
+
+
+**BibTeX:**
+
+ ```bibtex
+@inproceedings{gao2021simcse,
+ title={{SimCSE}: Simple Contrastive Learning of Sentence Embeddings},
+ author={Gao, Tianyu and Yao, Xingcheng and Chen, Danqi},
+ booktitle={Empirical Methods in Natural Language Processing (EMNLP)},
+ year={2021}
+}
+
+```
+
+
+# Glossary [optional]
+More information needed
+
+# More Information [optional]
+
+If you have any questions related to the code or the paper, feel free to email Tianyu (`tianyug@cs.princeton.edu`) and Xingcheng (`yxc18@mails.tsinghua.edu.cn`). If you encounter any problems when using the code, or want to report a bug, you can open an issue. Please try to specify the problem with details so we can help you better and quicker!
+# Model Card Authors [optional]
+
+
+Princeton NLP group in collaboration with Ezi Ozoani and the Hugging Face team
+
+# Model Card Contact
+
+More information needed
+
+# How to Get Started with the Model
+
+Use the code below to get started with the model.
+
+
+ Click to expand
+
+```python
+from transformers import AutoTokenizer, AutoModel
+
+tokenizer = AutoTokenizer.from_pretrained("princeton-nlp/sup-simcse-roberta-large")
+
+model = AutoModel.from_pretrained("princeton-nlp/sup-simcse-roberta-large")
+
+```
+
diff --git a/pretrained_weight/eval_model_weight/sup-simcse-roberta-large/config.json b/pretrained_weight/eval_model_weight/sup-simcse-roberta-large/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..92521a0efd16b2326c1a55e7baf7e82ce573cf06
--- /dev/null
+++ b/pretrained_weight/eval_model_weight/sup-simcse-roberta-large/config.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8b23fc051229f885bcb1d5361cd6e8b516bfccc322067c803d7c886de206ac8c
+size 664
diff --git a/pretrained_weight/eval_model_weight/sup-simcse-roberta-large/flax_model.msgpack b/pretrained_weight/eval_model_weight/sup-simcse-roberta-large/flax_model.msgpack
new file mode 100644
index 0000000000000000000000000000000000000000..b71e9284cc83609d737a48e30506f21807c93622
--- /dev/null
+++ b/pretrained_weight/eval_model_weight/sup-simcse-roberta-large/flax_model.msgpack
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:33223b2f3c2fcee6351cb0027bc3451e3db1b4e8f16fef29f245592c381bdf2e
+size 1421452955
diff --git a/pretrained_weight/eval_model_weight/sup-simcse-roberta-large/merges.txt b/pretrained_weight/eval_model_weight/sup-simcse-roberta-large/merges.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8e987df11095d23a649a1049f122df5907ed8076
--- /dev/null
+++ b/pretrained_weight/eval_model_weight/sup-simcse-roberta-large/merges.txt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fe36cab26d4f4421ed725e10a2e9ddb7f799449c603a96e7f29b5a3c82a95862
+size 456356
diff --git a/pretrained_weight/eval_model_weight/sup-simcse-roberta-large/pytorch_model.bin b/pretrained_weight/eval_model_weight/sup-simcse-roberta-large/pytorch_model.bin
new file mode 100644
index 0000000000000000000000000000000000000000..573bdebe3dc5a691fbbbe83af4e8bc97df903048
--- /dev/null
+++ b/pretrained_weight/eval_model_weight/sup-simcse-roberta-large/pytorch_model.bin
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b97bbd5aa01a5ab66f6e2d8bb96bb78aa01f81238787cfa9dc28b3f950f3da78
+size 1421571527
diff --git a/pretrained_weight/eval_model_weight/sup-simcse-roberta-large/special_tokens_map.json b/pretrained_weight/eval_model_weight/sup-simcse-roberta-large/special_tokens_map.json
new file mode 100644
index 0000000000000000000000000000000000000000..b1650435b544827c64d83d5c907e0cbbf1e065e4
--- /dev/null
+++ b/pretrained_weight/eval_model_weight/sup-simcse-roberta-large/special_tokens_map.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:378eb3bf733eb16e65792d7e3fda5b8a4631387ca04d2015199c4d4f22ae554d
+size 239
diff --git a/pretrained_weight/eval_model_weight/sup-simcse-roberta-large/tokenizer_config.json b/pretrained_weight/eval_model_weight/sup-simcse-roberta-large/tokenizer_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..582bd0ff9d09d8f08f66c54a9863f0d3577f8e06
--- /dev/null
+++ b/pretrained_weight/eval_model_weight/sup-simcse-roberta-large/tokenizer_config.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e5fdce4c4d2f83496b184197b9ea8005c0ff88b71797a8bd764b1024bc534871
+size 256
diff --git a/pretrained_weight/eval_model_weight/sup-simcse-roberta-large/vocab.json b/pretrained_weight/eval_model_weight/sup-simcse-roberta-large/vocab.json
new file mode 100644
index 0000000000000000000000000000000000000000..4afa3b68cfe3da8a92bade86267bfb823f32b62e
--- /dev/null
+++ b/pretrained_weight/eval_model_weight/sup-simcse-roberta-large/vocab.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ed19656ea1707df69134c4af35c8ceda2cc9860bf2c3495026153a133670ab5e
+size 798293
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000000000000000000000000000000000000..8731f6d4485e394fafcb8c2d6419d63892ddf450
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0246d717202152e507ed8602942bbda9732d13a2b464d4942989518e62e70d5c
+size 1480
diff --git a/release/5M_data_seting/scripts/test/release_5M_stage_2.sh b/release/5M_data_seting/scripts/test/release_5M_stage_2.sh
new file mode 100644
index 0000000000000000000000000000000000000000..5fdb5a85884c841e87da78b2eadf13736ce28942
--- /dev/null
+++ b/release/5M_data_seting/scripts/test/release_5M_stage_2.sh
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2b91a317642ca07374e25b40dfb614ddd3acf450e325e346fc65d4d680add1c1
+size 3387
diff --git a/release/5M_data_seting/scripts/test/release_5M_stage_3_lr_1e5.sh b/release/5M_data_seting/scripts/test/release_5M_stage_3_lr_1e5.sh
new file mode 100644
index 0000000000000000000000000000000000000000..b954a88cb7a439cd9b5fc697f03ff49f6176923c
--- /dev/null
+++ b/release/5M_data_seting/scripts/test/release_5M_stage_3_lr_1e5.sh
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8ce461f97271bcb0c6f5ed55719e6631c790c3087487e68749dd5b060d08f672
+size 3482
diff --git a/release/5M_data_seting/scripts/train/2.sh b/release/5M_data_seting/scripts/train/2.sh
new file mode 100644
index 0000000000000000000000000000000000000000..776d5fd2a49b3b1c551b88bdfd66a9d38c581873
--- /dev/null
+++ b/release/5M_data_seting/scripts/train/2.sh
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b48fd8e8efb201b1b33412880ae0b381911f2224f3a54fd16ab27b655beb0935
+size 1377
diff --git a/release/5M_data_seting/scripts/train/3.sh b/release/5M_data_seting/scripts/train/3.sh
new file mode 100644
index 0000000000000000000000000000000000000000..e2e1bed391caac1bc6f79eb159e13afcb94f9baa
--- /dev/null
+++ b/release/5M_data_seting/scripts/train/3.sh
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:65878631a10b3844334e6d8b13f26dba9a1f11269bc8dcd770b5004a79e7748f
+size 1621
diff --git a/release/5M_data_seting/weight/stage_2/5M_low_lr_1e5/README.md b/release/5M_data_seting/weight/stage_2/5M_low_lr_1e5/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..21dd51d3571961e4f3ccd051d1cd12f26ad65ace
--- /dev/null
+++ b/release/5M_data_seting/weight/stage_2/5M_low_lr_1e5/README.md
@@ -0,0 +1,202 @@
+---
+library_name: peft
+base_model: ./lava-vicuna_2024_4_Phi-3-mini-4k-instruct
+---
+
+# Model Card for Model ID
+
+
+
+
+
+## Model Details
+
+### Model Description
+
+
+
+
+
+- **Developed by:** [More Information Needed]
+- **Funded by [optional]:** [More Information Needed]
+- **Shared by [optional]:** [More Information Needed]
+- **Model type:** [More Information Needed]
+- **Language(s) (NLP):** [More Information Needed]
+- **License:** [More Information Needed]
+- **Finetuned from model [optional]:** [More Information Needed]
+
+### Model Sources [optional]
+
+
+
+- **Repository:** [More Information Needed]
+- **Paper [optional]:** [More Information Needed]
+- **Demo [optional]:** [More Information Needed]
+
+## Uses
+
+
+
+### Direct Use
+
+
+
+[More Information Needed]
+
+### Downstream Use [optional]
+
+
+
+[More Information Needed]
+
+### Out-of-Scope Use
+
+
+
+[More Information Needed]
+
+## Bias, Risks, and Limitations
+
+
+
+[More Information Needed]
+
+### Recommendations
+
+
+
+Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
+
+## How to Get Started with the Model
+
+Use the code below to get started with the model.
+
+[More Information Needed]
+
+## Training Details
+
+### Training Data
+
+
+
+[More Information Needed]
+
+### Training Procedure
+
+
+
+#### Preprocessing [optional]
+
+[More Information Needed]
+
+
+#### Training Hyperparameters
+
+- **Training regime:** [More Information Needed]
+
+#### Speeds, Sizes, Times [optional]
+
+
+
+[More Information Needed]
+
+## Evaluation
+
+
+
+### Testing Data, Factors & Metrics
+
+#### Testing Data
+
+
+
+[More Information Needed]
+
+#### Factors
+
+
+
+[More Information Needed]
+
+#### Metrics
+
+
+
+[More Information Needed]
+
+### Results
+
+[More Information Needed]
+
+#### Summary
+
+
+
+## Model Examination [optional]
+
+
+
+[More Information Needed]
+
+## Environmental Impact
+
+
+
+Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
+
+- **Hardware Type:** [More Information Needed]
+- **Hours used:** [More Information Needed]
+- **Cloud Provider:** [More Information Needed]
+- **Compute Region:** [More Information Needed]
+- **Carbon Emitted:** [More Information Needed]
+
+## Technical Specifications [optional]
+
+### Model Architecture and Objective
+
+[More Information Needed]
+
+### Compute Infrastructure
+
+[More Information Needed]
+
+#### Hardware
+
+[More Information Needed]
+
+#### Software
+
+[More Information Needed]
+
+## Citation [optional]
+
+
+
+**BibTeX:**
+
+[More Information Needed]
+
+**APA:**
+
+[More Information Needed]
+
+## Glossary [optional]
+
+
+
+[More Information Needed]
+
+## More Information [optional]
+
+[More Information Needed]
+
+## Model Card Authors [optional]
+
+[More Information Needed]
+
+## Model Card Contact
+
+[More Information Needed]
+### Framework versions
+
+- PEFT 0.11.1
\ No newline at end of file
diff --git a/release/5M_data_seting/weight/stage_2/5M_low_lr_1e5/adapter_config.json b/release/5M_data_seting/weight/stage_2/5M_low_lr_1e5/adapter_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..c5af7103894cb466246ba80734c2f2ddd5af53f1
--- /dev/null
+++ b/release/5M_data_seting/weight/stage_2/5M_low_lr_1e5/adapter_config.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1f12367bed6b7c9b5b4aa7b907c644dd366b24069737cba2d7bb0bedbf9f70b7
+size 708
diff --git a/release/5M_data_seting/weight/stage_2/5M_low_lr_1e5/adapter_model.safetensors b/release/5M_data_seting/weight/stage_2/5M_low_lr_1e5/adapter_model.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..d28e349dcec861f183dbeda9949ea2499946748f
--- /dev/null
+++ b/release/5M_data_seting/weight/stage_2/5M_low_lr_1e5/adapter_model.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:616d06ddf4832cf1f1cb6fb913ca09412152a5c28781d57e70a67098913f1cab
+size 100697984
diff --git a/release/5M_data_seting/weight/stage_2/5M_low_lr_1e5/config.json b/release/5M_data_seting/weight/stage_2/5M_low_lr_1e5/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..06c2744ef561ae64bb2fe90d61db0961bd7e12d3
--- /dev/null
+++ b/release/5M_data_seting/weight/stage_2/5M_low_lr_1e5/config.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1a653de7a3096215504ef831de6e9ecb4a66e83c7b6807e68184dff3d4dbbb05
+size 1280
diff --git a/release/5M_data_seting/weight/stage_2/5M_low_lr_1e5/non_lora_trainables.bin b/release/5M_data_seting/weight/stage_2/5M_low_lr_1e5/non_lora_trainables.bin
new file mode 100644
index 0000000000000000000000000000000000000000..663b710824285282fbd68630b11e5e802ead1c2a
--- /dev/null
+++ b/release/5M_data_seting/weight/stage_2/5M_low_lr_1e5/non_lora_trainables.bin
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:85e4a3e7e3127bee3b8ce3b0667783204da685066daf999b191216da1e947f7e
+size 25179879
diff --git a/release/5M_data_seting/weight/stage_2/5M_low_lr_1e5/trainer_state.json b/release/5M_data_seting/weight/stage_2/5M_low_lr_1e5/trainer_state.json
new file mode 100644
index 0000000000000000000000000000000000000000..d66155797d229f3c64685d121b791a43fdccccdc
--- /dev/null
+++ b/release/5M_data_seting/weight/stage_2/5M_low_lr_1e5/trainer_state.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fb182543e8bdb3f4ee1c4b8c2d7ead79ca4164617c7fed60417df57d91b66729
+size 88102847
diff --git a/release/5M_data_seting/weight/stage_3/use_stage_5M_data_lr_1e5/README.md b/release/5M_data_seting/weight/stage_3/use_stage_5M_data_lr_1e5/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..21dd51d3571961e4f3ccd051d1cd12f26ad65ace
--- /dev/null
+++ b/release/5M_data_seting/weight/stage_3/use_stage_5M_data_lr_1e5/README.md
@@ -0,0 +1,202 @@
+---
+library_name: peft
+base_model: ./lava-vicuna_2024_4_Phi-3-mini-4k-instruct
+---
+
+# Model Card for Model ID
+
+
+
+
+
+## Model Details
+
+### Model Description
+
+
+
+
+
+- **Developed by:** [More Information Needed]
+- **Funded by [optional]:** [More Information Needed]
+- **Shared by [optional]:** [More Information Needed]
+- **Model type:** [More Information Needed]
+- **Language(s) (NLP):** [More Information Needed]
+- **License:** [More Information Needed]
+- **Finetuned from model [optional]:** [More Information Needed]
+
+### Model Sources [optional]
+
+
+
+- **Repository:** [More Information Needed]
+- **Paper [optional]:** [More Information Needed]
+- **Demo [optional]:** [More Information Needed]
+
+## Uses
+
+
+
+### Direct Use
+
+
+
+[More Information Needed]
+
+### Downstream Use [optional]
+
+
+
+[More Information Needed]
+
+### Out-of-Scope Use
+
+
+
+[More Information Needed]
+
+## Bias, Risks, and Limitations
+
+
+
+[More Information Needed]
+
+### Recommendations
+
+
+
+Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
+
+## How to Get Started with the Model
+
+Use the code below to get started with the model.
+
+[More Information Needed]
+
+## Training Details
+
+### Training Data
+
+
+
+[More Information Needed]
+
+### Training Procedure
+
+
+
+#### Preprocessing [optional]
+
+[More Information Needed]
+
+
+#### Training Hyperparameters
+
+- **Training regime:** [More Information Needed]
+
+#### Speeds, Sizes, Times [optional]
+
+
+
+[More Information Needed]
+
+## Evaluation
+
+
+
+### Testing Data, Factors & Metrics
+
+#### Testing Data
+
+
+
+[More Information Needed]
+
+#### Factors
+
+
+
+[More Information Needed]
+
+#### Metrics
+
+
+
+[More Information Needed]
+
+### Results
+
+[More Information Needed]
+
+#### Summary
+
+
+
+## Model Examination [optional]
+
+
+
+[More Information Needed]
+
+## Environmental Impact
+
+
+
+Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
+
+- **Hardware Type:** [More Information Needed]
+- **Hours used:** [More Information Needed]
+- **Cloud Provider:** [More Information Needed]
+- **Compute Region:** [More Information Needed]
+- **Carbon Emitted:** [More Information Needed]
+
+## Technical Specifications [optional]
+
+### Model Architecture and Objective
+
+[More Information Needed]
+
+### Compute Infrastructure
+
+[More Information Needed]
+
+#### Hardware
+
+[More Information Needed]
+
+#### Software
+
+[More Information Needed]
+
+## Citation [optional]
+
+
+
+**BibTeX:**
+
+[More Information Needed]
+
+**APA:**
+
+[More Information Needed]
+
+## Glossary [optional]
+
+
+
+[More Information Needed]
+
+## More Information [optional]
+
+[More Information Needed]
+
+## Model Card Authors [optional]
+
+[More Information Needed]
+
+## Model Card Contact
+
+[More Information Needed]
+### Framework versions
+
+- PEFT 0.11.1
\ No newline at end of file
diff --git a/release/5M_data_seting/weight/stage_3/use_stage_5M_data_lr_1e5/adapter_config.json b/release/5M_data_seting/weight/stage_3/use_stage_5M_data_lr_1e5/adapter_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..c5af7103894cb466246ba80734c2f2ddd5af53f1
--- /dev/null
+++ b/release/5M_data_seting/weight/stage_3/use_stage_5M_data_lr_1e5/adapter_config.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1f12367bed6b7c9b5b4aa7b907c644dd366b24069737cba2d7bb0bedbf9f70b7
+size 708
diff --git a/release/5M_data_seting/weight/stage_3/use_stage_5M_data_lr_1e5/adapter_model.safetensors b/release/5M_data_seting/weight/stage_3/use_stage_5M_data_lr_1e5/adapter_model.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..f776eb29e4e4f4a6f81e432bb7d5e71229256f2f
--- /dev/null
+++ b/release/5M_data_seting/weight/stage_3/use_stage_5M_data_lr_1e5/adapter_model.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b4754c82a02ea2ecb63380d39deba3d0c2a039c8dae0107f752d5199ac2a732c
+size 100697984
diff --git a/release/5M_data_seting/weight/stage_3/use_stage_5M_data_lr_1e5/config.json b/release/5M_data_seting/weight/stage_3/use_stage_5M_data_lr_1e5/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..45cc321b4ea1d68174bc0999c3121f346c48f18d
--- /dev/null
+++ b/release/5M_data_seting/weight/stage_3/use_stage_5M_data_lr_1e5/config.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1368592eb17bac832760f4f92ed7e3ab0095270b29d71f2a7221b4d6f08b1e43
+size 1280
diff --git a/release/5M_data_seting/weight/stage_3/use_stage_5M_data_lr_1e5/non_lora_trainables.bin b/release/5M_data_seting/weight/stage_3/use_stage_5M_data_lr_1e5/non_lora_trainables.bin
new file mode 100644
index 0000000000000000000000000000000000000000..e31184f0a8d23d3f5cce9fea92e43f8f415c911f
--- /dev/null
+++ b/release/5M_data_seting/weight/stage_3/use_stage_5M_data_lr_1e5/non_lora_trainables.bin
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8e414eb12c85d500f75e229f5b9f8b8d94989ff9451ab56037fa347b18aa5db8
+size 25971143
diff --git a/release/5M_data_seting/weight/stage_3/use_stage_5M_data_lr_1e5/trainer_state.json b/release/5M_data_seting/weight/stage_3/use_stage_5M_data_lr_1e5/trainer_state.json
new file mode 100644
index 0000000000000000000000000000000000000000..a106ab2857ff3d969469def30944b1f9d531f0a8
--- /dev/null
+++ b/release/5M_data_seting/weight/stage_3/use_stage_5M_data_lr_1e5/trainer_state.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:51da303959735ca7ddfeddce69df7176936f29e589048a73f29db1b717bc1e4b
+size 1865017
diff --git a/release/paper/scripts/test/release_stage_2.sh b/release/paper/scripts/test/release_stage_2.sh
new file mode 100644
index 0000000000000000000000000000000000000000..ff861f2e806574ce3322bc3e0c87617a8aef89cf
--- /dev/null
+++ b/release/paper/scripts/test/release_stage_2.sh
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3d0b237b43bb2d91745aa9f4ec1d9742f6fcf971285ae9f8fdac8f72124386e8
+size 2836
diff --git a/release/paper/scripts/test/release_stage_3.sh b/release/paper/scripts/test/release_stage_3.sh
new file mode 100644
index 0000000000000000000000000000000000000000..ec747a63340829fc43bfa00e49a4bab4d409c8f8
--- /dev/null
+++ b/release/paper/scripts/test/release_stage_3.sh
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a3c38be7ae473f2c8456d7be5e08e77570cddcfcc74da1a9f2420b70acc5fa80
+size 2851
diff --git a/release/paper/scripts/train/1.sh b/release/paper/scripts/train/1.sh
new file mode 100644
index 0000000000000000000000000000000000000000..0b0fe850400a4cf75b7073e6c4cc101f5f0051bb
--- /dev/null
+++ b/release/paper/scripts/train/1.sh
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d3c954266e4452b05aae0b9334bf94d06c62749c46c5e99373619270d49bece1
+size 1166
diff --git a/release/paper/scripts/train/2.sh b/release/paper/scripts/train/2.sh
new file mode 100644
index 0000000000000000000000000000000000000000..b882ff4b6e73521eb201b658cfd81d0887cf1b91
--- /dev/null
+++ b/release/paper/scripts/train/2.sh
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:066ecb988ef6ce9cb1c7cca513495e19f7f307ffd45cbb20bb1102d3d4678a6d
+size 1373
diff --git a/release/paper/scripts/train/3.sh b/release/paper/scripts/train/3.sh
new file mode 100644
index 0000000000000000000000000000000000000000..537f8f121aa3fe9d21d4e04df5cf43abb8b1a5ae
--- /dev/null
+++ b/release/paper/scripts/train/3.sh
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d0dc1aa903b2aafa6099bd598fcbbc2f2f4801a7ea4f857e68d476f488eb4a2b
+size 1612
diff --git a/release/paper/weight/stage_1/config.json b/release/paper/weight/stage_1/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..9d0eccb980d400c578b08ebfc41eff1a678c5f7e
--- /dev/null
+++ b/release/paper/weight/stage_1/config.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e749274e786379e53e67e923557dc3763d5a4b31c7d6cfec26807a5e6d15a6f7
+size 1254
diff --git a/release/paper/weight/stage_1/mm_projector.bin b/release/paper/weight/stage_1/mm_projector.bin
new file mode 100644
index 0000000000000000000000000000000000000000..2a3010e6194ae3a4e95b29e00c029395eb859a68
--- /dev/null
+++ b/release/paper/weight/stage_1/mm_projector.bin
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:72bcfd1c43a25b0cc81f91c20fd3bf45b0c2f87e9df3ec4a5f84049d169c47ae
+size 25179773
diff --git a/release/paper/weight/stage_1/trainer_state.json b/release/paper/weight/stage_1/trainer_state.json
new file mode 100644
index 0000000000000000000000000000000000000000..8dfa8f93092bff89f1f6cc533ef7a859178dec41
--- /dev/null
+++ b/release/paper/weight/stage_1/trainer_state.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:20dcf1d0d76140b63ea1d2638d76e7ec0cc6b232b87aa45dd4a14ac9db55a2d9
+size 10606471
diff --git a/release/paper/weight/stage_2/README.md b/release/paper/weight/stage_2/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..21dd51d3571961e4f3ccd051d1cd12f26ad65ace
--- /dev/null
+++ b/release/paper/weight/stage_2/README.md
@@ -0,0 +1,202 @@
+---
+library_name: peft
+base_model: ./lava-vicuna_2024_4_Phi-3-mini-4k-instruct
+---
+
+# Model Card for Model ID
+
+
+
+
+
+## Model Details
+
+### Model Description
+
+
+
+
+
+- **Developed by:** [More Information Needed]
+- **Funded by [optional]:** [More Information Needed]
+- **Shared by [optional]:** [More Information Needed]
+- **Model type:** [More Information Needed]
+- **Language(s) (NLP):** [More Information Needed]
+- **License:** [More Information Needed]
+- **Finetuned from model [optional]:** [More Information Needed]
+
+### Model Sources [optional]
+
+
+
+- **Repository:** [More Information Needed]
+- **Paper [optional]:** [More Information Needed]
+- **Demo [optional]:** [More Information Needed]
+
+## Uses
+
+
+
+### Direct Use
+
+
+
+[More Information Needed]
+
+### Downstream Use [optional]
+
+
+
+[More Information Needed]
+
+### Out-of-Scope Use
+
+
+
+[More Information Needed]
+
+## Bias, Risks, and Limitations
+
+
+
+[More Information Needed]
+
+### Recommendations
+
+
+
+Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
+
+## How to Get Started with the Model
+
+Use the code below to get started with the model.
+
+[More Information Needed]
+
+## Training Details
+
+### Training Data
+
+
+
+[More Information Needed]
+
+### Training Procedure
+
+
+
+#### Preprocessing [optional]
+
+[More Information Needed]
+
+
+#### Training Hyperparameters
+
+- **Training regime:** [More Information Needed]
+
+#### Speeds, Sizes, Times [optional]
+
+
+
+[More Information Needed]
+
+## Evaluation
+
+
+
+### Testing Data, Factors & Metrics
+
+#### Testing Data
+
+
+
+[More Information Needed]
+
+#### Factors
+
+
+
+[More Information Needed]
+
+#### Metrics
+
+
+
+[More Information Needed]
+
+### Results
+
+[More Information Needed]
+
+#### Summary
+
+
+
+## Model Examination [optional]
+
+
+
+[More Information Needed]
+
+## Environmental Impact
+
+
+
+Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
+
+- **Hardware Type:** [More Information Needed]
+- **Hours used:** [More Information Needed]
+- **Cloud Provider:** [More Information Needed]
+- **Compute Region:** [More Information Needed]
+- **Carbon Emitted:** [More Information Needed]
+
+## Technical Specifications [optional]
+
+### Model Architecture and Objective
+
+[More Information Needed]
+
+### Compute Infrastructure
+
+[More Information Needed]
+
+#### Hardware
+
+[More Information Needed]
+
+#### Software
+
+[More Information Needed]
+
+## Citation [optional]
+
+
+
+**BibTeX:**
+
+[More Information Needed]
+
+**APA:**
+
+[More Information Needed]
+
+## Glossary [optional]
+
+
+
+[More Information Needed]
+
+## More Information [optional]
+
+[More Information Needed]
+
+## Model Card Authors [optional]
+
+[More Information Needed]
+
+## Model Card Contact
+
+[More Information Needed]
+### Framework versions
+
+- PEFT 0.11.1
\ No newline at end of file
diff --git a/release/paper/weight/stage_2/adapter_config.json b/release/paper/weight/stage_2/adapter_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..289818d06977ef1b9a2a41da474e6033fed45d40
--- /dev/null
+++ b/release/paper/weight/stage_2/adapter_config.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fb6facd911ef35c5e43f51b5718d65cbfe1e71581398f64667cfda247e252aeb
+size 708
diff --git a/release/paper/weight/stage_2/adapter_model.safetensors b/release/paper/weight/stage_2/adapter_model.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..9dd751f206dfb51808ba3080ca3df1b91562f5f5
--- /dev/null
+++ b/release/paper/weight/stage_2/adapter_model.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:75f3be40f841560cb1c1a4fb911a53b7d6d1a60c99713dcd44f81cad6cf30384
+size 100697984
diff --git a/release/paper/weight/stage_2/config.json b/release/paper/weight/stage_2/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..97ef2fd51e0adf01dc612a9f03d98d186d950046
--- /dev/null
+++ b/release/paper/weight/stage_2/config.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7377a1bfbec5fb51ffab6cbeddcd5d8d3e6abe445c42436b30283cce80138e71
+size 1253
diff --git a/release/paper/weight/stage_2/non_lora_trainables.bin b/release/paper/weight/stage_2/non_lora_trainables.bin
new file mode 100644
index 0000000000000000000000000000000000000000..1e01f6e88fa199fdf986b1522a3ace98927c4d55
--- /dev/null
+++ b/release/paper/weight/stage_2/non_lora_trainables.bin
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:669a448d691c9fdfbdf3821407fe2c396f7bb9e9025867c0e2195ed05f5067b2
+size 25179879
diff --git a/release/paper/weight/stage_2/trainer_state.json b/release/paper/weight/stage_2/trainer_state.json
new file mode 100644
index 0000000000000000000000000000000000000000..fa9b13031e934ce74ada395201a36744d85e13dd
--- /dev/null
+++ b/release/paper/weight/stage_2/trainer_state.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:26ca85eb72c6874cc9e0f536a8c0f513aec53aad745f41ad2842714b389353a4
+size 2561685
diff --git a/release/paper/weight/stage_3/README.md b/release/paper/weight/stage_3/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..21dd51d3571961e4f3ccd051d1cd12f26ad65ace
--- /dev/null
+++ b/release/paper/weight/stage_3/README.md
@@ -0,0 +1,202 @@
+---
+library_name: peft
+base_model: ./lava-vicuna_2024_4_Phi-3-mini-4k-instruct
+---
+
+# Model Card for Model ID
+
+
+
+
+
+## Model Details
+
+### Model Description
+
+
+
+
+
+- **Developed by:** [More Information Needed]
+- **Funded by [optional]:** [More Information Needed]
+- **Shared by [optional]:** [More Information Needed]
+- **Model type:** [More Information Needed]
+- **Language(s) (NLP):** [More Information Needed]
+- **License:** [More Information Needed]
+- **Finetuned from model [optional]:** [More Information Needed]
+
+### Model Sources [optional]
+
+
+
+- **Repository:** [More Information Needed]
+- **Paper [optional]:** [More Information Needed]
+- **Demo [optional]:** [More Information Needed]
+
+## Uses
+
+
+
+### Direct Use
+
+
+
+[More Information Needed]
+
+### Downstream Use [optional]
+
+
+
+[More Information Needed]
+
+### Out-of-Scope Use
+
+
+
+[More Information Needed]
+
+## Bias, Risks, and Limitations
+
+
+
+[More Information Needed]
+
+### Recommendations
+
+
+
+Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
+
+## How to Get Started with the Model
+
+Use the code below to get started with the model.
+
+[More Information Needed]
+
+## Training Details
+
+### Training Data
+
+
+
+[More Information Needed]
+
+### Training Procedure
+
+
+
+#### Preprocessing [optional]
+
+[More Information Needed]
+
+
+#### Training Hyperparameters
+
+- **Training regime:** [More Information Needed]
+
+#### Speeds, Sizes, Times [optional]
+
+
+
+[More Information Needed]
+
+## Evaluation
+
+
+
+### Testing Data, Factors & Metrics
+
+#### Testing Data
+
+
+
+[More Information Needed]
+
+#### Factors
+
+
+
+[More Information Needed]
+
+#### Metrics
+
+
+
+[More Information Needed]
+
+### Results
+
+[More Information Needed]
+
+#### Summary
+
+
+
+## Model Examination [optional]
+
+
+
+[More Information Needed]
+
+## Environmental Impact
+
+
+
+Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
+
+- **Hardware Type:** [More Information Needed]
+- **Hours used:** [More Information Needed]
+- **Cloud Provider:** [More Information Needed]
+- **Compute Region:** [More Information Needed]
+- **Carbon Emitted:** [More Information Needed]
+
+## Technical Specifications [optional]
+
+### Model Architecture and Objective
+
+[More Information Needed]
+
+### Compute Infrastructure
+
+[More Information Needed]
+
+#### Hardware
+
+[More Information Needed]
+
+#### Software
+
+[More Information Needed]
+
+## Citation [optional]
+
+
+
+**BibTeX:**
+
+[More Information Needed]
+
+**APA:**
+
+[More Information Needed]
+
+## Glossary [optional]
+
+
+
+[More Information Needed]
+
+## More Information [optional]
+
+[More Information Needed]
+
+## Model Card Authors [optional]
+
+[More Information Needed]
+
+## Model Card Contact
+
+[More Information Needed]
+### Framework versions
+
+- PEFT 0.11.1
\ No newline at end of file
diff --git a/release/paper/weight/stage_3/adapter_config.json b/release/paper/weight/stage_3/adapter_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..a73d23b26e84c32ca21935d486eaea0de9c18600
--- /dev/null
+++ b/release/paper/weight/stage_3/adapter_config.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cf0aea7cba52e3c7f44a913d7eaec8e96eba97c5dfb135e3353c4531ccd0667d
+size 708
diff --git a/release/paper/weight/stage_3/adapter_model.safetensors b/release/paper/weight/stage_3/adapter_model.safetensors
new file mode 100644
index 0000000000000000000000000000000000000000..b928415c5e00916ecb7af3dfaca404269634ac58
--- /dev/null
+++ b/release/paper/weight/stage_3/adapter_model.safetensors
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0584b6cd64a58d95deac5a09df8a23999adfd005e7e1e7b1bd7a73de5b66e7df
+size 100697984
diff --git a/release/paper/weight/stage_3/config.json b/release/paper/weight/stage_3/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..45cc321b4ea1d68174bc0999c3121f346c48f18d
--- /dev/null
+++ b/release/paper/weight/stage_3/config.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1368592eb17bac832760f4f92ed7e3ab0095270b29d71f2a7221b4d6f08b1e43
+size 1280
diff --git a/release/paper/weight/stage_3/non_lora_trainables.bin b/release/paper/weight/stage_3/non_lora_trainables.bin
new file mode 100644
index 0000000000000000000000000000000000000000..5585fced732024ccaf4182a4fe94fe936a14401f
--- /dev/null
+++ b/release/paper/weight/stage_3/non_lora_trainables.bin
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c050e3dfb710f18ee911915c30bd49ca86a092b9c0f47df35b74778d9b235798
+size 25971143
diff --git a/release/paper/weight/stage_3/trainer_state.json b/release/paper/weight/stage_3/trainer_state.json
new file mode 100644
index 0000000000000000000000000000000000000000..7848e4229409ef95c22666db2280c7258eb8cc25
--- /dev/null
+++ b/release/paper/weight/stage_3/trainer_state.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2eaad25dfe522b0db542cfea2090b5e10b13904abc8d2154f10332c924c741fb
+size 1866424
diff --git a/scripts/convert_gqa_for_eval.py b/scripts/convert_gqa_for_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d46c8b876df618faac548e9b369109d541f4f23
--- /dev/null
+++ b/scripts/convert_gqa_for_eval.py
@@ -0,0 +1,18 @@
+import os
+import json
+import argparse
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--src", type=str)
+parser.add_argument("--dst", type=str)
+args = parser.parse_args()
+
+all_answers = []
+for line_idx, line in enumerate(open(args.src)):
+ res = json.loads(line)
+ question_id = res['question_id']
+ text = res['text'].rstrip('.').lower()
+ all_answers.append({"questionId": question_id, "prediction": text})
+
+with open(args.dst, 'w') as f:
+ json.dump(all_answers, f)
diff --git a/scripts/convert_mmbench_for_submission.py b/scripts/convert_mmbench_for_submission.py
new file mode 100644
index 0000000000000000000000000000000000000000..27baec12f9ef48d4e3df41e15b1d2644aab4174b
--- /dev/null
+++ b/scripts/convert_mmbench_for_submission.py
@@ -0,0 +1,27 @@
+import os
+import json
+import argparse
+import pandas as pd
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--annotation-file", type=str, required=True)
+ parser.add_argument("--result-dir", type=str, required=True)
+ parser.add_argument("--upload-dir", type=str, required=True)
+ parser.add_argument("--experiment", type=str, required=True)
+
+ return parser.parse_args()
+
+if __name__ == "__main__":
+ args = get_args()
+
+ df = pd.read_table(args.annotation_file)
+
+ cur_df = df.copy()
+ cur_df = cur_df.drop(columns=['hint', 'category', 'source', 'image', 'comment', 'l2-category'])
+ cur_df.insert(6, 'prediction', None)
+ for pred in open(os.path.join(args.result_dir, f"{args.experiment}.jsonl")):
+ pred = json.loads(pred)
+ cur_df.loc[df['index'] == pred['question_id'], 'prediction'] = pred['text']
+
+ cur_df.to_excel(os.path.join(args.upload_dir, f"{args.experiment}.xlsx"), index=False, engine='openpyxl')
diff --git a/scripts/convert_mmvet_for_eval.py b/scripts/convert_mmvet_for_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..97f5cfb7fb7691ef3921e3e6afc6d82ec54d4c6c
--- /dev/null
+++ b/scripts/convert_mmvet_for_eval.py
@@ -0,0 +1,18 @@
+import os
+import json
+import argparse
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--src", type=str)
+parser.add_argument("--dst", type=str)
+args = parser.parse_args()
+
+cur_result = {}
+
+for line in open(args.src):
+ data = json.loads(line)
+ qid = data['question_id']
+ cur_result[f'v1_{qid}'] = data['text']
+
+with open(args.dst, 'w') as f:
+ json.dump(cur_result, f, indent=2)
diff --git a/scripts/convert_seed_for_submission.py b/scripts/convert_seed_for_submission.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae903e63087516bc8ae77142532196be6a85589c
--- /dev/null
+++ b/scripts/convert_seed_for_submission.py
@@ -0,0 +1,74 @@
+import os
+import json
+import argparse
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--annotation-file", type=str)
+ parser.add_argument("--result-file", type=str)
+ parser.add_argument("--result-upload-file", type=str)
+ return parser.parse_args()
+
+
+def eval_single(result_file, eval_only_type=None):
+ results = {}
+ for line in open(result_file):
+ row = json.loads(line)
+ results[row['question_id']] = row
+
+ type_counts = {}
+ correct_counts = {}
+ for question_data in data['questions']:
+ if eval_only_type is not None and question_data['data_type'] != eval_only_type: continue
+ data_type = question_data['question_type_id']
+ type_counts[data_type] = type_counts.get(data_type, 0) + 1
+ try:
+ question_id = int(question_data['question_id'])
+ except:
+ question_id = question_data['question_id']
+ if question_id not in results:
+ correct_counts[data_type] = correct_counts.get(data_type, 0)
+ continue
+ row = results[question_id]
+ if row['text'] == question_data['answer']:
+ correct_counts[data_type] = correct_counts.get(data_type, 0) + 1
+
+ total_count = 0
+ total_correct = 0
+ for data_type in sorted(type_counts.keys()):
+ accuracy = correct_counts[data_type] / type_counts[data_type] * 100
+ if eval_only_type is None:
+ print(f"{ques_type_id_to_name[data_type]}: {accuracy:.2f}%")
+
+ total_count += type_counts[data_type]
+ total_correct += correct_counts[data_type]
+
+ total_accuracy = total_correct / total_count * 100
+ if eval_only_type is None:
+ print(f"Total accuracy: {total_accuracy:.2f}%")
+ else:
+ print(f"{eval_only_type} accuracy: {total_accuracy:.2f}%")
+
+ return results
+
+if __name__ == "__main__":
+ args = get_args()
+ data = json.load(open(args.annotation_file))
+ ques_type_id_to_name = {id:n for n,id in data['question_type'].items()}
+
+ results = eval_single(args.result_file)
+ eval_single(args.result_file, eval_only_type='image')
+ eval_single(args.result_file, eval_only_type='video')
+
+ with open(args.result_upload_file, 'w') as fp:
+ for question in data['questions']:
+ qid = question['question_id']
+ if qid in results:
+ result = results[qid]
+ else:
+ result = results[int(qid)]
+ fp.write(json.dumps({
+ 'question_id': qid,
+ 'prediction': result['text']
+ }) + '\n')
diff --git a/scripts/convert_sqa_to_llava.py b/scripts/convert_sqa_to_llava.py
new file mode 100644
index 0000000000000000000000000000000000000000..26fe3002413a23b5029e540c8b338ebb14307bf6
--- /dev/null
+++ b/scripts/convert_sqa_to_llava.py
@@ -0,0 +1,88 @@
+import json
+import os
+import fire
+import re
+from convert_sqa_to_llava_base_prompt import build_prompt_chatbot
+
+
+def convert_to_llava(base_dir, split, prompt_format="QCM-LEA"):
+ split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[split]
+ problems = json.load(open(os.path.join(base_dir, "problems.json")))
+
+ split_problems = build_prompt_chatbot(
+ problems, split_indices, prompt_format,
+ use_caption=False, is_test=False)
+
+ target_format = []
+ for prob_id, (input, output) in split_problems.items():
+ if input.startswith('Question: '):
+ input = input.replace('Question: ', '')
+ if output.startswith('Answer: '):
+ output = output.replace('Answer: ', '')
+
+ raw_prob_data = problems[prob_id]
+ if raw_prob_data['image'] is None:
+ target_format.append({
+ "id": prob_id,
+ "conversations": [
+ {'from': 'human', 'value': f"{input}"},
+ {'from': 'gpt', 'value': f"{output}"},
+ ],
+ })
+
+ else:
+ target_format.append({
+ "id": prob_id,
+ "image": os.path.join(prob_id, raw_prob_data['image']),
+ "conversations": [
+ {'from': 'human', 'value': f"{input}\n"},
+ {'from': 'gpt', 'value': f"{output}"},
+ ],
+ })
+
+ print(f'Number of samples: {len(target_format)}')
+
+ with open(os.path.join(base_dir, f"llava_{split}_{prompt_format}.json"), "w") as f:
+ json.dump(target_format, f, indent=2)
+
+
+def convert_to_jsonl(base_dir, split, prompt_format="QCM-LEPA"):
+ split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[split]
+ problems = json.load(open(os.path.join(base_dir, "problems.json")))
+
+ split_problems = build_prompt_chatbot(
+ problems, split_indices, prompt_format,
+ use_caption=False, is_test=False)
+
+ writer = open(os.path.join(base_dir, f"scienceqa_{split}_{prompt_format}.jsonl"), "w")
+ for prob_id, (input, output) in split_problems.items():
+ if input.startswith('Question: '):
+ input = input.replace('Question: ', '')
+ if output.startswith('Answer: '):
+ output = output.replace('Answer: ', '')
+
+ raw_prob_data = problems[prob_id]
+ if raw_prob_data['image'] is None:
+ data = {
+ "id": prob_id,
+ "instruction": f"{input}",
+ "output": f"{output}",
+ }
+
+ else:
+ data = {
+ "id": prob_id,
+ "image": os.path.join(prob_id, raw_prob_data['image']),
+ "instruction": f"{input}\n",
+ "output": f"{output}",
+ }
+ writer.write(json.dumps(data) + '\n')
+ writer.close()
+
+
+def main(task, **kwargs):
+ globals()[task](**kwargs)
+
+
+if __name__ == "__main__":
+ fire.Fire(main)
diff --git a/scripts/convert_sqa_to_llava_base_prompt.py b/scripts/convert_sqa_to_llava_base_prompt.py
new file mode 100644
index 0000000000000000000000000000000000000000..b327fcc29eb44d7fe68be35da25bafa0e1d6feba
--- /dev/null
+++ b/scripts/convert_sqa_to_llava_base_prompt.py
@@ -0,0 +1,334 @@
+def get_question_text(problem):
+ question = problem['question']
+ return question
+
+
+def get_context_text(problem, use_caption):
+ txt_context = problem['hint']
+ img_context = problem['caption'] if use_caption else ""
+ context = " ".join([txt_context, img_context]).strip()
+ if context == "":
+ context = "N/A"
+ return context
+
+
+def get_choice_text(probelm, options):
+ choices = probelm['choices']
+ choice_list = []
+ for i, c in enumerate(choices):
+ choice_list.append("({}) {}".format(options[i], c))
+ choice_txt = " ".join(choice_list)
+ #print(choice_txt)
+ return choice_txt
+
+
+def get_answer(problem, options):
+ return options[problem['answer']]
+
+
+def get_lecture_text(problem):
+ # \\n: GPT-3 can generate the lecture with more tokens.
+ lecture = problem['lecture'].replace("\n", "\\n")
+ return lecture
+
+
+def get_solution_text(problem):
+ # \\n: GPT-3 can generate the solution with more tokens
+ solution = problem['solution'].replace("\n", "\\n")
+ return solution
+
+
+def create_one_example_chatbot(format, question, context, choice, answer, lecture, solution, test_example=True):
+
+ input_format, output_format = format.split("-")
+
+ ## Inputs
+ if input_format == "CQM":
+ input = f"Context: {context}\nQuestion: {question}\nOptions: {choice}\n"
+ elif input_format == "QCM":
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\n"
+ # upper bound experiment
+ elif input_format == "QCML":
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture}\n"
+ elif input_format == "QCME":
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {solution}\n"
+ elif input_format == "QCMLE":
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture} {solution}\n"
+
+ elif input_format == "QCLM":
+ input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture}\nOptions: {choice}\n"
+ elif input_format == "QCEM":
+ input = f"Question: {question}\nContext: {context}\nBECAUSE: {solution}\nOptions: {choice}\n"
+ elif input_format == "QCLEM":
+ input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture} {solution}\nOptions: {choice}\n"
+
+ # Outputs
+ if test_example:
+ output = "Answer:"
+ elif output_format == 'A':
+ output = f"Answer: The answer is {answer}."
+
+ elif output_format == 'AL':
+ output = f"Answer: The answer is {answer}. BECAUSE: {solution}"
+ elif output_format == 'AE':
+ output = f"Answer: The answer is {answer}. BECAUSE: {lecture}"
+ elif output_format == 'ALE':
+ output = f"Answer: The answer is {answer}. BECAUSE: {lecture} {solution}"
+ elif output_format == 'AEL':
+ output = f"Answer: The answer is {answer}. BECAUSE: {solution} {lecture}"
+
+ elif output_format == 'LA':
+ output = f"Answer: {lecture} The answer is {answer}."
+ elif output_format == 'EA':
+ output = f"Answer: {solution} The answer is {answer}."
+ elif output_format == 'LEA':
+ output = f"Answer: {lecture} {solution} The answer is {answer}."
+ elif output_format == 'ELA':
+ output = f"Answer: {solution} {lecture} The answer is {answer}."
+ elif output_format == 'LEPA':
+ output = ''
+ if len(lecture.strip()) > 0:
+ output += f"LECTURE: {lecture}\n"
+ if len(solution.strip()) > 0:
+ output += f"SOLUTION: {solution}\n"
+ output += '###\n'
+ output += f"ANSWER: {answer}."
+
+ input = input.replace(" ", " ").strip()
+ output = output.replace(" ", " ").strip()
+ if input.endswith("BECAUSE:"):
+ input = input.replace("BECAUSE:", "").strip()
+ if output.endswith("BECAUSE:"):
+ output = output.replace("BECAUSE:", "").strip()
+ return input, output
+
+
+def create_one_example(format, question, context, choice, answer, lecture, solution, test_example=True):
+
+ input_format, output_format = format.split("-")
+
+ ## Inputs
+ if input_format == "CQM":
+ input = f"Context: {context}\nQuestion: {question}\nOptions: {choice}\n"
+ elif input_format == "QCM":
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\n"
+ # upper bound experiment
+ elif input_format == "QCML":
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture}\n"
+ elif input_format == "QCME":
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {solution}\n"
+ elif input_format == "QCMLE":
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture} {solution}\n"
+
+ elif input_format == "QCLM":
+ input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture}\nOptions: {choice}\n"
+ elif input_format == "QCEM":
+ input = f"Question: {question}\nContext: {context}\nBECAUSE: {solution}\nOptions: {choice}\n"
+ elif input_format == "QCLEM":
+ input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture} {solution}\nOptions: {choice}\n"
+
+ # Outputs
+ if test_example:
+ output = "Answer:"
+ elif output_format == 'A':
+ output = f"Answer: The answer is {answer}."
+
+ elif output_format == 'AL':
+ output = f"Answer: The answer is {answer}. BECAUSE: {solution}"
+ elif output_format == 'AE':
+ output = f"Answer: The answer is {answer}. BECAUSE: {lecture}"
+ elif output_format == 'ALE':
+ output = f"Answer: The answer is {answer}. BECAUSE: {lecture} {solution}"
+ elif output_format == 'AEL':
+ output = f"Answer: The answer is {answer}. BECAUSE: {solution} {lecture}"
+
+ elif output_format == 'LA':
+ output = f"Answer: {lecture} The answer is {answer}."
+ elif output_format == 'EA':
+ output = f"Answer: {solution} The answer is {answer}."
+ elif output_format == 'LEA':
+ output = f"Answer: {lecture} {solution} The answer is {answer}."
+ elif output_format == 'ELA':
+ output = f"Answer: {solution} {lecture} The answer is {answer}."
+
+ text = input + output
+ text = text.replace(" ", " ").strip()
+ if text.endswith("BECAUSE:"):
+ text = text.replace("BECAUSE:", "").strip()
+ return text
+
+
+
+def create_one_example_gpt4(format, question, context, choice, answer, lecture, solution, test_example=True):
+
+ input_format, output_format = format.split("-")
+
+ ## Inputs
+ if input_format == "CQM":
+ input = f"Context: {context}\nQuestion: {question}\nOptions: {choice}\n"
+ elif input_format == "QCM":
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\n"
+ # upper bound experiment
+ elif input_format == "QCML":
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture}\n"
+ elif input_format == "QCME":
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {solution}\n"
+ elif input_format == "QCMLE":
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture} {solution}\n"
+
+ elif input_format == "QCLM":
+ input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture}\nOptions: {choice}\n"
+ elif input_format == "QCEM":
+ input = f"Question: {question}\nContext: {context}\nBECAUSE: {solution}\nOptions: {choice}\n"
+ elif input_format == "QCLEM":
+ input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture} {solution}\nOptions: {choice}\n"
+
+ # Outputs
+ if test_example:
+ output = "Answer:"
+ elif output_format == 'A':
+ output = f"Answer: The answer is {answer}."
+
+ elif output_format == 'AL':
+ output = f"Answer: The answer is {answer}. BECAUSE: {solution}"
+ elif output_format == 'AE':
+ output = f"Answer: The answer is {answer}. BECAUSE: {lecture}"
+ elif output_format == 'ALE':
+ output = f"Answer: The answer is {answer}. BECAUSE: {lecture} {solution}"
+ elif output_format == 'AEL':
+ output = f"Answer: The answer is {answer}. BECAUSE: {solution} {lecture}"
+
+ elif output_format == 'LA':
+ output = f"Answer: {lecture} The answer is {answer}."
+ elif output_format == 'EA':
+ output = f"Answer: {solution} The answer is {answer}."
+ elif output_format == 'LEA':
+ output = f"Answer: {lecture} {solution} The answer is {answer}."
+ elif output_format == 'ELA':
+ output = f"Answer: {solution} {lecture} The answer is {answer}."
+
+ input = input.replace(" ", " ").strip()
+ output = output.replace(" ", " ").strip()
+ if output.endswith("BECAUSE:"):
+ output = output.replace("BECAUSE:", "").strip()
+
+ user_prompt = {"role": "user", "content": f"Can you explain {input}?"}
+ assistant_prompt = {"role": "assistant", "content": f"{output}"}
+
+ return user_prompt, assistant_prompt
+
+
+def build_prompt_chatbot(problems, shot_qids, prompt_format, use_caption=False, options=["A", "B", "C", "D", "E"], is_test=False):
+ examples = {}
+
+ for qid in shot_qids:
+ question = get_question_text(problems[qid])
+ context = get_context_text(problems[qid], use_caption)
+ choice = get_choice_text(problems[qid], options)
+ answer = get_answer(problems[qid], options)
+ lecture = get_lecture_text(problems[qid]).replace('\\n', '\n')
+ solution = get_solution_text(problems[qid]).replace('\\n', '\n')
+
+ train_example = create_one_example_chatbot(prompt_format,
+ question,
+ context,
+ choice,
+ answer,
+ lecture,
+ solution,
+ test_example=is_test)
+ examples[qid] = train_example
+ return examples
+
+
+def build_prompt(problems, shot_qids, test_qid, args):
+
+ examples = []
+
+ # n-shot training examples
+ for qid in shot_qids:
+ question = get_question_text(problems[qid])
+ context = get_context_text(problems[qid], args.use_caption)
+ choice = get_choice_text(problems[qid], args.options)
+ answer = get_answer(problems[qid], args.options)
+ lecture = get_lecture_text(problems[qid])
+ solution = get_solution_text(problems[qid])
+
+ train_example = create_one_example(args.prompt_format,
+ question,
+ context,
+ choice,
+ answer,
+ lecture,
+ solution,
+ test_example=False)
+ examples.append(train_example)
+
+ # test example
+ question = get_question_text(problems[test_qid])
+ context = get_context_text(problems[test_qid], args.use_caption)
+ choice = get_choice_text(problems[test_qid], args.options)
+ answer = get_answer(problems[test_qid], args.options)
+ lecture = get_lecture_text(problems[test_qid])
+ solution = get_solution_text(problems[test_qid])
+
+ test_example = create_one_example(args.prompt_format,
+ question,
+ context,
+ choice,
+ answer,
+ lecture,
+ solution,
+ test_example=True)
+ examples.append(test_example)
+
+ # create the prompt input
+ prompt_input = '\n\n'.join(examples)
+
+ return prompt_input
+
+
+def build_prompt_gpt4(problems, shot_qids, test_qid, args):
+
+ prompt_array = [{"role": "system", "content": "You are a helpful assistant."}]
+
+ # n-shot training examples
+ for qid in shot_qids:
+ question = get_question_text(problems[qid])
+ context = get_context_text(problems[qid], args.use_caption)
+ choice = get_choice_text(problems[qid], args.options)
+ answer = get_answer(problems[qid], args.options)
+ lecture = get_lecture_text(problems[qid])
+ solution = get_solution_text(problems[qid])
+
+ user_prompt, assistant_prompt = create_one_example_gpt4(args.prompt_format,
+ question,
+ context,
+ choice,
+ answer,
+ lecture,
+ solution,
+ test_example=False)
+ prompt_array.append(user_prompt)
+ prompt_array.append(assistant_prompt)
+
+ # test example
+ question = get_question_text(problems[test_qid])
+ context = get_context_text(problems[test_qid], args.use_caption)
+ choice = get_choice_text(problems[test_qid], args.options)
+ answer = get_answer(problems[test_qid], args.options)
+ lecture = get_lecture_text(problems[test_qid])
+ solution = get_solution_text(problems[test_qid])
+
+ user_prompt, assistant_prompt = create_one_example_gpt4(args.prompt_format,
+ question,
+ context,
+ choice,
+ answer,
+ lecture,
+ solution,
+ test_example=True)
+ prompt_array.append(user_prompt)
+ prompt_array.append(assistant_prompt)
+
+ return prompt_array
\ No newline at end of file
diff --git a/scripts/convert_vizwiz_for_submission.py b/scripts/convert_vizwiz_for_submission.py
new file mode 100644
index 0000000000000000000000000000000000000000..7836d19f573d30e4224f2f89a53104acf03efb91
--- /dev/null
+++ b/scripts/convert_vizwiz_for_submission.py
@@ -0,0 +1,47 @@
+import os
+import argparse
+import json
+
+from llava.eval.m4c_evaluator import EvalAIAnswerProcessor
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--annotation-file', type=str, required=True)
+ parser.add_argument('--result-file', type=str, required=True)
+ parser.add_argument('--result-upload-file', type=str, required=True)
+ return parser.parse_args()
+
+
+if __name__ == '__main__':
+
+ args = parse_args()
+
+ os.makedirs(os.path.dirname(args.result_upload_file), exist_ok=True)
+
+ results = []
+ error_line = 0
+ for line_idx, line in enumerate(open(args.result_file)):
+ try:
+ results.append(json.loads(line))
+ except:
+ error_line += 1
+ results = {x['question_id']: x['text'] for x in results}
+ test_split = [json.loads(line) for line in open(args.annotation_file)]
+ split_ids = set([x['question_id'] for x in test_split])
+
+ print(f'total results: {len(results)}, total split: {len(test_split)}, error_line: {error_line}')
+
+ all_answers = []
+
+ answer_processor = EvalAIAnswerProcessor()
+
+ for x in test_split:
+ assert x['question_id'] in results
+ all_answers.append({
+ 'image': x['image'],
+ 'answer': answer_processor(results[x['question_id']])
+ })
+
+ with open(args.result_upload_file, 'w') as f:
+ json.dump(all_answers, f)
diff --git a/scripts/convert_vqav2_for_submission.py b/scripts/convert_vqav2_for_submission.py
new file mode 100644
index 0000000000000000000000000000000000000000..05f67b33a73e17c683dbf9c09f84bacd10f285f5
--- /dev/null
+++ b/scripts/convert_vqav2_for_submission.py
@@ -0,0 +1,56 @@
+import os
+import argparse
+import json
+
+from llava.eval.m4c_evaluator import EvalAIAnswerProcessor
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--dir', type=str, default="./playground/data/eval/vqav2")
+ parser.add_argument('--ckpt', type=str, required=True)
+ parser.add_argument('--split', type=str, required=True)
+ return parser.parse_args()
+
+
+if __name__ == '__main__':
+
+ args = parse_args()
+
+ src = os.path.join(args.dir, 'answers', args.split, args.ckpt, 'merge.jsonl')
+ test_split = os.path.join(args.dir, 'llava_vqav2_mscoco_test2015.jsonl')
+ dst = os.path.join(args.dir, 'answers_upload', args.split, f'{args.ckpt}.json')
+ os.makedirs(os.path.dirname(dst), exist_ok=True)
+
+ results = []
+ error_line = 0
+ for line_idx, line in enumerate(open(src)):
+ try:
+ results.append(json.loads(line))
+ except:
+ error_line += 1
+
+ results = {x['question_id']: x['text'] for x in results}
+ test_split = [json.loads(line) for line in open(test_split)]
+ split_ids = set([x['question_id'] for x in test_split])
+
+ print(f'total results: {len(results)}, total split: {len(test_split)}, error_line: {error_line}')
+
+ all_answers = []
+
+ answer_processor = EvalAIAnswerProcessor()
+
+ for x in test_split:
+ if x['question_id'] not in results:
+ all_answers.append({
+ 'question_id': x['question_id'],
+ 'answer': ''
+ })
+ else:
+ all_answers.append({
+ 'question_id': x['question_id'],
+ 'answer': answer_processor(results[x['question_id']])
+ })
+
+ with open(dst, 'w') as f:
+ json.dump(all_answers, open(dst, 'w'))
diff --git a/scripts/extract_mm_projector.py b/scripts/extract_mm_projector.py
new file mode 100644
index 0000000000000000000000000000000000000000..45be31e896e9c087093bd9bcb6d355ec6dfd11ab
--- /dev/null
+++ b/scripts/extract_mm_projector.py
@@ -0,0 +1,47 @@
+"""
+This is just a utility that I use to extract the projector for quantized models.
+It is NOT necessary at all to train, or run inference/serve demos.
+Use this script ONLY if you fully understand its implications.
+"""
+
+
+import os
+import argparse
+import torch
+import json
+from collections import defaultdict
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='Extract MMProjector weights')
+ parser.add_argument('--model-path', type=str, help='model folder')
+ parser.add_argument('--output', type=str, help='output file')
+ args = parser.parse_args()
+ return args
+
+
+if __name__ == '__main__':
+ args = parse_args()
+
+ keys_to_match = ['mm_projector']
+ ckpt_to_key = defaultdict(list)
+ try:
+ model_indices = json.load(open(os.path.join(args.model_path, 'pytorch_model.bin.index.json')))
+ for k, v in model_indices['weight_map'].items():
+ if any(key_match in k for key_match in keys_to_match):
+ ckpt_to_key[v].append(k)
+ except FileNotFoundError:
+ # Smaller models or model checkpoints saved by DeepSpeed.
+ v = 'pytorch_model.bin'
+ for k in torch.load(os.path.join(args.model_path, v), map_location='cpu').keys():
+ if any(key_match in k for key_match in keys_to_match):
+ ckpt_to_key[v].append(k)
+
+ loaded_weights = {}
+
+ for ckpt_name, weight_keys in ckpt_to_key.items():
+ ckpt = torch.load(os.path.join(args.model_path, ckpt_name), map_location='cpu')
+ for k in weight_keys:
+ loaded_weights[k] = ckpt[k]
+
+ torch.save(loaded_weights, args.output)
diff --git a/scripts/finetune.sh b/scripts/finetune.sh
new file mode 100644
index 0000000000000000000000000000000000000000..f3a7f76bb031a9a37c02d1d54d9503af7fd50541
--- /dev/null
+++ b/scripts/finetune.sh
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:17b3cbade91dfd020ef064c5c000022038a6c58731d6de66770cbea32e99b7b6
+size 1642
diff --git a/scripts/finetune_full_schedule.sh b/scripts/finetune_full_schedule.sh
new file mode 100644
index 0000000000000000000000000000000000000000..fe9ec7fa78d385c7eb744a33ccba972347d06d0e
--- /dev/null
+++ b/scripts/finetune_full_schedule.sh
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:01702a093f32e50db3e4fe28347d6707514da1c1935777fa99a981b52bc3f4cf
+size 1643
diff --git a/scripts/finetune_lora.sh b/scripts/finetune_lora.sh
new file mode 100644
index 0000000000000000000000000000000000000000..e5203b16b561521aaca865a7132703a1fa05342f
--- /dev/null
+++ b/scripts/finetune_lora.sh
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:33382961235b7a526b98c153f9367f64a1405107d392f1769d2a516edd9bf46b
+size 1672
diff --git a/scripts/finetune_qlora.sh b/scripts/finetune_qlora.sh
new file mode 100644
index 0000000000000000000000000000000000000000..5f9f56baf47a370d67b5d6e00f07936e45eaccc8
--- /dev/null
+++ b/scripts/finetune_qlora.sh
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0bab9f76a16fc03611368ccff460e1a3ce60bb88d6e4cb5f0a7e7f338c1a0c88
+size 1687
diff --git a/scripts/finetune_sqa.sh b/scripts/finetune_sqa.sh
new file mode 100644
index 0000000000000000000000000000000000000000..037d1058da724960e292c486d5619a778ee5d8eb
--- /dev/null
+++ b/scripts/finetune_sqa.sh
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b771ef3aace67984f2b4ea5425cba738054dec6e485930beb5274e95c41723dd
+size 1346
diff --git a/scripts/merge_lora_weights.py b/scripts/merge_lora_weights.py
new file mode 100644
index 0000000000000000000000000000000000000000..b19eea8808553b6cc20cf10da47781d844d35e7a
--- /dev/null
+++ b/scripts/merge_lora_weights.py
@@ -0,0 +1,22 @@
+import argparse
+from llava.model.builder import load_pretrained_model
+from llava.mm_utils import get_model_name_from_path
+
+
+def merge_lora(args):
+ model_name = get_model_name_from_path(args.model_path)
+ tokenizer, model, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, device_map='cpu')
+
+ model.save_pretrained(args.save_model_path)
+ tokenizer.save_pretrained(args.save_model_path)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model-path", type=str, required=True)
+ parser.add_argument("--model-base", type=str, required=True)
+ parser.add_argument("--save-model-path", type=str, required=True)
+
+ args = parser.parse_args()
+
+ merge_lora(args)
diff --git a/scripts/pretrain.sh b/scripts/pretrain.sh
new file mode 100644
index 0000000000000000000000000000000000000000..c7b4d10a51150193fb3d0f84d073b05b659d5aca
--- /dev/null
+++ b/scripts/pretrain.sh
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0bd52a4b49c485971ad8bc586a29b4e7ca8011560277b3d1b5dedbf36b3ef987
+size 1460
diff --git a/scripts/pretrain_xformers.sh b/scripts/pretrain_xformers.sh
new file mode 100644
index 0000000000000000000000000000000000000000..55aebab929fce3128880c73bbb076c0a36a38524
--- /dev/null
+++ b/scripts/pretrain_xformers.sh
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:29b146d008589213bf12a5f19833b8b40fe233adb16e718fca5b589d5379e167
+size 1380
diff --git a/scripts/sqa_eval_batch.sh b/scripts/sqa_eval_batch.sh
new file mode 100644
index 0000000000000000000000000000000000000000..c53435d9a06d00ef8e14ed92d504890b6d762df8
--- /dev/null
+++ b/scripts/sqa_eval_batch.sh
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:35dcef6688e0f3922f5569a852d933cd863860fb0161d0937b62415684e9e972
+size 524
diff --git a/scripts/sqa_eval_gather.sh b/scripts/sqa_eval_gather.sh
new file mode 100644
index 0000000000000000000000000000000000000000..501c8d31f1bbd4e2201f9193058af540f8e7f2d7
--- /dev/null
+++ b/scripts/sqa_eval_gather.sh
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9a504f39618fa0fe7b0d3dcc0f93002fbf977a4016d1a8185006b0e9a8e824e6
+size 518
diff --git a/scripts/upload_pypi.sh b/scripts/upload_pypi.sh
new file mode 100755
index 0000000000000000000000000000000000000000..c49d08f5a5c251b43be9fb8d75d095ca04580861
--- /dev/null
+++ b/scripts/upload_pypi.sh
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:22eb41f462a97a1249e2c65bd5bb7ff5ba684d3d57c7a922dc5ef2d9f7b3e9cb
+size 387
diff --git a/scripts/v1_5/eval/gqa.sh b/scripts/v1_5/eval/gqa.sh
new file mode 100644
index 0000000000000000000000000000000000000000..6b462c352da23ea60f463dcfbadff657ed47c5fd
--- /dev/null
+++ b/scripts/v1_5/eval/gqa.sh
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1b03e13e153b8e9ed310f664cf357306c056c43abd51fa5bcd947ccc909e5b61
+size 1228
diff --git a/scripts/v1_5/eval/llavabench.sh b/scripts/v1_5/eval/llavabench.sh
new file mode 100644
index 0000000000000000000000000000000000000000..66211ba963bab3d5a216e977a15c1ff35c4221be
--- /dev/null
+++ b/scripts/v1_5/eval/llavabench.sh
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fd6ad2eb6ba97bdd1db2d59912f26aa05cdc3dedfdbe812193b520b627a7632a
+size 1093
diff --git a/scripts/v1_5/eval/mmbench.sh b/scripts/v1_5/eval/mmbench.sh
new file mode 100644
index 0000000000000000000000000000000000000000..fa1388246a97198d95f9182a579763a7fffc6af0
--- /dev/null
+++ b/scripts/v1_5/eval/mmbench.sh
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a7e1dcf1c3783ea16a8e58170149d002118c29d166796f9a94e74d9abb7d55cd
+size 704
diff --git a/scripts/v1_5/eval/mmbench_cn.sh b/scripts/v1_5/eval/mmbench_cn.sh
new file mode 100644
index 0000000000000000000000000000000000000000..0db93ec8a67e5298cf567747cf27a0b4eb1a7183
--- /dev/null
+++ b/scripts/v1_5/eval/mmbench_cn.sh
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8427091301bf56195bec6238199dfdb4fb07cd601cf179f4674a13b94124df96
+size 738
diff --git a/scripts/v1_5/eval/mme.sh b/scripts/v1_5/eval/mme.sh
new file mode 100644
index 0000000000000000000000000000000000000000..829f82466efea43104440f3917a3a51d95dd9f6d
--- /dev/null
+++ b/scripts/v1_5/eval/mme.sh
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:02250d26f33bd09c6a9c15e1777b1e5172a285d00eb613af5c2cd80aaa0d1966
+size 532
diff --git a/scripts/v1_5/eval/mmvet.sh b/scripts/v1_5/eval/mmvet.sh
new file mode 100644
index 0000000000000000000000000000000000000000..e2c3c1da4619eade4e250858dcf1ce41fc9e052a
--- /dev/null
+++ b/scripts/v1_5/eval/mmvet.sh
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a8f0dbf4a1aa417d9ba662e751d4f69e55ff758c56aea66026b5fa2483d81099
+size 580
diff --git a/scripts/v1_5/eval/pope.sh b/scripts/v1_5/eval/pope.sh
new file mode 100644
index 0000000000000000000000000000000000000000..ff9a231518aaaf80cf03dda21536b4551cb38d6a
--- /dev/null
+++ b/scripts/v1_5/eval/pope.sh
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9f1260968ec7997d184665321008609b6f4286d3ae106c7b8abe585c201584f7
+size 590
diff --git a/scripts/v1_5/eval/qbench.sh b/scripts/v1_5/eval/qbench.sh
new file mode 100644
index 0000000000000000000000000000000000000000..991fdff2cb06f5bba01f1174bd84807719cc0f8d
--- /dev/null
+++ b/scripts/v1_5/eval/qbench.sh
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:45d688ca1fc92d1721f91aa903854f41850b8ae388c49fe11d49c66446f983e2
+size 578
diff --git a/scripts/v1_5/eval/qbench_zh.sh b/scripts/v1_5/eval/qbench_zh.sh
new file mode 100644
index 0000000000000000000000000000000000000000..d9a70f9dce6be5cac9246e0534b1bc894b450ed5
--- /dev/null
+++ b/scripts/v1_5/eval/qbench_zh.sh
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:45f1ae8699c5a38fdde87df34b5caa8d5e772d709947781a34c153327c79888e
+size 641
diff --git a/scripts/v1_5/eval/seed.sh b/scripts/v1_5/eval/seed.sh
new file mode 100644
index 0000000000000000000000000000000000000000..f16ed636e4d7c9781d155df4d4f02ec9f2dce7c9
--- /dev/null
+++ b/scripts/v1_5/eval/seed.sh
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7056c730f42533fee364f757430e18f31b0532e7890331dc68f8f7ea91f2baa1
+size 1264
diff --git a/scripts/v1_5/eval/sqa.sh b/scripts/v1_5/eval/sqa.sh
new file mode 100644
index 0000000000000000000000000000000000000000..5d0055769ca226df32e46cd0471f454207f8726e
--- /dev/null
+++ b/scripts/v1_5/eval/sqa.sh
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6b0350dd4cfc49c602fa649c41ce7a4a878f238f11a8e68c02516bb00fcb2899
+size 749
diff --git a/scripts/v1_5/eval/textvqa.sh b/scripts/v1_5/eval/textvqa.sh
new file mode 100644
index 0000000000000000000000000000000000000000..52f8239f023e38628c419015355ac30e431e0d55
--- /dev/null
+++ b/scripts/v1_5/eval/textvqa.sh
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f38445df60738977d67b1e3c48a7fc467dd1c6009aa094640ddaf403cdf0bccb
+size 571
diff --git a/scripts/v1_5/eval/vizwiz.sh b/scripts/v1_5/eval/vizwiz.sh
new file mode 100644
index 0000000000000000000000000000000000000000..8adef4d9f40d8fe8a748cc3a41e255608551b973
--- /dev/null
+++ b/scripts/v1_5/eval/vizwiz.sh
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7c50e1e835238d67ebf8e433f896cdb450dd32161472f1a7cf222c26e5b4757d
+size 642
diff --git a/scripts/v1_5/eval/vqav2.sh b/scripts/v1_5/eval/vqav2.sh
new file mode 100644
index 0000000000000000000000000000000000000000..8f4990683b8ec00ee7335091c5906d7ec2398185
--- /dev/null
+++ b/scripts/v1_5/eval/vqav2.sh
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:89c8705edee2b9be132b94a29f36f6d5114fab507de5dca1d621910d4d28ad7f
+size 1113
diff --git a/scripts/v1_5/finetune.sh b/scripts/v1_5/finetune.sh
new file mode 100644
index 0000000000000000000000000000000000000000..d782a6dd5c9dbfd180aa5d9d7cae7b448a72f34c
--- /dev/null
+++ b/scripts/v1_5/finetune.sh
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:aef5d0f6737ad76143205ec297f441367c2a4c8d876632fe08d60d58e619e320
+size 1234
diff --git a/scripts/v1_5/finetune_lora.sh b/scripts/v1_5/finetune_lora.sh
new file mode 100644
index 0000000000000000000000000000000000000000..c05d824598c36e2fa6170e4724e5d84316571c7f
--- /dev/null
+++ b/scripts/v1_5/finetune_lora.sh
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:208bdd804d856674ee580897b7085c2233eb04a35e785f77193aff1984527bd0
+size 1317
diff --git a/scripts/v1_5/finetune_task.sh b/scripts/v1_5/finetune_task.sh
new file mode 100644
index 0000000000000000000000000000000000000000..7f787581909592f78e2c2776e3a8ecbd274de2d5
--- /dev/null
+++ b/scripts/v1_5/finetune_task.sh
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4637ae27c5e4f1d62216fc9df82d806b6ecb300d14576a32c98c73e58464c5e5
+size 1156
diff --git a/scripts/v1_5/finetune_task_lora.sh b/scripts/v1_5/finetune_task_lora.sh
new file mode 100644
index 0000000000000000000000000000000000000000..124f83408653432b001d544edf6a919699bd3fae
--- /dev/null
+++ b/scripts/v1_5/finetune_task_lora.sh
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b335f8c77e293c2b943402496c638af0932043791849714c9de761ce72e0233c
+size 1239
diff --git a/scripts/v1_5/pretrain.sh b/scripts/v1_5/pretrain.sh
new file mode 100644
index 0000000000000000000000000000000000000000..ed39afa9140d6fc41783e1c8e03dc50183385321
--- /dev/null
+++ b/scripts/v1_5/pretrain.sh
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d7282a88558b5b9f87a8c8602991f33f151ec5f23b3abeadd08b950047e99ad0
+size 1164
diff --git a/scripts/zero2.json b/scripts/zero2.json
new file mode 100644
index 0000000000000000000000000000000000000000..7438c382833573ceab65a346d534b5dbdd537d8d
--- /dev/null
+++ b/scripts/zero2.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bb7d7a3360f8181c132e8666e6045cb75fa8a1a225d2b79148c1136fe99eaaa4
+size 556
diff --git a/scripts/zero3.json b/scripts/zero3.json
new file mode 100644
index 0000000000000000000000000000000000000000..eab316d14ba604a50f13fd19502a29977de6a0d4
--- /dev/null
+++ b/scripts/zero3.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8d25d1f104be517b09152d6398a7352a8db4956f956565ec2ed8af58feba7a81
+size 801
diff --git a/scripts/zero3_offload.json b/scripts/zero3_offload.json
new file mode 100644
index 0000000000000000000000000000000000000000..962b166c8aa2ec173ac010628d56eb164d7310d4
--- /dev/null
+++ b/scripts/zero3_offload.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:75da1bf2bb9220e7ab1a18a4b6a1ca2bb210835d373978aa94ebc6600b5e007c
+size 1279