|
| 1 | +import Control.Monad.IO.Class (liftIO) |
| 2 | +import System.Environment (getArgs, getEnv) |
| 3 | +import qualified Data.Aeson as Aeson |
| 4 | +import Data.Aeson (FromJSON, ToJSON) |
| 5 | +import GHC.Generics |
| 6 | +import Network.HTTP.Client.TLS (tlsManagerSettings) |
| 7 | +import Network.HTTP.Client (newManager, httpLbs, parseRequest, Request(..), RequestBody(..), responseBody, responseStatus) |
| 8 | +import Network.HTTP.Types.Status (statusCode) |
| 9 | +import qualified Data.Text as T |
| 10 | +import Data.Text.Encoding (encodeUtf8) |
| 11 | +import qualified Data.Vector as V |
| 12 | + |
| 13 | +data GeminiRequest = GeminiRequest |
| 14 | + { prompt :: String |
| 15 | + } deriving (Show, Generic, ToJSON) |
| 16 | + |
| 17 | +data GeminiResponse = GeminiResponse |
| 18 | + { candidates :: [Candidate] -- Changed from choices to candidates |
| 19 | + } deriving (Show, Generic, FromJSON) |
| 20 | + |
| 21 | +data Candidate = Candidate |
| 22 | + { content :: Content |
| 23 | + } deriving (Show, Generic, FromJSON) |
| 24 | + |
| 25 | +data Content = Content |
| 26 | + { parts :: [Part] |
| 27 | + } deriving (Show, Generic, FromJSON) |
| 28 | + |
| 29 | +data Part = Part |
| 30 | + { text :: String |
| 31 | + } deriving (Show, Generic, FromJSON, ToJSON) |
| 32 | + |
| 33 | +data PromptFeedback = PromptFeedback |
| 34 | + { blockReason :: Maybe String |
| 35 | + , safetyRatings :: Maybe [SafetyRating] |
| 36 | + } deriving (Show, Generic, FromJSON, ToJSON) |
| 37 | + |
| 38 | +data SafetyRating = SafetyRating |
| 39 | + { category :: String |
| 40 | + , probability :: String |
| 41 | + } deriving (Show, Generic, FromJSON, ToJSON) |
| 42 | + |
| 43 | +main :: IO () |
| 44 | +main = do |
| 45 | + args <- getArgs |
| 46 | + case args of |
| 47 | + [] -> putStrLn "Error: Please provide a prompt as a command-line argument." |
| 48 | + (arg:_) -> do -- Extract the argument directly |
| 49 | + apiKey <- getEnv "GOOGLE_API_KEY" |
| 50 | + |
| 51 | + manager <- newManager tlsManagerSettings |
| 52 | + |
| 53 | + initialRequest <- parseRequest "https://generativelanguage.googleapis.com/v1/models/gemini-1.5-pro:generateContent" |
| 54 | + |
| 55 | + let geminiRequestBody = Aeson.object [ |
| 56 | + ("contents", Aeson.Array $ V.singleton $ Aeson.object [ |
| 57 | + ("parts", Aeson.Array $ V.singleton $ Aeson.object [ |
| 58 | + ("text", Aeson.String $ T.pack arg) |
| 59 | + ]) |
| 60 | + ]), |
| 61 | + ("generationConfig", Aeson.object [ |
| 62 | + ("temperature", Aeson.Number 0.1), |
| 63 | + ("maxOutputTokens", Aeson.Number 800) |
| 64 | + ]) |
| 65 | + ] |
| 66 | + |
| 67 | + let request = initialRequest |
| 68 | + { requestHeaders = |
| 69 | + [ ("Content-Type", "application/json") |
| 70 | + , ("x-goog-api-key", encodeUtf8 $ T.pack apiKey) |
| 71 | + ] |
| 72 | + , method = "POST" |
| 73 | + , requestBody = RequestBodyLBS $ Aeson.encode geminiRequestBody |
| 74 | + } |
| 75 | + |
| 76 | + response <- httpLbs request manager |
| 77 | + |
| 78 | + let responseStatus' = responseStatus response |
| 79 | + |
| 80 | + if statusCode responseStatus' == 200 |
| 81 | + then do |
| 82 | + let maybeGeminiResponse = Aeson.decode (responseBody response) :: Maybe GeminiResponse |
| 83 | + case maybeGeminiResponse of |
| 84 | + Just geminiResponse -> do |
| 85 | + case candidates geminiResponse of |
| 86 | + (candidate:_) -> do |
| 87 | + case parts (content candidate) of |
| 88 | + (part:_) -> do -- Changed text_ to _ since it's unused |
| 89 | + liftIO $ putStrLn $ "Response:\n\n" ++ text part |
| 90 | + [] -> do |
| 91 | + liftIO $ putStrLn "Error: No parts in content" |
| 92 | + [] -> do |
| 93 | + liftIO $ putStrLn "Error: No candidates in response" |
| 94 | + Nothing -> do |
| 95 | + liftIO $ putStrLn "Error: Failed to parse response" |
| 96 | + else do |
| 97 | + putStrLn $ "Error: " ++ show responseStatus' |
0 commit comments