nixcfg/packages/rag/default.nix
Harald Hoyer 116d04665d fix(sgx): move LiteLLM off port 4000 to avoid uptime-kuma clash
Uptime Kuma already binds 4000, so the gateway never got the port and
requests hit the wrong service. Move LiteLLM to 4001 and update the rag
CLI default endpoint to match.
2026-05-22 07:08:26 +02:00

159 lines
5.4 KiB
Nix

{ writers, python3Packages }:
# Thin CLI RAG over a shared Qdrant collection. Embeddings + synthesis go
# through the OpenAI-compatible gateway (LiteLLM on sgx -> halo's llama-server).
# Configure via env: RAG_API_BASE, RAG_API_KEY, RAG_QDRANT_URL, RAG_EMBED_MODEL,
# RAG_CHAT_MODEL, RAG_COLLECTION.
writers.writePython3Bin "rag"
{
libraries = [
python3Packages.openai
python3Packages.qdrant-client
];
flakeIgnore = [ "E501" ];
}
''
import argparse
import hashlib
import os
import pathlib
import sys
from openai import OpenAI
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, PointStruct, VectorParams
API_BASE = os.environ.get("RAG_API_BASE", "http://localhost:4001/v1")
API_KEY = os.environ.get("RAG_API_KEY", "none")
QDRANT_URL = os.environ.get("RAG_QDRANT_URL", "http://localhost:6333")
EMBED_MODEL = os.environ.get("RAG_EMBED_MODEL", "bge-m3")
CHAT_MODEL = os.environ.get("RAG_CHAT_MODEL", "coder")
DEFAULT_COLLECTION = os.environ.get("RAG_COLLECTION", "docs")
client = OpenAI(base_url=API_BASE, api_key=API_KEY)
qdrant = QdrantClient(url=QDRANT_URL)
def embed(texts):
resp = client.embeddings.create(model=EMBED_MODEL, input=texts)
return [d.embedding for d in resp.data]
def chunk_text(text, size=1000, overlap=200):
chunks = []
start = 0
n = len(text)
while start < n:
end = min(start + size, n)
chunks.append(text[start:end])
if end == n:
break
start = end - overlap
return chunks
def iter_files(paths):
for p in paths:
path = pathlib.Path(p)
if path.is_dir():
for f in sorted(path.rglob("*")):
if f.is_file():
yield f
elif path.is_file():
yield path
def read_text(f):
try:
return f.read_text(encoding="utf-8")
except (UnicodeDecodeError, OSError):
return None
def ensure_collection(name, dim):
if not qdrant.collection_exists(name):
qdrant.create_collection(
name,
vectors_config=VectorParams(size=dim, distance=Distance.COSINE),
)
def point_id(source, idx):
h = hashlib.sha1(f"{source}:{idx}".encode()).hexdigest()
return int(h[:16], 16)
def cmd_ingest(args):
collection = args.collection
total = 0
for f in iter_files(args.paths):
text = read_text(f)
if not text or not text.strip():
continue
chunks = chunk_text(text)
vectors = embed(chunks)
ensure_collection(collection, len(vectors[0]))
points = [
PointStruct(
id=point_id(str(f), i),
vector=v,
payload={"text": c, "source": str(f), "chunk": i},
)
for i, (c, v) in enumerate(zip(chunks, vectors))
]
qdrant.upsert(collection, points)
total += len(points)
print(f"ingested {len(points):4d} chunks {f}", file=sys.stderr)
print(f"done: {total} chunks into {collection!r}", file=sys.stderr)
def cmd_query(args):
collection = args.collection
qvec = embed([args.text])[0]
hits = qdrant.query_points(collection, query=qvec, limit=args.k).points
if not args.synthesize:
for h in hits:
src = h.payload.get("source", "?")
chunk = h.payload.get("chunk", "?")
txt = h.payload.get("text", "")
print(f"\n# {h.score:.3f} {src}#{chunk}\n{txt}")
return
context = "\n\n".join(
f"[{i + 1}] {h.payload.get('source')}\n{h.payload.get('text')}"
for i, h in enumerate(hits)
)
messages = [
{"role": "system", "content": "Answer the question using only the context below. Cite sources by their [n] tag. If the answer is not in the context, say so."},
{"role": "user", "content": f"Context:\n{context}\n\nQuestion: {args.text}"},
]
stream = client.chat.completions.create(model=CHAT_MODEL, messages=messages, stream=True)
for ev in stream:
delta = ev.choices[0].delta.content
if delta:
sys.stdout.write(delta)
sys.stdout.flush()
sys.stdout.write("\n")
def main():
parser = argparse.ArgumentParser(prog="rag", description="CLI RAG over a Qdrant collection using halo embeddings.")
sub = parser.add_subparsers(dest="cmd", required=True)
pi = sub.add_parser("ingest", help="embed files/dirs into a collection")
pi.add_argument("paths", nargs="+")
pi.add_argument("-c", "--collection", default=DEFAULT_COLLECTION)
pi.set_defaults(func=cmd_ingest)
pq = sub.add_parser("query", help="search a collection")
pq.add_argument("text")
pq.add_argument("-c", "--collection", default=DEFAULT_COLLECTION)
pq.add_argument("-k", type=int, default=5)
pq.add_argument("-s", "--synthesize", action="store_true", help="feed hits to the chat model for a written answer")
pq.set_defaults(func=cmd_query)
args = parser.parse_args()
args.func(args)
if __name__ == "__main__":
main()
''