343 lines
12 KiB
Rust
343 lines
12 KiB
Rust
mod api;
|
|
|
|
use std::{
|
|
cell::RefCell,
|
|
io::{self, Write},
|
|
sync::{
|
|
mpsc::{Receiver, Sender},
|
|
Arc,
|
|
},
|
|
};
|
|
|
|
use api::{ChatMessage, OllamaAPI};
|
|
use iced::{
|
|
alignment::{Horizontal, Vertical},
|
|
futures::StreamExt,
|
|
subscription,
|
|
widget::{button, column, container, row, scrollable, text, text_input},
|
|
window, Application, Color, Command, Event, Length, Settings, Subscription, Theme,
|
|
};
|
|
use tokio::sync::RwLock;
|
|
|
|
use crate::api::{OllamaChat, Role};
|
|
|
|
enum UI {
|
|
Error(String),
|
|
Chats(UIState),
|
|
Closing,
|
|
}
|
|
|
|
struct UIState {
|
|
ollama_api: Arc<RwLock<OllamaAPI>>,
|
|
active_chat: Option<usize>,
|
|
chat_input: String,
|
|
send: Sender<UIMessage>,
|
|
recv: RefCell<Option<Receiver<UIMessage>>>,
|
|
}
|
|
|
|
impl UIState {
|
|
fn create(
|
|
api_url: &str,
|
|
send: Sender<UIMessage>,
|
|
recv: Receiver<UIMessage>,
|
|
) -> anyhow::Result<Self> {
|
|
Ok(Self {
|
|
ollama_api: Arc::new(RwLock::new(OllamaAPI::create(api_url)?)),
|
|
active_chat: None,
|
|
chat_input: String::new(),
|
|
send,
|
|
recv: RefCell::new(Some(recv)),
|
|
})
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
enum UIMessage {
|
|
Nop,
|
|
EventOccurred(Event),
|
|
CreateChat,
|
|
OpenChat(usize),
|
|
ChatInput(String),
|
|
ChatSubmit,
|
|
DoneGenerating,
|
|
}
|
|
|
|
fn main() -> anyhow::Result<()> {
|
|
let (send, recv) = std::sync::mpsc::channel();
|
|
let mut settings = Settings::with_flags(UIFlags { send, recv });
|
|
settings.exit_on_close_request = false;
|
|
settings.window = window::Settings {
|
|
size: (720, 480),
|
|
..window::Settings::default()
|
|
};
|
|
|
|
UI::run(settings)?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
struct UIFlags {
|
|
send: Sender<UIMessage>,
|
|
recv: Receiver<UIMessage>,
|
|
}
|
|
|
|
impl Application for UI {
|
|
type Executor = iced::executor::Default;
|
|
type Message = UIMessage;
|
|
type Theme = Theme;
|
|
type Flags = UIFlags;
|
|
|
|
fn new(flags: Self::Flags) -> (Self, iced::Command<Self::Message>) {
|
|
let state: UI;
|
|
let state_result = UIState::create("http://localhost:11434", flags.send, flags.recv);
|
|
if let Ok(s) = state_result {
|
|
state = UI::Chats(s);
|
|
} else {
|
|
state = UI::Error(String::from("Failed to initialize Ollama API"));
|
|
}
|
|
|
|
(state, Command::none())
|
|
}
|
|
|
|
fn title(&self) -> String {
|
|
String::from("AIGUI")
|
|
}
|
|
|
|
fn update(&mut self, message: Self::Message) -> iced::Command<Self::Message> {
|
|
if let UIMessage::EventOccurred(event) = &message {
|
|
if let Event::Window(window::Event::CloseRequested) = event {
|
|
*self = UI::Closing;
|
|
return window::close();
|
|
} else {
|
|
return Command::none();
|
|
}
|
|
}
|
|
|
|
match self {
|
|
UI::Chats(state) => match message {
|
|
UIMessage::CreateChat => {
|
|
let chat = state
|
|
.ollama_api
|
|
.blocking_write()
|
|
.create_chat("dolphin-mixtral");
|
|
if let Ok(chat) = chat {
|
|
chat.blocking_write().send_system("You write stories about various topics.\n\nYou can include images into these stories using the following syntax:\n{IMAGE:description of image}\n\ne.g.:\n{IMAGE:an image of a large church, volumetric lighting, masterpiece, best quality}\n\nMake sure to use at least one image every five sentences.");
|
|
}
|
|
Command::none()
|
|
}
|
|
UIMessage::OpenChat(index) => {
|
|
println!("Open chat {}", index);
|
|
state.active_chat = Some(index);
|
|
Command::none()
|
|
}
|
|
UIMessage::ChatInput(text) => {
|
|
println!("Input: {}", text);
|
|
state.chat_input = text;
|
|
Command::none()
|
|
}
|
|
UIMessage::ChatSubmit => {
|
|
println!("Submit");
|
|
|
|
if let Some(index) = state.active_chat {
|
|
let content = state.chat_input.clone();
|
|
state.chat_input = String::new();
|
|
let send = state.send.clone();
|
|
let ollama_api = state.ollama_api.clone();
|
|
return Command::perform(
|
|
async move {
|
|
/* TODO: Stable diffusion
|
|
let sd =
|
|
StableDiffusionAPI::create("http://localhost:7860").unwrap();
|
|
let req = StableDiffusionRequest {
|
|
prompt: String::from("among us"),
|
|
steps: 10,
|
|
};
|
|
let img = sd.generate_image(&req).await.unwrap();
|
|
println!("{:?}", img);
|
|
std::fs::write("amogus.png", img);*/
|
|
|
|
let stream;
|
|
let chat;
|
|
{
|
|
{
|
|
let api = ollama_api.read().await;
|
|
chat = api.chats[index].clone();
|
|
let mut chat = chat.write().await;
|
|
chat.send_user(&content);
|
|
chat.generating_message = Some(String::new());
|
|
}
|
|
|
|
//let mut chat = chat.write().await;
|
|
stream = OllamaChat::complete(chat.clone()).await;
|
|
}
|
|
|
|
if let Ok(mut stream) = stream {
|
|
while let Some(content) = stream.next().await {
|
|
if content.is_err() {
|
|
break;
|
|
}
|
|
|
|
let content = content.unwrap();
|
|
|
|
let chat_arc = chat.clone();
|
|
let mut chat = chat_arc.write().await;
|
|
let msg = &mut chat.generating_message;
|
|
if let Some(msg) = msg {
|
|
msg.push_str(&content);
|
|
}
|
|
|
|
print!("{}", content);
|
|
let _ = io::stdout().flush();
|
|
|
|
let _ = send.send(UIMessage::Nop);
|
|
}
|
|
|
|
let mut chat = chat.write().await;
|
|
let msg = &chat.generating_message.clone();
|
|
if let Some(msg) = msg {
|
|
chat.send_assistant(&msg);
|
|
chat.generating_message = None;
|
|
}
|
|
}
|
|
},
|
|
|_| UIMessage::DoneGenerating,
|
|
);
|
|
}
|
|
|
|
Command::none()
|
|
}
|
|
UIMessage::DoneGenerating => Command::none(),
|
|
_ => Command::none(),
|
|
},
|
|
UI::Error(_) => Command::none(),
|
|
UI::Closing => Command::none(),
|
|
}
|
|
}
|
|
|
|
fn view(&self) -> iced::Element<'_, Self::Message, iced::Renderer<Self::Theme>> {
|
|
match self {
|
|
UI::Error(message) => text(String::from("Error: ") + message)
|
|
.width(Length::Fill)
|
|
.height(Length::Fill)
|
|
.vertical_alignment(Vertical::Center)
|
|
.horizontal_alignment(Horizontal::Center)
|
|
.into(),
|
|
UI::Chats(state) => {
|
|
let api = state.ollama_api.blocking_read();
|
|
let mut chats: Vec<iced::Element<'_, Self::Message, iced::Renderer<Self::Theme>>> =
|
|
api.chats
|
|
.iter()
|
|
.enumerate()
|
|
.map(|c| {
|
|
button("Chat")
|
|
.on_press(UIMessage::OpenChat(c.0))
|
|
.width(Length::Fill)
|
|
.into()
|
|
})
|
|
.collect();
|
|
|
|
chats.insert(
|
|
0,
|
|
button("Create Chat")
|
|
.on_press(UIMessage::CreateChat)
|
|
.width(Length::Fill)
|
|
.into(),
|
|
);
|
|
|
|
let chats_list = scrollable(container(column(chats).padding(10.0).spacing(10)))
|
|
.height(Length::Fill)
|
|
.width(Length::Fixed(300.0));
|
|
|
|
let mut messages: Vec<
|
|
iced::Element<'_, Self::Message, iced::Renderer<Self::Theme>>,
|
|
>;
|
|
let mut input_box = text_input("Your message", &state.chat_input);
|
|
if let Some(chat_index) = state.active_chat {
|
|
let chat = &api.chats[chat_index].blocking_read();
|
|
messages = chat
|
|
.messages
|
|
.iter()
|
|
.map(|x| self.render_message(x))
|
|
.collect();
|
|
|
|
if let Some(msg) = &chat.generating_message {
|
|
let mut msg = msg.clone();
|
|
msg.push_str("...");
|
|
let in_progress_message = ChatMessage {
|
|
role: Role::Assistant,
|
|
content: msg,
|
|
};
|
|
messages.push(self.render_message(&in_progress_message));
|
|
} else {
|
|
input_box = input_box
|
|
.on_input(|x| UIMessage::ChatInput(x))
|
|
.on_submit(UIMessage::ChatSubmit)
|
|
}
|
|
} else {
|
|
messages = Vec::new();
|
|
}
|
|
|
|
let active_chat = container(
|
|
column![
|
|
scrollable(container(column(messages).spacing(10)))
|
|
.width(Length::Fill)
|
|
.height(Length::Fill),
|
|
input_box
|
|
]
|
|
.padding(10)
|
|
.spacing(10),
|
|
)
|
|
.width(Length::Fill)
|
|
.height(Length::Fill);
|
|
|
|
container(row![chats_list, active_chat]).into()
|
|
}
|
|
UI::Closing => text("Closing").into(),
|
|
}
|
|
}
|
|
|
|
fn theme(&self) -> Self::Theme {
|
|
Theme::Dark
|
|
}
|
|
|
|
fn subscription(&self) -> iced::Subscription<Self::Message> {
|
|
let event_subscription = subscription::events().map(UIMessage::EventOccurred);
|
|
match self {
|
|
UI::Chats(state) => {
|
|
let recv_subscription = iced::subscription::unfold(
|
|
"external messages",
|
|
state.recv.take(),
|
|
move |recv| async move {
|
|
let msg = recv.as_ref().unwrap().recv();
|
|
|
|
if let Err(_) = msg {
|
|
return (UIMessage::Nop, recv);
|
|
}
|
|
|
|
let msg = msg.unwrap();
|
|
(msg, recv)
|
|
},
|
|
);
|
|
|
|
Subscription::batch(vec![event_subscription, recv_subscription])
|
|
}
|
|
_ => event_subscription,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl UI {
|
|
fn render_message(
|
|
&self,
|
|
message: &ChatMessage,
|
|
) -> iced::Element<'_, UIMessage, iced::Renderer<Theme>> {
|
|
let (role_name, role_color) = match &message.role {
|
|
Role::User => ("User", Color::from_rgb8(128, 128, 128)),
|
|
Role::Assistant => ("Assistant", Color::from_rgb8(100, 100, 255)),
|
|
Role::System => ("System", Color::from_rgb8(255, 0, 0)),
|
|
};
|
|
|
|
column![text(role_name).style(role_color), text(&message.content)].into()
|
|
}
|
|
}
|