Survey of Fine Tuning Techniques
Large Language Models are the buzz everywhere. There are hundreds of blog posts and how-to guides. I started this series to record my own understanding.
Large Models (Language or Vision or Image) are trained on vast amounts of data. In some ways, they record the "world" data which is generic, and great for chat interface or generic tasks. However, these models do not perform well for specific tasks. As an example, if I am a paralegal wanting to look up US Court cases where there was a land dispute between neighbors over a tree in Seattle, ChatGPT is not able to provide this answer.
Assuming I have all the prior court transcripts since 1776, there are two choices for me
Train a new model
Fine-tune an existing base model
Training a new model is possible - it may not be "large" depending on the dataset. It might not have the same chat based interface either since it does not have language data outside of the training set. This would be an expensive process applying similar rigor as Meta, OpenAI and Google did to generate base models.
Fine-tuning an existing model is what we will cover here in this article. We will cover approaches for fine-tuning which may be suitable for most usecases, with lower barrier to entry.
Contents
… Retraining LLMs
…….. Fine tuning entire LLMs
…….. Lightweight fine-tuning (Adapter Tuning)
… In-Context Learning (Prompting) & RAG
…….. Zero Shot prompting
…….. Few Shot prompting
…….. Chain of Thought (CoT) prompting
…….. Retrieval Augmented Generation (RAG)
… Parameter Efficient Fine Tuning (PEFT)
…….. Prompt Tuning
……..…….. Hard Prompting
……..…….. Soft Prompting
…….. Prefix-tuning
…….. P-tuning
… Reparameterization
…….. Low-Rank Adaptation (LoRA) of LLMs
…….. Quantized LoRA (QLoRA)
Retraining LLMs
The most foundational way to fine-tune LLMs is to retrain the layers with the domain-specific data. Taking the approaches from Sebastian's excellent entry, training can be targeted on specific layers or all layers.
Fine tuning entire LLMs
This is covered as Finetuning II in the above diagram. This involves training whole models per task, to update weights of all the layers in a LM. Some references to papers below
For NLG
fine-tune masked language models (e.g., BERT; Devlin et al., 2019) and encode-decoder models (e.g., BART; Lewis et al., 2020) respectively (Zhong et al., 2020; Liu and Lapata, 2019; Raffel et al., 2020)
For machine translation
(Zhang et al., 2020c; Stickland et al., 2020; Zhu et al., 2020; Liu et al., 2020).
Lightweight fine-tuning (Adapter Tuning)
Lightweight fine tuning freezes most of the pre-trained parameters and modifies the pre-trained model with small trainable modules. For example, if we want to use the LM to train a ML Logistic Classifier model, then we can use a training feature-set to do so. (#1 in figure above)
Another example is to update the output layers, keeping all other layers frozen (#2 in figure above)
There are many more approaches, depending on the usecase.
In-Context Learning (Prompting) & RAG
Zero Shot prompting
In zero-shot prompting, we prepend a specific instruction to the user query without providing any direct examples to the model. Below is an example of the prepended instruction.
Provide a support solution based on the user concern.
User concern:
An example usage with user prompt would be
Provide a support solution based on the user concern.
User concern: My lights went out. What should I do?
Few Shot prompting
In few-shot prompting, a few examples are prepended to the user’s query. There are many variations of prompts that have been manually created that have shown better empirical results. An example of one such is shown above.
Chain of Thought (CoT) prompting
This is a layer built over Few Shot prompting. CoT prompting allows for detailed problem-solving by guiding the model through intermediate steps. This may provide the LLM examples of how to break down the thinking into steps, and cross check itself.
Retrieval Augmented Generation (RAG)
Introduced by Meta researchers, Retrieval Augmented Generation (RAG) a) retrieves relevant data from outside of the language model (non-parametric) and b) augments the data with context in the prompt to the LLM. By grounding the model on additional information, it allows for more accurate and context-aware responses.
The data can come from multiple data sources depending on the usecase. To use the data, it must be sized small enough to fit into the LLM query context, and a way to identify relevance. It is typical to split the data into chunks and then compute the embedding on those chunks. To avoid critical context to be lost between two chunks, these chunks have a small overlap with each other. Converting them to embeddings and storing in vector database, enables a similarity search with the user query. The top-k most similar embeddings are appended to the query before sent to the LLM
Advantages of RAG
Output Quality – When the model makes ‘best-guess’ assumptions, essentially filling in what it ‘doesn’t know’, the output can be incorrect or pure nonsense. Compared to simple prompt engineering, RAG produces results that are more accurate with a lower-chance of hallucinations.
Data freshness – RAG can adapt in situations where facts could evolve over time, making it useful for generating responses that require up-to-date information.
Observability – Using RAG, the source of the LLM’s answer can be pinpointed. Having traceability regarding the source of an answer can be beneficial for internal monitoring, quality assurance, or addressing customer disputes.
Cost – No training cost for fine-tuning
Parameter Efficient Fine Tuning (PEFT)
PEFT is a family of techniques used to augment a pre-trained LLM with a small set of new parameters and only fine-tune these new parameters, keeping the pre-trained LLM frozen. This is done because with models growing bigger and bigger like BLOOM which has a whopping 176 billion parameters, it is very costly to run a full fine-tuning on all the layers. Furthermore, to maintain data-freshness, retraining the full model takes a long time, while PEFT updates can be done in short cycles near-production.
Prompt Tuning
Prompt tuning refers to ways to modify the query prompt to the LLM to improve modeling results. These are generally classified into four types
Hard Prompting
Soft Prompting
Prefix-Tuning
P-Tuning
Hard Prompting
Hard prompts are similar to [[#Zero Shot prompting]], except that Hard Prompts modify the input tokens from the user input. It is splitting hairs - they are so similar that you wouldn't be wrong to interchangeably use them. Some people also call them as discreet prompts
2. From English to Spanish: {english} -> {spanish}
Soft Prompting
Soft Prompts (Lester et al. 2021) concatenates the embeddings of the input tokens with a trainable tensor that can be optimized via backpropagation to improve the modeling performance on a target task. These embedding vectors are not generated from the embedding space (token vocabulary) from the LM. Therefore, these prefixes are also referred to as "virtual tokens", since they do not share the embedding space as the LM and are much more compact
This tensor layer is generated during prompt-training, and is maintained & stored separately from the LLM. Depending on the number of tasks/features involved, there can per-task layer separately maintained, and loaded during inference.
Soft prompts differ from the discrete text prompts in that they are acquired through back-propagation and is thus adjusted based on loss feedback from a labeled dataset. It is often confused with Prefix-tuning as we will see below
Annotated figure from https://arxiv.org/abs/2104.08691.
Prefix-tuning
Prefix-tuning was developed in parallel by Stanford researchers (Li & Liang 2021) while Google researchers developed soft-tuning (Lester et al 2021). Both techniques are similar in generating task-specific during training. Prefix tuning adds trainable tensors to each transformer block while soft-prompting adds it to only the input embeddings.
There are no comparison studies that I can find between soft-tuning and prefix-tuning, but one can assume that since soft-tuning is applied to a single layer, it is more efficient, but less accurate than prefix-tuning.
Pros & Cons of Soft-tuning & Prefix-tuning
Advantages
Single LM (Language Model): a. no large training, b. No additional large storage
Inference can be batched, run in parallel
Multiple tasks can be batched together in parallel since LM is unmodified
Extended for providing history (in usecases like chatbots. Not relevant where historical context is not useful e.g. text classification)
Limitations
Models have a limited context size
Input tokens increases with additional in-context tokens. Processing more tokens equates to longer processing times, and higher inference cost
P-tuning
P-tuning is a variation of a soft prompt method. It also adds a trainable embedded tensor layer. There are a few differences:
Training: P-tuning uses a Prompt encoder (which is a bi-directional LSTM), which makes the generated prompts more optimized, and therefore scores better than prefix-tuning.
Prompt tokens are only added to the input instead of adding them to every layer of the model (like soft-prompting)
Reparameterization
Reparameterization is a group of techniques that freezes the pre-trained model weights, instead adds additional parameters to existing layers. Here we review LoRA and QLoRA as the only Reparameterization methods.
Classic fine-tuning of Large Language Models typically changes most or all weights of the models which requires a lot of resources. LoRA- and QLoRA-based fine-tunings freeze the original weights and only train a small number of parameters making the trainings much more efficient.
Low-Rank Adaptation (LoRA) of LLMs
Source - https://arxiv.org/pdf/2106.09685.pdf
LoRA is the most popular techniques in use today to efficiently fine-tuning and provide near-comparable results to full fine-tuning. Here we introduce two smaller matrices (called update matrices) in each transformer layer, that are trained and then merged together with the model weights to generate a new merged weights. The original weight matrix remains frozen and doesn’t receive any further adjustments. During inference, the merged weights are used in each layer.
Advantages
Original model is frozen
LoRA training (and retraining) is fast, and low cost compared to full fine-tuning, since only A & B matrices need to be updated
You can have multiple light LoRA models per task. During LoRA training, for each task/feature-set, the update matrices A and B need to be computed.
Performance of LoRA is comparable to full fine-tuned models
LoRA does not add any inference latency because the adapter weights are merged with the base model
In principle, LoRA can be applied to any subset of weight matrices in a neural network to reduce the number of trainable parameters. However, for simplicity and further parameter efficiency, in Transformer models LoRA is typically applied to attention blocks only. - @HuggingFace
Quantized LoRA (QLoRA)
Source - [[https://arxiv.org/pdf/2305.14314.pdf]]
Models use floating points for each weight storage. Floats uses 32-bit datatypes. Quantization is an optimization step of reducing the number of bits per weight. For example a weight stored as a 4-bit NormalFloat can represent 16 discrete values
QLoRa works by first quantizing a subset of weights to 4-bit precision. This reduces the memory footprint of the LLM, making it possible to finetune it on a single GPU. QLoRa then adds a sparse set of learnable Low-rank Adapter weights to the quantized LLM. These adapters are updated during finetuning by backpropogating gradients throught quantized weights.
Performance-wise, this reduces the GPU memory requirements as well as the training time by orders of magnitude. 16-bit finetuning of a LLaMA 65B parameter model requires 780 GB of GPU memory, while it needs <48GB using QLoRA.
Summary
The fine-tuning options covered above are some of the most popular ones. However, each problem domain, and data is different. Follow the best practices using metrics instead of shooting in the dark. Some people assume that retraining all layers is the gold standard, but I have found a few cases where that is not true.