AIGUI/src/api.rs

141 lines
3.5 KiB
Rust

use std::ops::Deref;
use std::sync::Arc;
use iced::futures::StreamExt;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use tokio::io::AsyncBufReadExt;
use tokio::sync::RwLock;
use tokio_stream::wrappers::LinesStream;
use tokio_stream::Stream;
use tokio_util::io::StreamReader;
pub struct OllamaAPI {
pub api_url: String,
pub chats: Vec<Arc<RwLock<OllamaChat>>>,
}
#[derive(Serialize)]
pub struct OllamaChat {
#[serde(skip_serializing)]
client: Client,
#[serde(skip_serializing)]
api_url: String,
#[serde(skip_serializing)]
pub history_size: u32,
pub model: String,
pub messages: Vec<ChatMessage>,
#[serde(skip_serializing)]
pub generating_message: Option<String>,
}
#[derive(Serialize, Deserialize, Clone)]
pub enum Role {
#[serde(rename = "system")]
System,
#[serde(rename = "user")]
User,
#[serde(rename = "assistant")]
Assistant,
}
#[derive(Serialize, Deserialize, Clone)]
pub struct ChatMessage {
pub role: Role,
pub content: String,
}
#[derive(Deserialize)]
struct OllamaResponse {
message: ChatMessage,
}
fn deserialize_chunk(chunk: Result<String, std::io::Error>) -> anyhow::Result<String> {
match chunk {
Err(err) => Err(anyhow::Error::new(err)),
Ok(str) => {
let response = serde_json::from_str::<OllamaResponse>(&str)?;
Ok(response.message.content)
}
}
}
impl OllamaAPI {
pub fn create(api_url: &str) -> anyhow::Result<Self> {
Ok(Self {
api_url: String::from(api_url),
chats: Vec::new(),
})
}
pub fn create_chat(&mut self, model: &str) -> anyhow::Result<Arc<RwLock<OllamaChat>>> {
let client = Client::builder().build()?;
let chat = Arc::new(RwLock::new(OllamaChat {
history_size: 0,
client,
api_url: self.api_url.clone(),
model: String::from(model),
messages: Vec::new(),
generating_message: None,
}));
self.chats.push(chat.clone());
Ok(chat)
}
}
impl OllamaChat {
pub fn send_message(&mut self, role: Role, content: &str) {
self.messages.push(ChatMessage {
role,
content: String::from(content),
})
}
pub fn send_system(&mut self, content: &str) {
self.send_message(Role::System, content);
}
pub fn send_user(&mut self, content: &str) {
self.send_message(Role::User, content);
}
pub fn send_assistant(&mut self, content: &str) {
self.send_message(Role::Assistant, content);
}
pub async fn complete(
chat: Arc<RwLock<OllamaChat>>,
) -> anyhow::Result<impl Stream<Item = anyhow::Result<String>>> {
{
let mut chat = chat.write().await;
let body = serde_json::to_string(chat.deref())?;
println!("Sending: {}", body);
chat.generating_message = Some(String::new());
}
let chat = chat.read().await;
let request = chat
.client
.post(chat.api_url.clone() + "/api/chat")
.body(serde_json::to_string(chat.deref())?)
.build()?;
let stream = chat
.client
.execute(request)
.await?
.bytes_stream()
.map(|x| match x {
Err(err) => Err(std::io::Error::new(std::io::ErrorKind::Other, err)),
Ok(bytes) => Ok(bytes),
});
let lines = StreamReader::new(stream).lines();
Ok(LinesStream::new(lines).map(deserialize_chunk))
}
}