Skip to content

Commit

Permalink
refine generate (#1562)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?



### Type of change

- [x] Refactoring
  • Loading branch information
KevinHuSh authored Jul 17, 2024
1 parent b06957e commit 9bf6f7c
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
5 changes: 5 additions & 0 deletions graph/component/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,11 @@ def get_input(self):
if DEBUG: print(self.component_name, reversed_cpnts[::-1])
for u in reversed_cpnts[::-1]:
if self.get_component_name(u) in ["switch"]: continue
if self.component_name.lower() == "generate" and self.get_component_name(u) == "retrieval":
o = self._canvas.get_component(u)["obj"].output(allow_partial=False)[1]
if o is not None:
upstream_outs.append(o)
continue
if u not in self._canvas.get_component(self._id)["upstream"]: continue
if self.component_name.lower().find("switch") < 0 \
and self.get_component_name(u) in ["relevant", "categorize"]:
Expand Down
2 changes: 1 addition & 1 deletion graph/component/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def _run(self, history, **kwargs):
prompt = self._param.prompt

retrieval_res = self.get_input()
input = "\n- ".join(retrieval_res["content"])
input = "\n- ".join(retrieval_res["content"]) if "content" in retrieval_res else ""
for para in self._param.parameters:
cpn = self._canvas.get_component(para["component_id"])["obj"]
_, out = cpn.output(allow_partial=False)
Expand Down

0 comments on commit 9bf6f7c

Please sign in to comment.