use std::ops::Deref; use std::sync::Arc; use iced::futures::StreamExt; use reqwest::Client; use serde::{Deserialize, Serialize}; use tokio::io::AsyncBufReadExt; use tokio::sync::RwLock; use tokio_stream::wrappers::LinesStream; use tokio_stream::Stream; use tokio_util::io::StreamReader; pub struct OllamaAPI { pub api_url: String, pub chats: Vec>>, } #[derive(Serialize)] pub struct OllamaChat { #[serde(skip_serializing)] client: Client, #[serde(skip_serializing)] api_url: String, #[serde(skip_serializing)] pub history_size: u32, pub model: String, pub messages: Vec, #[serde(skip_serializing)] pub generating_message: Option, } #[derive(Serialize, Deserialize, Clone)] pub enum Role { #[serde(rename = "system")] System, #[serde(rename = "user")] User, #[serde(rename = "assistant")] Assistant, } #[derive(Serialize, Deserialize, Clone)] pub struct ChatMessage { pub role: Role, pub content: String, } #[derive(Deserialize)] struct OllamaResponse { message: ChatMessage, } fn deserialize_chunk(chunk: Result) -> anyhow::Result { match chunk { Err(err) => Err(anyhow::Error::new(err)), Ok(str) => { let response = serde_json::from_str::(&str)?; Ok(response.message.content) } } } impl OllamaAPI { pub fn create(api_url: &str) -> anyhow::Result { Ok(Self { api_url: String::from(api_url), chats: Vec::new(), }) } pub fn create_chat(&mut self, model: &str) -> anyhow::Result>> { let client = Client::builder().build()?; let chat = Arc::new(RwLock::new(OllamaChat { history_size: 0, client, api_url: self.api_url.clone(), model: String::from(model), messages: Vec::new(), generating_message: None, })); self.chats.push(chat.clone()); Ok(chat) } } impl OllamaChat { pub fn send_message(&mut self, role: Role, content: &str) { self.messages.push(ChatMessage { role, content: String::from(content), }) } pub fn send_system(&mut self, content: &str) { self.send_message(Role::System, content); } pub fn send_user(&mut self, content: &str) { self.send_message(Role::User, content); } pub fn send_assistant(&mut self, content: &str) { self.send_message(Role::Assistant, content); } pub async fn complete( chat: Arc>, ) -> anyhow::Result>> { { let mut chat = chat.write().await; let body = serde_json::to_string(chat.deref())?; println!("Sending: {}", body); chat.generating_message = Some(String::new()); } let chat = chat.read().await; let request = chat .client .post(chat.api_url.clone() + "/api/chat") .body(serde_json::to_string(chat.deref())?) .build()?; let stream = chat .client .execute(request) .await? .bytes_stream() .map(|x| match x { Err(err) => Err(std::io::Error::new(std::io::ErrorKind::Other, err)), Ok(bytes) => Ok(bytes), }); let lines = StreamReader::new(stream).lines(); Ok(LinesStream::new(lines).map(deserialize_chunk)) } }