A complete implementation of the Word2Vec Skip-Gram model with negative sampling using TensorFlow. This project, developed for an M.S. Machine Learning course, processes the Text8 Wikipedia corpus, trains word embeddings from scratch, and includes utilities for visualizing the resulting vector space.
- Efficient text preprocessing including stopword removal, punctuation tokenization, and frequency-based filtering.
- Implementation of Mikolov et al.'s subsampling heuristic for frequent words.
- Generation of skip-gram pairs with negative sampling using TensorFlow/Keras utilities.
- A custom
tf.keras.Modelclass for the Skip-Gram architecture. - A custom training loop with
tf.GradientTapefor fine-grained control and logging. - Checkpointing with
tf.train.CheckpointManagerto save model progress. - Command-line argument parsing with
argparseto easily configure hyperparameters. - Utilities to export embeddings (
vecs.tsv,meta.tsv) for visualization in the TensorFlow Embedding Projector.
- Word Embeddings (Word2Vec): Learning dense vector representations of words.
- Skip-Gram Architecture: A model that learns embeddings by predicting context words from a target word.
- Negative Sampling: An efficient optimization strategy that reframes the problem as binary classification.
- TensorFlow Custom Models: Subclassing
tf.keras.Modelto create a bespoke architecture. - TensorFlow Custom Training: Using
tf.GradientTapeto manage gradients and model optimization. - Text Preprocessing: Tokenization, filtering, and subsampling with
NLTKandTensorFlow. tf.dataAPI: Building efficient and scalable input pipelines for training.
This project trains word embeddings by implementing the Skip-Gram model with Negative Sampling. The core idea is that a word's meaning is defined by the company it keeps. The model learns by trying to distinguish true "context words" from random "negative sample" words.
- Load Data: The
text8dataset (the first 100MB of Wikipedia) is downloaded. - Preprocess Text: The raw text goes through a multi-step cleanup:
- Punctuation is tokenized (e.g.,
.becomes<PERIOD>). - Text is lowercased and stopwords (like 'the', 'is') are removed.
- Words appearing fewer than 5 times are filtered out.
- Subsampling: Frequent words (e.g., 'anarchism' in this corpus) are randomly dropped based on the heuristic formula to prevent them from dominating the training.
- Punctuation is tokenized (e.g.,
- Generate Pairs: The
tf.keras.preprocessing.sequence.skipgramsfunction is used. For each word in the text:- Positive Samples: It pairs the word with words inside its context window (e.g., 5 words before and after), assigning a label of
1. - Negative Samples: It pairs the word with random words from the vocabulary, assigning a label of
0.
- Positive Samples: It pairs the word with words inside its context window (e.g., 5 words before and after), assigning a label of
- Create
tf.dataPipeline: These pairs are converted into atf.data.Datasetand batched, shuffled, and prefetched for efficient GPU training.
The model (src/model.py) is surprisingly simple. It consists of two main components:
-
target_embedding: An embedding layer (a$V \times D$ matrix) for the target (center) word. -
context_embedding: An embedding layer (another$V \times D$ matrix) for the context word.
Here
The original Skip-Gram objective tries to predict the context given a target word. The probability of observing a context word
where
This project uses Negative Sampling as a more efficient objective. Instead of a multiclass prediction, we frame the problem as binary classification.
For each (target, context) pair, the model is trained to answer: "Is this a real context pair or a random, 'negative' pair?"
The objective function for a single positive pair
-
$\sigma(\cdot)$ is the sigmoid function. - The first term pushes the dot product of real pairs to be high.
- The second term pushes the dot product of fake (negative) pairs to be low.
This is exactly what our model implements. The skipgrams function provides the (target, context, label) triples, and our model's call method computes the dot product. We then apply tf.keras.losses.BinaryCrossentropy(from_logits=True), which is a numerically stable implementation of this exact sigmoid-based objective.
The model is trained using a custom loop in src/train.py. Based on the 5-epoch training run from the original notebook:
- Final Training Accuracy: ~88.9%
- Final Validation Accuracy: ~84.1%
The training accuracy steadily increases, while the validation accuracy plateaus around 84-85% after the second epoch. This is expected. The model quickly learns to distinguish true pairs from random ones.
The validation loss (as seen in the original notebook) starts to increase after epoch 2, while validation accuracy stays high. This suggests the model's confidence on the validation set is decreasing (slight overfitting), but its classification ability remains strong. In Word2Vec, the final classification accuracy is less important than the side-effect of the training: the learned target_embedding vectors. The high validation accuracy confirms that the vectors have successfully encoded the co-occurrence statistics of the corpus.
When the final vecs.tsv and meta.tsv files are loaded into the TensorFlow Embedding Projector, we would observe clear semantic clustering. For example:
anarchismwould be spatially close toproudhon,bakunin, andanarchists.fascismwould be in a different region, perhaps nearfascistsandrepublican.one,two,threewould cluster together.
This demonstrates the model successfully learned meaningful semantic relationships from the text.
tf-word2vec-skipgram/
├── .gitignore # Ignores venv, logs, data, and cache
├── LICENSE # MIT License file
├── README.md # This project guide
├── requirements.txt # Python dependencies
├── run_training.ipynb # A simple notebook to run the training
├── checkpoints/
│ └── .gitkeep # Holds model checkpoints
├── data/
│ └── .gitkeep # Default location for projector files
├── logs/
│ └── .gitkeep # Holds the training.log file
└── src/
├── __init__.py # Makes 'src' a Python package
├── config.py # All hyperparameters and paths
├── data_loader.py # Handles data download, preprocessing, and batching
├── main.py # Main executable script with argparse
├── model.py # Defines the SkipGramModel (tf.keras.Model)
├── train.py # Implements the custom training and eval loops
└── utils.py # Utility for logging and saving embeddings
-
Clone the Repository:
git clone https://github.com/msmrexe/tf-word2vec-skipgram.git cd tf-word2vec-skipgram -
Set up Environment and Install Dependencies: (Optional, but recommended)
python -m venv venv source venv/bin/activate # On Windows: venv\Scripts\activate
Install requirements:
pip install -r requirements.txt
-
Run the Training Pipeline:
You have two options:
Option A: Run the Jupyter Notebook (Easiest)
- Open and run all cells in
run_training.ipynb. This will execute the main script with default settings.
Option B: Use the Command Line Script
-
Run the main script directly. You can override any default setting from
src/config.pyusing command-line arguments.# Run with default settings (5 epochs, 128 embedding dim) python src/main.py # Run for more epochs and with a larger embedding dimension python src/main.py --epochs 10 --embedding_dim 150
- Open and run all cells in
-
View Results:
- Logs: All training progress is logged to the console and saved in
logs/training.log. - Checkpoints: Model checkpoints are saved in the
checkpoints/directory. - Embeddings: The final vectors (
vecs.tsv) and metadata (meta.tsv) are saved in thedata/directory.
- Logs: All training progress is logged to the console and saved in
-
Visualize in Embedding Projector:
- Go to http://projector.tensorflow.org/.
- Click on Load.
- Upload
data/vecs.tsvfor "Choose tensor file". - Upload
data/meta.tsvfor "Choose metadata file". - Explore the learned vector space!
Feel free to connect or reach out if you have any questions!
- Maryam Rezaee
- GitHub: @msmrexe
- Email: ms.maryamrezaee@gmail.com
This project is licensed under the MIT License. See the LICENSE file for full details.