You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
from torchdrug import utils
from torch.nn import functional as F
samples = []
categories = set()
for sample in valid_set:
sample.pop("graph")
category = tuple(sample.values())
if category not in categories:
categories.add(category)
samples.append(sample)
samples = data.graph_collate(samples)
samples = utils.cuda(samples)
preds = F.sigmoid(task.predict(samples))
targets = task.target(samples)
titles = []
for pred, target in zip(preds, targets):
pred = ", ".join(["%.2f" % p for p in pred])
target = ", ".join(["%d" % t for t in target])
titles.append("predict: %s\ntarget: %s" % (pred, target))
graph = samples["graph"]
graph.visualize(titles, figure_size=(3, 3.5), num_row=1)
The following error occurred:
Traceback (most recent call last):
File "/home/ibmc-2/Projects/MNIST/mnist_data/5-2.py", line 46, in <module>
preds = F.sigmoid(task.predict(samples))
File "/home/ibmc-2/anaconda3/envs/td/lib/python3.8/site-packages/torchdrug-0.1.0-py3.8.egg/torchdrug/tasks/property_prediction.py", line 104, in predict
graph = batch["graph"]
KeyError: 'graph'
Since I'm just getting started, the question may be absurd. Thank you
The text was updated successfully, but these errors were encountered:
Thank you, we can copy the 'sample' dictionary 'sample_1', and using sample_1.pop("graph") to filter, but the dict "samples" which will be predicted add a complete "sample" with graph information
for sample in valid_set:
sample_1 = sample.copy()
sample_1.pop("graph")
category = tuple(sample_1.values())
if category not in categories:
categories.add(category)
samples.append(sample)
When running the following code
The following error occurred:
Since I'm just getting started, the question may be absurd. Thank you
The text was updated successfully, but these errors were encountered: