Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CS224W ReaRev GNN-RAG #9857

Open
wants to merge 41 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
d6c115a
Create rearev_data_loader.py
natalieshell22 Dec 13, 2024
f216279
Create rearev_data_loader_test.py
natalieshell22 Dec 13, 2024
7188791
Create rearev.py
natalieshell22 Dec 13, 2024
92a097d
Create trainer_kbqa.py
natalieshell22 Dec 13, 2024
218abf2
Create graph_utils.py
natalieshell22 Dec 13, 2024
6f7676e
Create reason.py
natalieshell22 Dec 13, 2024
019dd36
Update CHANGELOG.md
natalieshell22 Dec 13, 2024
03508a0
Update rearev_data_loader_test.py
natalieshell22 Dec 13, 2024
04fe125
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 13, 2024
7edb8ac
Update rearev_data_loader.py
natalieshell22 Dec 13, 2024
a126372
Update trainer_kbqa.py
natalieshell22 Dec 13, 2024
0c30f2e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 13, 2024
c788edb
Update rearev.py
natalieshell22 Dec 13, 2024
e411b63
Update reason.py
natalieshell22 Dec 13, 2024
21e3139
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 13, 2024
f59bc67
Update graph_utils.py
natalieshell22 Dec 13, 2024
c7fc6cc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 13, 2024
397291f
Update trainer_kbqa.py
natalieshell22 Dec 13, 2024
291d461
Update rearev.py
natalieshell22 Dec 13, 2024
8ed7c76
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 13, 2024
90961e0
Update graph_utils.py
natalieshell22 Dec 13, 2024
f06cab1
Update reason.py
natalieshell22 Dec 13, 2024
e85f635
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 13, 2024
5a090cd
Update trainer_kbqa.py
natalieshell22 Dec 13, 2024
70db241
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 13, 2024
834565a
Update trainer_kbqa.py
natalieshell22 Dec 13, 2024
0a0ace9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 13, 2024
119fd07
Update reason.py
natalieshell22 Dec 13, 2024
7f4e99a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 13, 2024
5f24243
Update graph_utils.py
natalieshell22 Dec 13, 2024
280fae9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 13, 2024
9f49b2f
Update rearev.py
natalieshell22 Dec 13, 2024
fbec144
Update graph_utils.py
natalieshell22 Dec 13, 2024
5952480
Update reason.py
natalieshell22 Dec 13, 2024
17295ae
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 13, 2024
33c9758
Update graph_utils.py
natalieshell22 Dec 13, 2024
3f43890
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 13, 2024
3a78d97
Update reason.py
natalieshell22 Dec 13, 2024
8485eaf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 13, 2024
320f77e
Update graph_utils.py
natalieshell22 Dec 13, 2024
5ce693e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Dec 13, 2024
commit 21e3139a7eb30bff888e4ad69b06f51d1a7c4ba1
12 changes: 7 additions & 5 deletions torch_geometric/nn/rearev.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import torch
import torch.nn as nn
from torch.autograd import Variable
from reason import

from torch_geometric.nn import GlobalAttention
from torch_geometric.utils.reason import Fusion, QueryReform, TypeLayer

from reason import


class ReaRev(torch.nn.Module):
"""Recurrent Attention and Reasoning model (ReaRev) for iterative reasoning
Expand Down Expand Up @@ -47,7 +49,7 @@ def _define_layers(self, args):
self.layers(args)

def layers(self, args):
"""Define layers for embedding transformations
"""Define layers for embedding transformations
and attention encoders."""
self.word_dim
self.kg_dim
Expand Down Expand Up @@ -138,7 +140,7 @@ def init_reason(self, curr_dist, local_entity, kb_adj_mat, q_input,
)

def calc_loss_label(self, curr_dist, teacher_dist, label_valid):
"""Computes the label loss based on current predictions vs.
"""Computes the label loss based on current predictions vs.
teacher distribution.
"""
tp_loss = self.get_loss(pred_dist=curr_dist, answer_dist=teacher_dist,
Expand All @@ -148,10 +150,10 @@ def calc_loss_label(self, curr_dist, teacher_dist, label_valid):
return cur_loss

def forward(self, batch, training=False):
"""Standard forward pass: perform reasoning steps and
"""Standard forward pass: perform reasoning steps and
produce predictions and loss.
"""
local_entity, query_entities, kb_adj_mat, query_text,
local_entity, query_entities, kb_adj_mat, query_text,
seed_dist, _, answer_dist = batch
local_entity = torch.from_numpy(local_entity).long().to(self.device)
query_entities = torch.from_numpy(query_entities).float().to(
Expand Down
Loading