Skip to content

Commit

Permalink
Initial Checkin of Text Search using Sentence Transformers
Browse files Browse the repository at this point in the history
  • Loading branch information
liamca committed Apr 9, 2021
1 parent 4ba8b95 commit e21811b
Show file tree
Hide file tree
Showing 8 changed files with 2,522 additions and 0 deletions.
64 changes: 64 additions & 0 deletions notebooks/text-search-question-answer/azureCognitiveSearch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from azure.core.credentials import AzureKeyCredential
from azure.search.documents.indexes import SearchIndexClient
from azure.search.documents import SearchClient
from azure.search.documents.indexes.models import (
ComplexField,
CorsOptions,
SearchIndex,
ScoringProfile,
SearchFieldDataType,
SimpleField,
SearchableField
)

# Set the service endpoint and API key from the environment
serviceName = ["Enter Search Service name -- DO NOT include .search.windows.net"]
adminKey = ["Search Service Admin API Key"]
indexName = "vec2text-msmarco"

# Create an SDK client
endpoint = "https://{}.search.windows.net/".format(serviceName)
adminClient = SearchIndexClient(endpoint=endpoint,
index_name=indexName,
credential=AzureKeyCredential(adminKey))

searchClient = SearchClient(endpoint=endpoint,
index_name=indexName,
credential=AzureKeyCredential(adminKey))

# Delete the index if it exists
def deleteIndex():
try:
result = adminClient.delete_index(indexName)
print ('Index', indexName, 'Deleted')
except Exception as ex:
print (ex)


# Create the index
def createIndex():
name = indexName
fields = [
SimpleField(name="Id", type=SearchFieldDataType.String, key=True),
SearchableField(name="Content", type=SearchFieldDataType.String, facetable=False, filterable=True, sortable=True, analyzer_name="en.microsoft"),
SearchableField(name="VecText", type=SearchFieldDataType.String, facetable=False, filterable=False, sortable=False)
]
cors_options = CorsOptions(allowed_origins=["*"], max_age_in_seconds=60)

index = SearchIndex(
name=name,
fields=fields,
cors_options=cors_options)

try:
result = adminClient.create_index(index)
print ('Index', result.name, 'Created')
except Exception as ex:
print (ex)

def uploadDocuments(documents):
try:
result = searchClient.upload_documents(documents=documents)
print("Upload of new document succeeded: {}".format(result[0].succeeded))
except Exception as ex:
print (ex.message)
Binary file not shown.
2,000 changes: 2,000 additions & 0 deletions notebooks/text-search-question-answer/data/collection-small.tsv

Large diffs are not rendered by default.

