From 8933e580576dee575234e9e11aa75a12d4848c9b Mon Sep 17 00:00:00 2001 From: "vejvarm@freya" Date: Mon, 27 Nov 2023 21:17:25 +0900 Subject: [PATCH] refactor: add KeyError handling for DataBatch --- dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dataset.py b/dataset.py index 87f351a..6aa44f6 100644 --- a/dataset.py +++ b/dataset.py @@ -66,7 +66,7 @@ def __init__(self, batch: list[list[any]], vocabs: dict, device: str): entity_pointer.append(self._tensor([vocabs[ENTITY].stoi[s] for s in sample[7]])) KeyError: 'Q4622539' """ - entity_pointer.append(self._tensor([vocabs['NA'].stoi[s] for s in sample[7]])) + entity_pointer.append(self._tensor([vocabs['ENTITY'].stoi['NA']])) self.id = self._tensor(id).to(device) self.input = pad_sequence(inp, padding_value=vocabs[INPUT].stoi[PAD_TOKEN],