141 lines
3.5 KiB
Rust
141 lines
3.5 KiB
Rust
use std::ops::Deref;
|
|
use std::sync::Arc;
|
|
|
|
use iced::futures::StreamExt;
|
|
use reqwest::Client;
|
|
use serde::{Deserialize, Serialize};
|
|
use tokio::io::AsyncBufReadExt;
|
|
use tokio::sync::RwLock;
|
|
use tokio_stream::wrappers::LinesStream;
|
|
use tokio_stream::Stream;
|
|
use tokio_util::io::StreamReader;
|
|
|
|
pub struct OllamaAPI {
|
|
pub api_url: String,
|
|
pub chats: Vec<Arc<RwLock<OllamaChat>>>,
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
pub struct OllamaChat {
|
|
#[serde(skip_serializing)]
|
|
client: Client,
|
|
#[serde(skip_serializing)]
|
|
api_url: String,
|
|
|
|
#[serde(skip_serializing)]
|
|
pub history_size: u32,
|
|
|
|
pub model: String,
|
|
pub messages: Vec<ChatMessage>,
|
|
|
|
#[serde(skip_serializing)]
|
|
pub generating_message: Option<String>,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Clone)]
|
|
pub enum Role {
|
|
#[serde(rename = "system")]
|
|
System,
|
|
#[serde(rename = "user")]
|
|
User,
|
|
#[serde(rename = "assistant")]
|
|
Assistant,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Clone)]
|
|
pub struct ChatMessage {
|
|
pub role: Role,
|
|
pub content: String,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
struct OllamaResponse {
|
|
message: ChatMessage,
|
|
}
|
|
|
|
fn deserialize_chunk(chunk: Result<String, std::io::Error>) -> anyhow::Result<String> {
|
|
match chunk {
|
|
Err(err) => Err(anyhow::Error::new(err)),
|
|
Ok(str) => {
|
|
let response = serde_json::from_str::<OllamaResponse>(&str)?;
|
|
Ok(response.message.content)
|
|
}
|
|
}
|
|
}
|
|
|
|
impl OllamaAPI {
|
|
pub fn create(api_url: &str) -> anyhow::Result<Self> {
|
|
Ok(Self {
|
|
api_url: String::from(api_url),
|
|
chats: Vec::new(),
|
|
})
|
|
}
|
|
|
|
pub fn create_chat(&mut self, model: &str) -> anyhow::Result<Arc<RwLock<OllamaChat>>> {
|
|
let client = Client::builder().build()?;
|
|
let chat = Arc::new(RwLock::new(OllamaChat {
|
|
history_size: 0,
|
|
client,
|
|
api_url: self.api_url.clone(),
|
|
model: String::from(model),
|
|
messages: Vec::new(),
|
|
generating_message: None,
|
|
}));
|
|
|
|
self.chats.push(chat.clone());
|
|
Ok(chat)
|
|
}
|
|
}
|
|
|
|
impl OllamaChat {
|
|
pub fn send_message(&mut self, role: Role, content: &str) {
|
|
self.messages.push(ChatMessage {
|
|
role,
|
|
content: String::from(content),
|
|
})
|
|
}
|
|
|
|
pub fn send_system(&mut self, content: &str) {
|
|
self.send_message(Role::System, content);
|
|
}
|
|
|
|
pub fn send_user(&mut self, content: &str) {
|
|
self.send_message(Role::User, content);
|
|
}
|
|
|
|
pub fn send_assistant(&mut self, content: &str) {
|
|
self.send_message(Role::Assistant, content);
|
|
}
|
|
|
|
pub async fn complete(
|
|
chat: Arc<RwLock<OllamaChat>>,
|
|
) -> anyhow::Result<impl Stream<Item = anyhow::Result<String>>> {
|
|
{
|
|
let mut chat = chat.write().await;
|
|
let body = serde_json::to_string(chat.deref())?;
|
|
println!("Sending: {}", body);
|
|
|
|
chat.generating_message = Some(String::new());
|
|
}
|
|
|
|
let chat = chat.read().await;
|
|
|
|
let request = chat
|
|
.client
|
|
.post(chat.api_url.clone() + "/api/chat")
|
|
.body(serde_json::to_string(chat.deref())?)
|
|
.build()?;
|
|
let stream = chat
|
|
.client
|
|
.execute(request)
|
|
.await?
|
|
.bytes_stream()
|
|
.map(|x| match x {
|
|
Err(err) => Err(std::io::Error::new(std::io::ErrorKind::Other, err)),
|
|
Ok(bytes) => Ok(bytes),
|
|
});
|
|
let lines = StreamReader::new(stream).lines();
|
|
Ok(LinesStream::new(lines).map(deserialize_chunk))
|
|
}
|
|
}
|