Skip to content

Commit

Permalink
Add imports
Browse files Browse the repository at this point in the history
  • Loading branch information
stsievert committed Jul 31, 2022
1 parent 4bce1a7 commit 06f6bdd
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions docs/source/offline.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,22 +86,28 @@ This Python code will generate an embedding:

.. code-block:: python
import pandas as pd
from sklearn.model_selection import train_test_split
from salmon.triplets.offline import OfflineEmbedding
# Read in data
df = pd.read_csv("responses.csv") # from dashboard
X = df[["head", "winner", "loser"]].to_numpy()
em = pd.read_csv("embedding.csv") # from dashboard; optional
em = pd.read_csv("embedding.csv") # from dashboard
X = df[["head", "winner", "loser"]].to_numpy()
X_train, X_test = train_test_split(X, random_state=42, test_size=0.2)
n = int(X.max() + 1) # number of targets
d = 2 # embed into 2 dimensions
X_train, X_test = train_test_split(X, random_state=42, test_size=0.2)
# Fit the model
model = OfflineEmbedding(n=n, d=d, max_epochs=500_000)
model.initialize(X_train, embedding=em.to_numpy())
model.initialize(X_train, embedding=em.to_numpy()) # (optional)
model.fit(X_train, X_test)
# Inspect the model
model.embedding_ # embedding
model.history_ # to view information on how well train/test performed
Expand Down

0 comments on commit 06f6bdd

Please sign in to comment.