Keep system messages in history

This commit is contained in:
MrLetsplay 2024-01-21 11:23:03 +01:00
parent 215e818f11
commit aa12b28fa7
Signed by: mr
SSH Key Fingerprint: SHA256:92jBH80vpXyaZHjaIl47pjRq+Yt7XGTArqQg1V7hSqg
2 changed files with 24 additions and 5 deletions

View File

@ -8,11 +8,20 @@ public class Ollama {
public static void main(String[] args) throws IOException, InterruptedException { public static void main(String[] args) throws IOException, InterruptedException {
OllamaAPI api = new OllamaAPI("http://localhost:11434"); OllamaAPI api = new OllamaAPI("http://localhost:11434");
// String systemPrompt = """
// You are AmogusBot, a bot based on the video game \"Among Us\". You don't talk about anything other than Among Us and constantly make references to the game Among Us.
// """;
String systemPrompt = """ String systemPrompt = """
You are AmogusBot, a bot based on the video game \"Among Us\". You don't talk about anything other than Among Us and constantly make references to the game Among Us. You are a bot that tells interactive stories. Give the user a choice they can make to advance the story.
Write about 5 sentences of story per choice and then give three choices like
1) Choice 1
2) Choice 2
3) Choice 3
"""; """;
Chat chat = api.startChat("dolphin-mixtral", systemPrompt); Chat chat = api.startChat("dolphin-mixtral", systemPrompt);
chat.setHistorySize(2);
try(Scanner s = new Scanner(System.in)) { try(Scanner s = new Scanner(System.in)) {
while(true) { while(true) {

View File

@ -69,10 +69,20 @@ public class Chat {
JSONObject obj = new JSONObject(); JSONObject obj = new JSONObject();
obj.put("model", model); obj.put("model", model);
JSONArray messages = new JSONArray(); List<JSONObject> messages = new ArrayList<>();
// TODO: keep system messages int count = 0;
this.messages.stream().skip(historySize == -1 ? 0 : Math.max(0, messages.size() - historySize)).forEach(m -> messages.add(m.toJSON())); for(int i = this.messages.size() - 1; i >= 0; i--) {
obj.put("messages", messages); ChatMessage message = this.messages.get(i);
if(message.role() == Role.SYSTEM) {
messages.add(0, message.toJSON());
continue;
}
if(historySize != -1 && count >= historySize) continue;
messages.add(0, message.toJSON());
count++;
}
obj.put("messages", new JSONArray(messages));
// JSONObject options = new JSONObject(); // JSONObject options = new JSONObject();
// options.put("num_thread", 16); // options.put("num_thread", 16);