Skip to content

Commit

Permalink
add use model
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed Jul 30, 2020
1 parent a901e56 commit b646912
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 2 deletions.
6 changes: 4 additions & 2 deletions MedSemanticSearch/medsearch/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ def __init__(self, dataset_cls:type, network_fn:Callable,

if network_args is None:
network_args={}
self.network = network_fn(**network_args)
if network_fn is not None:
self.network = network_fn(**network_args)

def load_weights(self, filename):
pass
Expand Down Expand Up @@ -46,7 +47,8 @@ def model(self):

class TensorflowModelBase(ModelBase):

def __init__(self, dataset_cls:type,
def __init__(self,
dataset_cls:type,
network_fn:Callable,
dataset_args:Dict=None,
network_args:Dict=None):
Expand Down
51 changes: 51 additions & 0 deletions MedSemanticSearch/medsearch/models/universal_sentence_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import tensorflow as tf
import tensorflow_text
import tensorflow_hub as hub

import numpy as np
from dataclasses import dataclass, field
from medsearch.datasets.dataset import SemanticCorpusDataset
from typing import Union, List, Tuple, Callable, Dict, Optional

class UniversalSentenceEncoderModel():
def __init__(self, dataset_cls:type=SemanticCorpusDataset, dataset_args:Dict=None ):
if dataset_args is None: dataset_args={}
self.data = dataset_cls(**dataset_args)
module_url = "https://tfhub.dev/google/universal-sentence-encoder/4"
self.model = hub.load(module_url)
self.batch_size = 16

def embed(self, input):
return self.model(input)

def get_similarity_vecs(self, queries:Union[str, List[str]], corpus:List[str], topk=5):
n_samples = len(corpus)
emb = np.zeros([n_samples, 512])
num_batches = n_samples // self.batch_size
for i in range(num_batches + 1):
start = self.batch_size * i
end = (self.batch_size * i) + self.batch_size
emb[start:end] = self.embed(corpus[start:end])
emb_query = self.embed(queries)[0]
### TODO make a separete function
input_matrix = np.vstack([[emb_query] * n_samples])
results = np.dot(input_matrix, emb.T)[0]
topk = results.argsort()[-topk:][::-1]
scores =[str(s) for s in results[topk]]
sentences = [corpus[idx] for idx in topk]
return sentences, scores

def run_test():
model = UniversalSentenceEncoderModel(dataset_args={'batch':1000})
data = model.data.load_one_batch()
corpus = [(f'{t} <SEP> {a}')[:512] for t,a in zip(data.title, data.paperAbstract)]
queries = ["breast cancer"]
sentences, scores = model.get_similarity_vecs(queries, corpus)

print(f"Queries: {queries}")
for i, (st, sc) in enumerate(zip(sentences,scores)):
print(f"Similar paper {i} Score : {sc}")
print(f"{st}")
print(f"-------------------------------------")
if __name__ == "__main__":
run_test()

0 comments on commit b646912

Please sign in to comment.