Skip to content

Commit

Permalink
Merge pull request #15 from VowpalWabbit/move_things_around
Browse files Browse the repository at this point in the history
Move everything into langchain_experimental
  • Loading branch information
olgavrou authored Sep 11, 2023
2 parents 3d700aa + 32445de commit 3a299b9
Show file tree
Hide file tree
Showing 17 changed files with 897 additions and 82 deletions.
30 changes: 19 additions & 11 deletions docs/extras/modules/chains/how_to/learned_prompt_optimization.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,17 @@
},
{
"cell_type": "code",
"execution_count": 36,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Install necessary packages\n",
"# ! pip install langchain langchain-experimental matplotlib"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -33,14 +43,14 @@
},
{
"cell_type": "code",
"execution_count": 37,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# pick and configure the LLM of your choice\n",
"\n",
"from langchain.llms import OpenAI\n",
"llm = OpenAI(engine=\"text-davinci-003\")\n"
"llm = OpenAI(engine=\"text-davinci-003\")"
]
},
{
Expand Down Expand Up @@ -93,7 +103,7 @@
"metadata": {},
"outputs": [],
"source": [
"import langchain.chains.rl_chain as rl_chain\n",
"import langchain_experimental.rl_chain as rl_chain\n",
"\n",
"chain = rl_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)\n"
]
Expand Down Expand Up @@ -466,12 +476,10 @@
}
],
"source": [
"# note matplotlib is not a dependency of langchain so you need to install to plot\n",
"\n",
"# from matplotlib import pyplot as plt\n",
"# chain.metrics.to_pandas()['score'].plot(label=\"default learning policy\")\n",
"# random_chain.metrics.to_pandas()['score'].plot(label=\"random selection policy\")\n",
"# plt.legend()\n",
"from matplotlib import pyplot as plt\n",
"chain.metrics.to_pandas()['score'].plot(label=\"default learning policy\")\n",
"random_chain.metrics.to_pandas()['score'].plot(label=\"random selection policy\")\n",
"plt.legend()\n",
"\n",
"print(f\"The final average score for the default policy, calculated over a rolling window, is: {chain.metrics.to_pandas()['score'].iloc[-1]}\")\n",
"print(f\"The final average score for the random policy, calculated over a rolling window, is: {random_chain.metrics.to_pandas()['score'].iloc[-1]}\")"
Expand Down Expand Up @@ -816,7 +824,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
"version": "3.9.17"
},
"orig_nbformat": 4
},
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging

from langchain.chains.rl_chain.base import (
from langchain_experimental.rl_chain.base import (
AutoSelectionScorer,
BasedOn,
Embed,
Expand All @@ -12,7 +12,7 @@
embed,
stringify_embedding,
)
from langchain.chains.rl_chain.pick_best_chain import (
from langchain_experimental.rl_chain.pick_best_chain import (
PickBest,
PickBestEvent,
PickBestFeatureEmbedder,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,20 @@
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.rl_chain.metrics import (
MetricsTrackerAverage,
MetricsTrackerRollingWindow,
)
from langchain.chains.rl_chain.model_repository import ModelRepository
from langchain.chains.rl_chain.vw_logger import VwLogger
from langchain.prompts import (
BasePromptTemplate,
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain.pydantic_v1 import BaseModel, Extra, root_validator

from langchain_experimental.pydantic_v1 import BaseModel, Extra, root_validator
from langchain_experimental.rl_chain.metrics import (
MetricsTrackerAverage,
MetricsTrackerRollingWindow,
)
from langchain_experimental.rl_chain.model_repository import ModelRepository
from langchain_experimental.rl_chain.vw_logger import VwLogger

if TYPE_CHECKING:
import vowpal_wabbit_next as vw
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import logging
from typing import Any, Dict, List, Optional, Tuple, Type, Union

import langchain.chains.rl_chain.base as base
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.llm import LLMChain
from langchain.prompts import BasePromptTemplate

import langchain_experimental.rl_chain.base as base

logger = logging.getLogger(__name__)

# sentinel object used to distinguish between
Expand Down
Loading

0 comments on commit 3a299b9

Please sign in to comment.