105 changes: 105 additions & 0 deletions notebooks/text-search-question-answer/vec2Text-msmarco-test.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import azureCognitiveSearch\n",
"import vec2Text\n",
"import vec2TextSentenceTransformer\n",
"\n",
"import pickle, datetime"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def vecToText(query):\n",
" curVec = vec2TextSentenceTransformer.contentToMeanEmbedding(query)\n",
" vecText = ''\n",
" for d in range(len(curVec)):\n",
" vecText += vec2Text.convertFieldNumToString(d) + str(vec2Text.closest(clusterCenters[d], curVec[d])) + ' '\n",
" return vecText\n",
"\n",
"\n",
"def executeQuery(query, searchField):\n",
" return azureCognitiveSearch.searchClient.search(search_text=query, include_total_count=False, search_fields=searchField, select='Content', top=5)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Load the cluster centers\n",
"with open(vec2Text.clusterCenterFile, 'rb') as pickle_in:\n",
" clusterCenters = pickle.load(pickle_in)\n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Perform a query using the traditional BM25\n",
"query = \"when did the manhattan project begin?\"\n",
"\n",
"results = executeQuery(query, 'Content')\n",
"for result in results:\n",
" print(\"{}\".format(result[\"Content\"]))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Perform a query using the sentence tranformer embeddings\n",
"query = \"when did the manhattan project begin?\"\n",
"\n",
"# Uncomment to see the fake terms created for this query\n",
"# print (vecToText(query) + '\\n')\n",
"results = executeQuery(vecToText(query), 'VecText')\n",
"for result in results:\n",
" print(\"{}\".format(result[\"Content\"]))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "py37_pytorch",
"language": "python",
"name": "conda-env-py37_pytorch-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
105 changes: 105 additions & 0 deletions notebooks/text-search-question-answer/vec2Text-msmarco-train.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import vec2Text\n",
"import vec2TextSentenceTransformer\n",
"\n",
"import random\n",
"import pickle"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#Find the number of dimensions in this model\n",
"with open(vec2Text.dataFile) as f:\n",
" fields = f.readline().split('\\t')\n",
" dimensions = vec2TextSentenceTransformer.calculateDimensions(fields[1])\n",
"print ('Dimensions:', dimensions)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Add test vectors to dictionary\n",
"vecDict = vec2Text.initializeVectorDictionary(dimensions)\n",
"\n",
"counter = 0\n",
"with open(vec2Text.dataFile) as f:\n",
"\n",
" for line in f:\n",
" fields = line.split('\\t')\n",
" cur_vec = vec2TextSentenceTransformer.contentToMeanEmbedding(fields[1])\n",
" for d in range(dimensions):\n",
" vecDict[str(d)].append(cur_vec[d])\n",
"\n",
" counter +=1\n",
" if counter % 100 == 0:\n",
" print ('Processed:', counter, 'of', vec2Text.testSamplesToTest)\n",
" if counter == vec2Text.testSamplesToTest:\n",
" print ('Completed:', counter, 'of', vec2Text.testSamplesToTest)\n",
" break\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Find the cluster centers\n",
"clusterCenters = vec2Text.findClusterCenters(dimensions, vecDict)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Save the cluster centers\n",
"with open(vec2Text.clusterCenterFile, 'wb') as pickle_out:\n",
" pickle.dump(clusterCenters, pickle_out)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "azureml_py36_tensorflow",
"language": "python",
"name": "conda-env-azureml_py36_tensorflow-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
100 changes: 100 additions & 0 deletions notebooks/text-search-question-answer/vec2Text-msmarco-upload.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import azureCognitiveSearch\n",
"import vec2Text\n",
"import vec2TextSentenceTransformer\n",
"\n",
"import pickle \n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Load the cluster centers\n",
"with open(vec2Text.clusterCenterFile, 'rb') as pickle_in:\n",
" clusterCenters = pickle.load(pickle_in)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Reset Index\n",
"azureCognitiveSearch.deleteIndex()\n",
"azureCognitiveSearch.createIndex()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Upload to Cognitive Search\n",
"max_batch_size = 100\n",
"documents = []\n",
"\n",
"counter = 0\n",
"with open(vec2Text.dataFile) as f:\n",
" for line in f:\n",
" fields = line.split('\\t')\n",
" curVec = vec2TextSentenceTransformer.contentToMeanEmbedding(fields[1])\n",
" vecText = ''\n",
" for d in range(len(curVec)):\n",
" vecText += vec2Text.convertFieldNumToString(d) + str(vec2Text.closest(clusterCenters[d], curVec[d])) + ' '\n",
"\n",
" documents.append({\"@search.action\": \"upload\",\"Id\": vec2Text.stringToBase64(fields[0]), \"Content\": fields[1],\"VecText\": vecText })\n",
"\n",
" counter += 1\n",
" if len(documents) == max_batch_size:\n",
" azureCognitiveSearch.uploadDocuments(documents)\n",
" documents = []\n",
" print ('Processed:', counter)\n",
"\n",
"if len(documents) > 0:\n",
" azureCognitiveSearch.uploadDocuments(documents)\n",
" documents = []\n",
" print ('Processed:', counter)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "azureml_py36_tensorflow",
"language": "python",
"name": "conda-env-azureml_py36_tensorflow-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Loading

0 comments on commit e21811b

Please sign in to comment.