This repository is the implementation of the Large Language Model From Power Law Decoder Representations (PLDR-LLM) with KV-cache and G-cache detailed in the research article: PLDR-LLMs Learn A Generalizable Tensor Operator That Can Replace Its Own Deep Neural Net At Inference.
Large Language Model From Power Law Decoder Representations is a deductive-inductive LLM that utilizes the decoder layers that were first developed for Power Law Graph Transformer (PLGT). It was first introduced in the research article: PLDR-LLM: Large Language Model From Power Law Decoder Representations.
The deductive outputs of PLDR-LLM are generated by the Power Law Graph Attention (PLGA) mechanism at each decoder layer. PLDR-LLM takes advantage of the metric tensor, the potential tensor and energy-curvature tensor to observe and assess the model response or can be regularized to modify the model behaviour using a DAG loss.
The deductive outputs of PLDR-LLM exhibit unique characteristics such that it learns an invariant tensor operator that can replace the deep neural net of the PLGA section with one of the deductive outputs, GLM, at inference. This characteristic also makes it possible to implement KV-cache and G-cache straightforward to improve the inference time. LLM with Scaled Dot-Product Attention (SDPA) widely used in the literature is a special case of PLDR-LLM with GLM as identity.
The PLDR-LLM with KV-cache and G-cache support was implemented with Pytorch. For distributed training, it uses the Fully Sharded Data Parallel approach supported within the Pytorch framework.
The output and training procedure of PLDR-LLM is similar to LLMs that utilize decoders with SDPA. The inductive output is same as the transductive output of an LLM with SDPA. At inference time, it is straightforward to replace an LLM with SDPA with PLDR-LLM.
- The PLDR-LLMs that were pretrained and studied in the research paper can be found at
huggingface.co/fromthesky (PLDR-LLM-v51, PLDR-LLM-v51G and PLDR-LLM-v51-DAG versions). - A fork of the LM Evaluation Harness Suite with PLDR-LLM support with Pytorch is available at
lm-evaluation-harness-with-PLDR-LLM-kvg-cache.
- Support for KV-cache and G-cache for faster inference.
- A flexible interface to build, customize and train PLDR-LLMs with deep layers of decoders by specifying hyperparameters such as number of layers, number of attention heads, embedding dimension, vocabulary size and more through a dictionary of hyperparameters.
- Generate continuation text from input text through the following random sampling techniques: temperature, top-k or top-p (nucleus) sampling.
- Train interface that allows keeping track of global loss/accuracy, running loss/accuracy for training and validation datasets as well as DAG loss values for each deductive outputs in a single epoch.
- Run scripts to train PLDR-LLM on single or multiple-gpus with Fully Sharded Data Parallelism.
- Data preparation optimized for pretraining PLDR-LLMs. Implementation is optimized for the Refined-Web dataset used in the research paper.
- pldr_model_v510: This model type is the state of the art PLDR-LLM implementation with full PLGA network, KV-cache and G-cache. It learns deductive outputs A, ALM, AP, GLM along with other outputs as custom weights and biases.
- pldr_model_v510_dag:: This is pldr_model_v510 with dag regularization support.
- ablation/pldr_model_v510G: This version was used for ablation studies in the research paper. It replaces the deep neural network (residual layers, custom weights and biases) in PLGA with a predefined GLM during initialization. For GLM defined as Identity, it reduces to an LLM with SDPA.
- ablation/pldr_model_v510Gi: This version was also used for ablation studies in the research paper. It accepts configuration for a pldr_model_v510 model and uses its weights to initialize a pldr_model_v510G model for inference. This version is not intended for training.
Predefined GLM values for ablation: Contents of ablation/predefined-G_LM can be used as input to pldr_model_v510G as described in the research paper. Similarly, contents of ablation/negative-test-G-init can be used to initialize a predefined GLM using pldr_model_v510.
A PLDR-LLM model can be trained using a sample model train script in the scripts/ folder of each model type. The script accepts parameters for data preparation, model hyperparameters, and other settings for distributed training. It runs a main python script to start training and view on a log file:
-For pldr_model_v510 it runs the following python script. The shell script sample-train-script-pldrv510.sh have more details and sample values for the parameters.
python dist_pldr_v510_train_main.py --master_addr=$MASTER_ADDR \
--master_port=$MASTER_PORT \
--batch_size=$BATCH_SIZE \
--tok_model=$TOKEN_MODEL \
--context_length=$CONTEXT_LENGTH \
--train_sample_interval "${TRAIN_SAMPLE_INTERVAL[@]}" \
--val_sample_size=$VAL_SAMPLE_SIZE \
--buffer_size=$BUFFER_SIZE \
--dataset_file=$DATASET_FILE \
--dataset_column_label=$DATASET_COLUMN_LABEL \
--load_dataset \
--load_from_train \
--split_style=$SPLIT_STYLE \
--batch_agg_count=$BATCH_AGG_COUNT \
--padding_type=$PADDING_TYPE \
--trust_remote_code \
--num_layers=$NUM_LAYERS \
--num_heads=$NUM_HEADS \
--dk=$DK \
--num_reslayerA=$NUM_RESLAYERA \
--num_denseA=$NUM_DENSEA \
--Adff=$ADFF \
--epochs=$EPOCHS \
--save_model_path=$SAVE_MODEL_PATH \
--warmup_steps=$WARMUP_STEPS \
--train_batches_cnt=$TRAIN_BATCHES_CNT \
--val_batches_cnt=$VAL_BATCHES_CNT \
--learning_rate=$LEARNING_RATE \
--lr_alpha=$LR_ALPHA \
--adamw_decay=$ADAMW_DECAY \
--checkpoint_path=$CHECKPOINT_PATH \
--chkpt_batches "${CHKPT_BATCHES[@]}" \
--fsdp_sharding_strategy=$FSDP_SHARDING_STRATEGY \
--backward_prefetch=$BACKWARD_PREFETCH \
--verbose_freq=$VERBOSE_FREQ \
--val_verbose_freq=$VAL_VERBOSE_FREQ \
--is_train \
--device=$DEVICE \
--save_type=$SAVE_TYPE 2>&1 | tee $logfile
#other options that are not specified and use the default values:
# --enable_batch_count
# --fsdp_cpu_offload
# --disable_amp
# --disable_fsdp_mixed_precision
# --split_names
# --test_offset
# --shuffle_set
# --auto_size_minimum
# --chkpt_epochs
# --enable_full_dist_load
# --dff
After a model is trained, a train checkpoint found at the checkpoint_path can be loaded on a single GPU as follows:
import pdlr_run_model_v510 as pldr_run_model
e2e_obj=pldr_run_model.dist_pldr_model_e2e(rank=0,world_size=1,
inp_obj_src=inp_obj,
hpdict=hpdict,
checkpoint_path="./chkpt-saved-models",
load_ckpt="/path/to/pldrllm/chkpt/file.pth",
is_train=False,
device='cuda',
enable_full_dist_load=False)
Above method needs a dist_pldr_data_prep object as inp_obj for tokenizer and a dictionary hpdict for the model parameters that match the hyperparameters of the saved model.
The inp_obj is obtained by initializing a dist_pldr_data_prep object without loading the training dataset.
import pldr_data_prep
inp_obj = pldr_data_prep.dist_pldr_data_prep(rank=0,
WORLD_SIZE=1,
load_dataset=False,
tok_model = "/path/to/sentencepiece/tokenizer.model"
)
The hyperparameters are prepared by providing a dictionary of the following format for pldr_model_v510:
import torch.nn.functional as F
num_layers=5
num_heads=14
dk=64
adff=170
hpdict={"num_layers": num_layers,
"d_model": int(num_heads*dk),
"num_heads": num_heads,
"dff": int(np.floor(num_heads*dk*4*2/3)),
"num_reslayerA":8,
"num_denseA":2,
"A_dff":adff,
"input_vocab_size": inp_obj.tokenizer.vocab_size,
"max_seq_len":1024,
"epochs":1,
"save_model_path": "default_pldr_model",
"warmup_steps": 2000,
"lr_total_steps": 250000,
"learning_rate": 1e-3,
"lr_alpha":0.1,
"adamw_decay":0.1,
"activation":F.silu,
"device":'cuda',
"auto_size_minimum": None,
"disable_amp":False,
"disable_fsdp_mixed_precision":False,
"fsdp_cpu_offload":False,
"fsdp_sharding_strategy":"FULL_SHARD",
"backward_prefetch":"PRE",
"save_type": "torch"
}
Hyperparameters such as "warmup_steps", "learning_rate" are needed for training and they are not used when initializing the model for inference.
To generate continuation text for an input sentence, generate_text method is used. The temperature, top_k and top_p sampling can be stacked, but most common is to use one type of sampling. (temperature=1, top_k=0, top_p=1) setting disables all sampling approaches. For greedy sampling top_k=1 can be used (with top_p=1, temperature=1).
sentence="Write a letter requesting that people use language models responsibly."
text, att_weigths, kvcache_lst=e2e_obj.generate_text(sentence,
temperature=1.0, top_k=1, top_p=1.0,
enable_kvcache=True, enable_Gcache=True,
Gcachelst_init=None,
max_length=100, save_att=None)
print(text)
A nucleus sampling only example at top_p=0.6 would be as follows:
sentence="Write a letter requesting that people use language models responsibly."
text, att_weigths, kvcache_lst=e2e_obj.generate_text(sentence,
temperature=1.0, top_k=0, top_p=0.6,
enable_kvcache=True, enable_Gcache=True,
Gcachelst_init=None,
max_length=100, save_att=None)
print(text)
For pldr_model_v510, the it is also possible to provide a predefined GLM, mainly for ablation purposes. This reduces the model to a pldr_model_v510G while PLGA deep layers are still initialized but bypassed.
Gcachelst_init is a list of tuple torch tensors [ALM, GLM] for each decoder layer.
import common as cm
Gcachelst_init=cm.pklload("/path/to/predefined/G_LM/file.pkl")
sentence="Write a letter requesting that people use language models responsibly."
text, _, _=e2e_obj.generate_text(sentence,
temperature=1.0, top_k=0, top_p=0.6,
enable_kvcache=True, enable_Gcache=False,
Gcachelst_init=Gcachelst_init,
max_length=100, save_att=None)
print(text)
Below are the deductive outputs used for monitoring and regularizing the PLDR-LLM. For more on deductive outputs, please see the papers for PLDR-LLM and Power Law Graph Transformer.
Metric Tensor ALM (# decoder layers, # attention heads,
torch.stack([t[0][0] for t in att_weigths])
Potential Tensor AP (# decoder layers, # attention heads,
torch.stack([torch.pow(t[0], t[1])[0] for t in att_weigths])
Energy-Curvature Tensor GLM (# decoder layers, # attention heads,
torch.stack([t[4][0] for t in att_weigths])
Output of the residual network A (# decoder layers, # attention heads,
torch.stack([t[-1][0] for t in kvcache_lst])
Please cite this work as:
@misc{gokden2025pldrllmkvgcache,
title={PLDR-LLMs Learn A Generalizable Tensor Operator That Can Replace Its Own Deep Neural Net At Inference},
author={Burc Gokden},
year={2025},
eprint={2502.13502},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2502.13502},
}