{ writers, python3Packages }: # Thin CLI RAG over a shared Qdrant collection. Embeddings + synthesis go # through the OpenAI-compatible gateway (LiteLLM on sgx -> halo's llama-server). # Configure via env: RAG_API_BASE, RAG_API_KEY, RAG_QDRANT_URL, RAG_EMBED_MODEL, # RAG_CHAT_MODEL, RAG_COLLECTION. writers.writePython3Bin "rag" { libraries = [ python3Packages.openai python3Packages.qdrant-client ]; flakeIgnore = [ "E501" ]; } '' import argparse import hashlib import os import pathlib import re import sys from openai import OpenAI from qdrant_client import QdrantClient from qdrant_client.models import Distance, PointStruct, VectorParams API_BASE = os.environ.get("RAG_API_BASE", "http://sgx:4001/v1") API_KEY = os.environ.get("RAG_API_KEY", "none") QDRANT_URL = os.environ.get("RAG_QDRANT_URL", "http://sgx:6333") EMBED_MODEL = os.environ.get("RAG_EMBED_MODEL", "bge-m3") CHAT_MODEL = os.environ.get("RAG_CHAT_MODEL", "fast") DEFAULT_COLLECTION = os.environ.get("RAG_COLLECTION", "docs") client = OpenAI(base_url=API_BASE, api_key=API_KEY) qdrant = QdrantClient(url=QDRANT_URL) def embed(texts): # encoding_format is explicit: llama.cpp rejects a null value, and # LiteLLM forwards an unset one as JSON null. resp = client.embeddings.create(model=EMBED_MODEL, input=texts, encoding_format="float") return [d.embedding for d in resp.data] def chunk_text(text, size=1000, overlap=200): chunks = [] start = 0 n = len(text) while start < n: end = min(start + size, n) chunks.append(text[start:end]) if end == n: break start = end - overlap return chunks def iter_files(paths): for p in paths: path = pathlib.Path(p) if path.is_dir(): for f in sorted(path.rglob("*")): if f.is_file(): yield f elif path.is_file(): yield path def read_text(f): try: return f.read_text(encoding="utf-8") except (UnicodeDecodeError, OSError): return None def ensure_collection(name, dim): if not qdrant.collection_exists(name): qdrant.create_collection( name, vectors_config=VectorParams(size=dim, distance=Distance.COSINE), ) def point_id(source, idx): h = hashlib.sha1(f"{source}:{idx}".encode()).hexdigest() return int(h[:16], 16) def cmd_ingest(args): collection = args.collection total = 0 for f in iter_files(args.paths): text = read_text(f) if not text or not text.strip(): continue chunks = chunk_text(text) vectors = embed(chunks) ensure_collection(collection, len(vectors[0])) points = [ PointStruct( id=point_id(str(f), i), vector=v, payload={"text": c, "source": str(f), "chunk": i}, ) for i, (c, v) in enumerate(zip(chunks, vectors)) ] qdrant.upsert(collection, points) total += len(points) print(f"ingested {len(points):4d} chunks {f}", file=sys.stderr) print(f"done: {total} chunks into {collection!r}", file=sys.stderr) def cmd_query(args): collection = args.collection qvec = embed([args.text])[0] hits = qdrant.query_points(collection, query=qvec, limit=args.k).points if not args.synthesize: for h in hits: src = h.payload.get("source", "?") chunk = h.payload.get("chunk", "?") txt = h.payload.get("text", "") print(f"\n# {h.score:.3f} {src}#{chunk}\n{txt}") return context = "\n\n".join( f"[{i + 1}] {h.payload.get('source')}\n{h.payload.get('text')}" for i, h in enumerate(hits) ) messages = [ {"role": "system", "content": "Answer the question using only the context below. Cite sources by their [n] tag. If the answer is not in the context, say so."}, {"role": "user", "content": f"Context:\n{context}\n\nQuestion: {args.text}"}, ] stream = client.chat.completions.create(model=CHAT_MODEL, messages=messages, stream=True) answer = [] for ev in stream: delta = ev.choices[0].delta.content if delta: answer.append(delta) sys.stdout.write(delta) sys.stdout.flush() sys.stdout.write("\n") # Legend for only the [n] tags actually cited (stderr keeps stdout pipe-clean). cited = sorted({int(n) for n in re.findall(r"\[(\d+)\]", "".join(answer))}) cited = [i for i in cited if 1 <= i <= len(hits)] if cited: print("\nQuellen:", file=sys.stderr) for i in cited: h = hits[i - 1] src = h.payload.get("source", "?") chunk = h.payload.get("chunk", "?") print(f" [{i}] {src}#{chunk} ({h.score:.3f})", file=sys.stderr) def main(): parser = argparse.ArgumentParser(prog="rag", description="CLI RAG over a Qdrant collection using halo embeddings.") sub = parser.add_subparsers(dest="cmd", required=True) pi = sub.add_parser("ingest", help="embed files/dirs into a collection") pi.add_argument("paths", nargs="+") pi.add_argument("-c", "--collection", default=DEFAULT_COLLECTION) pi.set_defaults(func=cmd_ingest) pq = sub.add_parser("query", help="search a collection") pq.add_argument("text") pq.add_argument("-c", "--collection", default=DEFAULT_COLLECTION) pq.add_argument("-k", type=int, default=5) pq.add_argument("-s", "--synthesize", action="store_true", help="feed hits to the chat model for a written answer") pq.set_defaults(func=cmd_query) args = parser.parse_args() args.func(args) if __name__ == "__main__": try: main() except KeyboardInterrupt: sys.exit(130) except BrokenPipeError: sys.exit(0) except Exception as e: if os.environ.get("RAG_DEBUG"): raise print(f"rag: {type(e).__name__}: {e}", file=sys.stderr) sys.exit(1) ''