diff --git a/packages/rag/default.nix b/packages/rag/default.nix index 09d82c5..1335461 100644 --- a/packages/rag/default.nix +++ b/packages/rag/default.nix @@ -16,6 +16,7 @@ writers.writePython3Bin "rag" import hashlib import os import pathlib + import re import sys from openai import OpenAI @@ -128,12 +129,24 @@ writers.writePython3Bin "rag" {"role": "user", "content": f"Context:\n{context}\n\nQuestion: {args.text}"}, ] stream = client.chat.completions.create(model=CHAT_MODEL, messages=messages, stream=True) + answer = [] for ev in stream: delta = ev.choices[0].delta.content if delta: + answer.append(delta) sys.stdout.write(delta) sys.stdout.flush() sys.stdout.write("\n") + # Legend for only the [n] tags actually cited (stderr keeps stdout pipe-clean). + cited = sorted({int(n) for n in re.findall(r"\[(\d+)\]", "".join(answer))}) + cited = [i for i in cited if 1 <= i <= len(hits)] + if cited: + print("\nQuellen:", file=sys.stderr) + for i in cited: + h = hits[i - 1] + src = h.payload.get("source", "?") + chunk = h.payload.get("chunk", "?") + print(f" [{i}] {src}#{chunk} ({h.score:.3f})", file=sys.stderr) def main(): @@ -157,5 +170,15 @@ writers.writePython3Bin "rag" if __name__ == "__main__": - main() + try: + main() + except KeyboardInterrupt: + sys.exit(130) + except BrokenPipeError: + sys.exit(0) + except Exception as e: + if os.environ.get("RAG_DEBUG"): + raise + print(f"rag: {type(e).__name__}: {e}", file=sys.stderr) + sys.exit(1) ''