Skip to content

Commit

Permalink
fix faiss model
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed Aug 4, 2020
1 parent 918030e commit 22b8444
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
21 changes: 19 additions & 2 deletions MedSemanticSearch/medsearch/models/transformer__faiss_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,27 @@ class FaissTransformerModel(TorchModelBase):
def __init__(self, dataset_cls:type, network_fn:Callable=Transformer, dataset_args:Dict=None, network_args:Dict=None):
super().__init__(dataset_cls, None, network_fn, dataset_args, network_args)
pass
def encode(self, document:List[str]):

def encode(self, document:List[str])->torch.Tensor:
tokens = self.network.tokenize(document)
embed = self.network(**tokens)[0].detach().squeeze()
return torch.mean(embed, dim=0) # Average Vector

def load_word_vectors(input_path:str):
pass

def indexing(self):
word_vectors = [self.encode(doc) for doc in self.documents]
self.index = faiss.IndexIDMap(faiss.IndexFlatIP(self.network.get_word_embeddings_dim()))
self.index.add_with_ids(np.array([t.numpy() for t in word_vectors]), np.array(range(0, len(self.documents))))

def search(self, query:Union[str, List[str]], topk:int=5)->List[str]:
pass
query_embed = self.encode(query).unsqueeze(0).numpy()
res = self.index.search(query_embed, k=topk)
scores = res[0][0]
results = [self.document[_id] for _id in res[1][0]]
return list(zip(results, scores))




2 changes: 1 addition & 1 deletion MedSemanticSearch/medsearch/networks/Transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def forward(self, features):
features.update({'all_layer_embeddings':hidden_states})
return features

def get_word_embeddings(self)->int:
def get_word_embeddings_dim(self)->int:
return self.network.config.hidden_size

def tokenize(self, text:str)->List[int]:
Expand Down

0 comments on commit 22b8444

Please sign in to comment.