diff --git a/packages/rag/default.nix b/packages/rag/default.nix new file mode 100644 index 0000000..242b7dc --- /dev/null +++ b/packages/rag/default.nix @@ -0,0 +1,159 @@ +{ writers, python3Packages }: +# Thin CLI RAG over a shared Qdrant collection. Embeddings + synthesis go +# through the OpenAI-compatible gateway (LiteLLM on sgx -> halo's llama-server). +# Configure via env: RAG_API_BASE, RAG_API_KEY, RAG_QDRANT_URL, RAG_EMBED_MODEL, +# RAG_CHAT_MODEL, RAG_COLLECTION. +writers.writePython3Bin "rag" + { + libraries = [ + python3Packages.openai + python3Packages.qdrant-client + ]; + flakeIgnore = [ "E501" ]; + } + '' + import argparse + import hashlib + import os + import pathlib + import sys + + from openai import OpenAI + from qdrant_client import QdrantClient + from qdrant_client.models import Distance, PointStruct, VectorParams + + API_BASE = os.environ.get("RAG_API_BASE", "http://localhost:4000/v1") + API_KEY = os.environ.get("RAG_API_KEY", "none") + QDRANT_URL = os.environ.get("RAG_QDRANT_URL", "http://localhost:6333") + EMBED_MODEL = os.environ.get("RAG_EMBED_MODEL", "bge-m3") + CHAT_MODEL = os.environ.get("RAG_CHAT_MODEL", "coder") + DEFAULT_COLLECTION = os.environ.get("RAG_COLLECTION", "docs") + + client = OpenAI(base_url=API_BASE, api_key=API_KEY) + qdrant = QdrantClient(url=QDRANT_URL) + + + def embed(texts): + resp = client.embeddings.create(model=EMBED_MODEL, input=texts) + return [d.embedding for d in resp.data] + + + def chunk_text(text, size=1000, overlap=200): + chunks = [] + start = 0 + n = len(text) + while start < n: + end = min(start + size, n) + chunks.append(text[start:end]) + if end == n: + break + start = end - overlap + return chunks + + + def iter_files(paths): + for p in paths: + path = pathlib.Path(p) + if path.is_dir(): + for f in sorted(path.rglob("*")): + if f.is_file(): + yield f + elif path.is_file(): + yield path + + + def read_text(f): + try: + return f.read_text(encoding="utf-8") + except (UnicodeDecodeError, OSError): + return None + + + def ensure_collection(name, dim): + if not qdrant.collection_exists(name): + qdrant.create_collection( + name, + vectors_config=VectorParams(size=dim, distance=Distance.COSINE), + ) + + + def point_id(source, idx): + h = hashlib.sha1(f"{source}:{idx}".encode()).hexdigest() + return int(h[:16], 16) + + + def cmd_ingest(args): + collection = args.collection + total = 0 + for f in iter_files(args.paths): + text = read_text(f) + if not text or not text.strip(): + continue + chunks = chunk_text(text) + vectors = embed(chunks) + ensure_collection(collection, len(vectors[0])) + points = [ + PointStruct( + id=point_id(str(f), i), + vector=v, + payload={"text": c, "source": str(f), "chunk": i}, + ) + for i, (c, v) in enumerate(zip(chunks, vectors)) + ] + qdrant.upsert(collection, points) + total += len(points) + print(f"ingested {len(points):4d} chunks {f}", file=sys.stderr) + print(f"done: {total} chunks into {collection!r}", file=sys.stderr) + + + def cmd_query(args): + collection = args.collection + qvec = embed([args.text])[0] + hits = qdrant.query_points(collection, query=qvec, limit=args.k).points + if not args.synthesize: + for h in hits: + src = h.payload.get("source", "?") + chunk = h.payload.get("chunk", "?") + txt = h.payload.get("text", "") + print(f"\n# {h.score:.3f} {src}#{chunk}\n{txt}") + return + context = "\n\n".join( + f"[{i + 1}] {h.payload.get('source')}\n{h.payload.get('text')}" + for i, h in enumerate(hits) + ) + messages = [ + {"role": "system", "content": "Answer the question using only the context below. Cite sources by their [n] tag. If the answer is not in the context, say so."}, + {"role": "user", "content": f"Context:\n{context}\n\nQuestion: {args.text}"}, + ] + stream = client.chat.completions.create(model=CHAT_MODEL, messages=messages, stream=True) + for ev in stream: + delta = ev.choices[0].delta.content + if delta: + sys.stdout.write(delta) + sys.stdout.flush() + sys.stdout.write("\n") + + + def main(): + parser = argparse.ArgumentParser(prog="rag", description="CLI RAG over a Qdrant collection using halo embeddings.") + sub = parser.add_subparsers(dest="cmd", required=True) + + pi = sub.add_parser("ingest", help="embed files/dirs into a collection") + pi.add_argument("paths", nargs="+") + pi.add_argument("-c", "--collection", default=DEFAULT_COLLECTION) + pi.set_defaults(func=cmd_ingest) + + pq = sub.add_parser("query", help="search a collection") + pq.add_argument("text") + pq.add_argument("-c", "--collection", default=DEFAULT_COLLECTION) + pq.add_argument("-k", type=int, default=5) + pq.add_argument("-s", "--synthesize", action="store_true", help="feed hits to the chat model for a written answer") + pq.set_defaults(func=cmd_query) + + args = parser.parse_args() + args.func(args) + + + if __name__ == "__main__": + main() + '' diff --git a/systems/x86_64-linux/sgx/default.nix b/systems/x86_64-linux/sgx/default.nix index c4290a0..d7074a8 100644 --- a/systems/x86_64-linux/sgx/default.nix +++ b/systems/x86_64-linux/sgx/default.nix @@ -12,6 +12,7 @@ ./wyoming.nix ./searx.nix ./litellm.nix + ./qdrant.nix ./uptime-kuma.nix ./firefly.nix ./opencode.nix @@ -25,6 +26,7 @@ environment.systemPackages = with pkgs; [ claude-code opencode + metacfg.rag ]; services.tailscale.enable = true; diff --git a/systems/x86_64-linux/sgx/litellm.nix b/systems/x86_64-linux/sgx/litellm.nix index 89095fe..e22412e 100644 --- a/systems/x86_64-linux/sgx/litellm.nix +++ b/systems/x86_64-linux/sgx/litellm.nix @@ -22,6 +22,16 @@ api_key = "none"; # llama-server requires no key; value is ignored }; } + { + # Multilingual embeddings, also served by halo's router (the `[bge-m3]` + # preset). Exposes /v1/embeddings on this gateway for the rag CLI. + model_name = "bge-m3"; + litellm_params = { + model = "openai/bge-m3"; + api_base = "http://halo:8000/v1"; + api_key = "none"; + }; + } ]; general_settings = { diff --git a/systems/x86_64-linux/sgx/qdrant.nix b/systems/x86_64-linux/sgx/qdrant.nix new file mode 100644 index 0000000..dcd184b --- /dev/null +++ b/systems/x86_64-linux/sgx/qdrant.nix @@ -0,0 +1,9 @@ +_: { + # Shared vector store for RAG, queried from any LAN machine by the rag CLI. + services.qdrant = { + enable = true; + settings.service.host = "0.0.0.0"; # default 127.0.0.1; LAN-reachable + }; + + networking.firewall.allowedTCPPorts = [ 6333 ]; # HTTP/REST API +}