102 lines
2.6 KiB
Java

package me.mrletsplay.ollama.api;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpRequest.BodyPublishers;
import java.net.http.HttpResponse.BodyHandlers;
import java.util.Iterator;
import java.util.Spliterator;
import java.util.Spliterators;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import me.mrletsplay.mrcore.json.JSONException;
import me.mrletsplay.mrcore.json.JSONObject;
public class OllamaAPI {
private String baseURL;
private HttpClient client;
public OllamaAPI(String baseURL) {
this.baseURL = baseURL;
this.client = HttpClient.newHttpClient();
}
public Stream<String> prompt(String model, String systemPrompt, String prompt) throws IOException, InterruptedException {
Chat c = startChat(model, systemPrompt);
c.sendUser(prompt);
return c.doCompletion();
}
public Stream<String> prompt(String model, String prompt) throws IOException, InterruptedException {
Chat c = startChat(model);
c.sendUser(prompt);
return c.doCompletion();
}
public Stream<JSONObject> prompt(Chat chat) throws IOException, InterruptedException {
JSONObject request = chat.toJSON();
InputStream in = client.send(HttpRequest.newBuilder(URI.create(baseURL + "/api/chat"))
.POST(BodyPublishers.ofString(request.toFancyString()))
.build(), BodyHandlers.ofInputStream()).body();
return createJSONStream(in);
}
public Chat startChat(String model) {
return new Chat(this, model);
}
public Chat startChat(String model, String systemPrompt) {
Chat chat = startChat(model);
chat.sendSystem(systemPrompt);
return chat;
}
private static Stream<JSONObject> createJSONStream(InputStream in) {
BufferedReader r = new BufferedReader(new InputStreamReader(in));
Iterator<JSONObject> it = new Iterator<>() {
private JSONObject nextObj;
private JSONObject tryRead() {
if(nextObj != null) return null;
try {
String line = r.readLine();
if(line == null) {
r.close();
return null;
}
return nextObj = new JSONObject(line);
} catch (IOException | JSONException e) {
throw new RuntimeException("Failed to read object", e);
}
}
@Override
public JSONObject next() {
tryRead();
JSONObject toReturn = nextObj;
nextObj = null;
return toReturn;
}
@Override
public boolean hasNext() {
return tryRead() != null;
}
};
return StreamSupport.stream(Spliterators.spliteratorUnknownSize(it, Spliterator.ORDERED), false);
}
}