From aa12b28fa7656f8a3e98663a79ad2f8821ecf97b Mon Sep 17 00:00:00 2001 From: MrLetsplay Date: Sun, 21 Jan 2024 11:23:03 +0100 Subject: [PATCH] Keep system messages in history --- src/main/java/Ollama.java | 11 ++++++++++- .../java/me/mrletsplay/ollama/api/Chat.java | 18 ++++++++++++++---- 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/src/main/java/Ollama.java b/src/main/java/Ollama.java index d55b1f0..57cc639 100644 --- a/src/main/java/Ollama.java +++ b/src/main/java/Ollama.java @@ -8,11 +8,20 @@ public class Ollama { public static void main(String[] args) throws IOException, InterruptedException { OllamaAPI api = new OllamaAPI("http://localhost:11434"); +// String systemPrompt = """ +// You are AmogusBot, a bot based on the video game \"Among Us\". You don't talk about anything other than Among Us and constantly make references to the game Among Us. +// """; + String systemPrompt = """ - You are AmogusBot, a bot based on the video game \"Among Us\". You don't talk about anything other than Among Us and constantly make references to the game Among Us. + You are a bot that tells interactive stories. Give the user a choice they can make to advance the story. + Write about 5 sentences of story per choice and then give three choices like + 1) Choice 1 + 2) Choice 2 + 3) Choice 3 """; Chat chat = api.startChat("dolphin-mixtral", systemPrompt); + chat.setHistorySize(2); try(Scanner s = new Scanner(System.in)) { while(true) { diff --git a/src/main/java/me/mrletsplay/ollama/api/Chat.java b/src/main/java/me/mrletsplay/ollama/api/Chat.java index a5ed2d7..4f617b8 100644 --- a/src/main/java/me/mrletsplay/ollama/api/Chat.java +++ b/src/main/java/me/mrletsplay/ollama/api/Chat.java @@ -69,10 +69,20 @@ public class Chat { JSONObject obj = new JSONObject(); obj.put("model", model); - JSONArray messages = new JSONArray(); - // TODO: keep system messages - this.messages.stream().skip(historySize == -1 ? 0 : Math.max(0, messages.size() - historySize)).forEach(m -> messages.add(m.toJSON())); - obj.put("messages", messages); + List messages = new ArrayList<>(); + int count = 0; + for(int i = this.messages.size() - 1; i >= 0; i--) { + ChatMessage message = this.messages.get(i); + if(message.role() == Role.SYSTEM) { + messages.add(0, message.toJSON()); + continue; + } + + if(historySize != -1 && count >= historySize) continue; + messages.add(0, message.toJSON()); + count++; + } + obj.put("messages", new JSONArray(messages)); // JSONObject options = new JSONObject(); // options.put("num_thread", 16);