-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Initial Checkin of Text Search using Sentence Transformers
- Loading branch information
Showing
8 changed files
with
2,522 additions
and
0 deletions.
There are no files selected for viewing
64 changes: 64 additions & 0 deletions
64
notebooks/text-search-question-answer/azureCognitiveSearch.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
from azure.core.credentials import AzureKeyCredential | ||
from azure.search.documents.indexes import SearchIndexClient | ||
from azure.search.documents import SearchClient | ||
from azure.search.documents.indexes.models import ( | ||
ComplexField, | ||
CorsOptions, | ||
SearchIndex, | ||
ScoringProfile, | ||
SearchFieldDataType, | ||
SimpleField, | ||
SearchableField | ||
) | ||
|
||
# Set the service endpoint and API key from the environment | ||
serviceName = ["Enter Search Service name -- DO NOT include .search.windows.net"] | ||
adminKey = ["Search Service Admin API Key"] | ||
indexName = "vec2text-msmarco" | ||
|
||
# Create an SDK client | ||
endpoint = "https://{}.search.windows.net/".format(serviceName) | ||
adminClient = SearchIndexClient(endpoint=endpoint, | ||
index_name=indexName, | ||
credential=AzureKeyCredential(adminKey)) | ||
|
||
searchClient = SearchClient(endpoint=endpoint, | ||
index_name=indexName, | ||
credential=AzureKeyCredential(adminKey)) | ||
|
||
# Delete the index if it exists | ||
def deleteIndex(): | ||
try: | ||
result = adminClient.delete_index(indexName) | ||
print ('Index', indexName, 'Deleted') | ||
except Exception as ex: | ||
print (ex) | ||
|
||
|
||
# Create the index | ||
def createIndex(): | ||
name = indexName | ||
fields = [ | ||
SimpleField(name="Id", type=SearchFieldDataType.String, key=True), | ||
SearchableField(name="Content", type=SearchFieldDataType.String, facetable=False, filterable=True, sortable=True, analyzer_name="en.microsoft"), | ||
SearchableField(name="VecText", type=SearchFieldDataType.String, facetable=False, filterable=False, sortable=False) | ||
] | ||
cors_options = CorsOptions(allowed_origins=["*"], max_age_in_seconds=60) | ||
|
||
index = SearchIndex( | ||
name=name, | ||
fields=fields, | ||
cors_options=cors_options) | ||
|
||
try: | ||
result = adminClient.create_index(index) | ||
print ('Index', result.name, 'Created') | ||
except Exception as ex: | ||
print (ex) | ||
|
||
def uploadDocuments(documents): | ||
try: | ||
result = searchClient.upload_documents(documents=documents) | ||
print("Upload of new document succeeded: {}".format(result[0].succeeded)) | ||
except Exception as ex: | ||
print (ex.message) |
Binary file not shown.
2,000 changes: 2,000 additions & 0 deletions
2,000
notebooks/text-search-question-answer/data/collection-small.tsv
Large diffs are not rendered by default.
Oops, something went wrong.
105 changes: 105 additions & 0 deletions
105
notebooks/text-search-question-answer/vec2Text-msmarco-test.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import azureCognitiveSearch\n", | ||
"import vec2Text\n", | ||
"import vec2TextSentenceTransformer\n", | ||
"\n", | ||
"import pickle, datetime" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def vecToText(query):\n", | ||
" curVec = vec2TextSentenceTransformer.contentToMeanEmbedding(query)\n", | ||
" vecText = ''\n", | ||
" for d in range(len(curVec)):\n", | ||
" vecText += vec2Text.convertFieldNumToString(d) + str(vec2Text.closest(clusterCenters[d], curVec[d])) + ' '\n", | ||
" return vecText\n", | ||
"\n", | ||
"\n", | ||
"def executeQuery(query, searchField):\n", | ||
" return azureCognitiveSearch.searchClient.search(search_text=query, include_total_count=False, search_fields=searchField, select='Content', top=5)\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Load the cluster centers\n", | ||
"with open(vec2Text.clusterCenterFile, 'rb') as pickle_in:\n", | ||
" clusterCenters = pickle.load(pickle_in)\n", | ||
" " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Perform a query using the traditional BM25\n", | ||
"query = \"when did the manhattan project begin?\"\n", | ||
"\n", | ||
"results = executeQuery(query, 'Content')\n", | ||
"for result in results:\n", | ||
" print(\"{}\".format(result[\"Content\"]))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Perform a query using the sentence tranformer embeddings\n", | ||
"query = \"when did the manhattan project begin?\"\n", | ||
"\n", | ||
"# Uncomment to see the fake terms created for this query\n", | ||
"# print (vecToText(query) + '\\n')\n", | ||
"results = executeQuery(vecToText(query), 'VecText')\n", | ||
"for result in results:\n", | ||
" print(\"{}\".format(result[\"Content\"]))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "py37_pytorch", | ||
"language": "python", | ||
"name": "conda-env-py37_pytorch-py" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.7.6" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 4 | ||
} |
105 changes: 105 additions & 0 deletions
105
notebooks/text-search-question-answer/vec2Text-msmarco-train.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import vec2Text\n", | ||
"import vec2TextSentenceTransformer\n", | ||
"\n", | ||
"import random\n", | ||
"import pickle" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"#Find the number of dimensions in this model\n", | ||
"with open(vec2Text.dataFile) as f:\n", | ||
" fields = f.readline().split('\\t')\n", | ||
" dimensions = vec2TextSentenceTransformer.calculateDimensions(fields[1])\n", | ||
"print ('Dimensions:', dimensions)\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Add test vectors to dictionary\n", | ||
"vecDict = vec2Text.initializeVectorDictionary(dimensions)\n", | ||
"\n", | ||
"counter = 0\n", | ||
"with open(vec2Text.dataFile) as f:\n", | ||
"\n", | ||
" for line in f:\n", | ||
" fields = line.split('\\t')\n", | ||
" cur_vec = vec2TextSentenceTransformer.contentToMeanEmbedding(fields[1])\n", | ||
" for d in range(dimensions):\n", | ||
" vecDict[str(d)].append(cur_vec[d])\n", | ||
"\n", | ||
" counter +=1\n", | ||
" if counter % 100 == 0:\n", | ||
" print ('Processed:', counter, 'of', vec2Text.testSamplesToTest)\n", | ||
" if counter == vec2Text.testSamplesToTest:\n", | ||
" print ('Completed:', counter, 'of', vec2Text.testSamplesToTest)\n", | ||
" break\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Find the cluster centers\n", | ||
"clusterCenters = vec2Text.findClusterCenters(dimensions, vecDict)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Save the cluster centers\n", | ||
"with open(vec2Text.clusterCenterFile, 'wb') as pickle_out:\n", | ||
" pickle.dump(clusterCenters, pickle_out)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "azureml_py36_tensorflow", | ||
"language": "python", | ||
"name": "conda-env-azureml_py36_tensorflow-py" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.6.9" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 4 | ||
} |
100 changes: 100 additions & 0 deletions
100
notebooks/text-search-question-answer/vec2Text-msmarco-upload.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import azureCognitiveSearch\n", | ||
"import vec2Text\n", | ||
"import vec2TextSentenceTransformer\n", | ||
"\n", | ||
"import pickle \n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Load the cluster centers\n", | ||
"with open(vec2Text.clusterCenterFile, 'rb') as pickle_in:\n", | ||
" clusterCenters = pickle.load(pickle_in)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Reset Index\n", | ||
"azureCognitiveSearch.deleteIndex()\n", | ||
"azureCognitiveSearch.createIndex()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Upload to Cognitive Search\n", | ||
"max_batch_size = 100\n", | ||
"documents = []\n", | ||
"\n", | ||
"counter = 0\n", | ||
"with open(vec2Text.dataFile) as f:\n", | ||
" for line in f:\n", | ||
" fields = line.split('\\t')\n", | ||
" curVec = vec2TextSentenceTransformer.contentToMeanEmbedding(fields[1])\n", | ||
" vecText = ''\n", | ||
" for d in range(len(curVec)):\n", | ||
" vecText += vec2Text.convertFieldNumToString(d) + str(vec2Text.closest(clusterCenters[d], curVec[d])) + ' '\n", | ||
"\n", | ||
" documents.append({\"@search.action\": \"upload\",\"Id\": vec2Text.stringToBase64(fields[0]), \"Content\": fields[1],\"VecText\": vecText })\n", | ||
"\n", | ||
" counter += 1\n", | ||
" if len(documents) == max_batch_size:\n", | ||
" azureCognitiveSearch.uploadDocuments(documents)\n", | ||
" documents = []\n", | ||
" print ('Processed:', counter)\n", | ||
"\n", | ||
"if len(documents) > 0:\n", | ||
" azureCognitiveSearch.uploadDocuments(documents)\n", | ||
" documents = []\n", | ||
" print ('Processed:', counter)\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "azureml_py36_tensorflow", | ||
"language": "python", | ||
"name": "conda-env-azureml_py36_tensorflow-py" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.6.9" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 4 | ||
} |
Oops, something went wrong.