Expose halo's [fast] MoE preset through the LiteLLM gateway and make it the rag CLI's default chat model (overridable via RAG_CHAT_MODEL), so query synthesis is quicker than the larger coder model.
184 lines
6.5 KiB
Nix
184 lines
6.5 KiB
Nix
{ writers, python3Packages }:
|
|
# Thin CLI RAG over a shared Qdrant collection. Embeddings + synthesis go
|
|
# through the OpenAI-compatible gateway (LiteLLM on sgx -> halo's llama-server).
|
|
# Configure via env: RAG_API_BASE, RAG_API_KEY, RAG_QDRANT_URL, RAG_EMBED_MODEL,
|
|
# RAG_CHAT_MODEL, RAG_COLLECTION.
|
|
writers.writePython3Bin "rag"
|
|
{
|
|
libraries = [
|
|
python3Packages.openai
|
|
python3Packages.qdrant-client
|
|
];
|
|
flakeIgnore = [ "E501" ];
|
|
}
|
|
''
|
|
import argparse
|
|
import hashlib
|
|
import os
|
|
import pathlib
|
|
import re
|
|
import sys
|
|
|
|
from openai import OpenAI
|
|
from qdrant_client import QdrantClient
|
|
from qdrant_client.models import Distance, PointStruct, VectorParams
|
|
|
|
API_BASE = os.environ.get("RAG_API_BASE", "http://sgx:4001/v1")
|
|
API_KEY = os.environ.get("RAG_API_KEY", "none")
|
|
QDRANT_URL = os.environ.get("RAG_QDRANT_URL", "http://sgx:6333")
|
|
EMBED_MODEL = os.environ.get("RAG_EMBED_MODEL", "bge-m3")
|
|
CHAT_MODEL = os.environ.get("RAG_CHAT_MODEL", "fast")
|
|
DEFAULT_COLLECTION = os.environ.get("RAG_COLLECTION", "docs")
|
|
|
|
client = OpenAI(base_url=API_BASE, api_key=API_KEY)
|
|
qdrant = QdrantClient(url=QDRANT_URL)
|
|
|
|
|
|
def embed(texts):
|
|
# encoding_format is explicit: llama.cpp rejects a null value, and
|
|
# LiteLLM forwards an unset one as JSON null.
|
|
resp = client.embeddings.create(model=EMBED_MODEL, input=texts, encoding_format="float")
|
|
return [d.embedding for d in resp.data]
|
|
|
|
|
|
def chunk_text(text, size=1000, overlap=200):
|
|
chunks = []
|
|
start = 0
|
|
n = len(text)
|
|
while start < n:
|
|
end = min(start + size, n)
|
|
chunks.append(text[start:end])
|
|
if end == n:
|
|
break
|
|
start = end - overlap
|
|
return chunks
|
|
|
|
|
|
def iter_files(paths):
|
|
for p in paths:
|
|
path = pathlib.Path(p)
|
|
if path.is_dir():
|
|
for f in sorted(path.rglob("*")):
|
|
if f.is_file():
|
|
yield f
|
|
elif path.is_file():
|
|
yield path
|
|
|
|
|
|
def read_text(f):
|
|
try:
|
|
return f.read_text(encoding="utf-8")
|
|
except (UnicodeDecodeError, OSError):
|
|
return None
|
|
|
|
|
|
def ensure_collection(name, dim):
|
|
if not qdrant.collection_exists(name):
|
|
qdrant.create_collection(
|
|
name,
|
|
vectors_config=VectorParams(size=dim, distance=Distance.COSINE),
|
|
)
|
|
|
|
|
|
def point_id(source, idx):
|
|
h = hashlib.sha1(f"{source}:{idx}".encode()).hexdigest()
|
|
return int(h[:16], 16)
|
|
|
|
|
|
def cmd_ingest(args):
|
|
collection = args.collection
|
|
total = 0
|
|
for f in iter_files(args.paths):
|
|
text = read_text(f)
|
|
if not text or not text.strip():
|
|
continue
|
|
chunks = chunk_text(text)
|
|
vectors = embed(chunks)
|
|
ensure_collection(collection, len(vectors[0]))
|
|
points = [
|
|
PointStruct(
|
|
id=point_id(str(f), i),
|
|
vector=v,
|
|
payload={"text": c, "source": str(f), "chunk": i},
|
|
)
|
|
for i, (c, v) in enumerate(zip(chunks, vectors))
|
|
]
|
|
qdrant.upsert(collection, points)
|
|
total += len(points)
|
|
print(f"ingested {len(points):4d} chunks {f}", file=sys.stderr)
|
|
print(f"done: {total} chunks into {collection!r}", file=sys.stderr)
|
|
|
|
|
|
def cmd_query(args):
|
|
collection = args.collection
|
|
qvec = embed([args.text])[0]
|
|
hits = qdrant.query_points(collection, query=qvec, limit=args.k).points
|
|
if not args.synthesize:
|
|
for h in hits:
|
|
src = h.payload.get("source", "?")
|
|
chunk = h.payload.get("chunk", "?")
|
|
txt = h.payload.get("text", "")
|
|
print(f"\n# {h.score:.3f} {src}#{chunk}\n{txt}")
|
|
return
|
|
context = "\n\n".join(
|
|
f"[{i + 1}] {h.payload.get('source')}\n{h.payload.get('text')}"
|
|
for i, h in enumerate(hits)
|
|
)
|
|
messages = [
|
|
{"role": "system", "content": "Answer the question using only the context below. Cite sources by their [n] tag. If the answer is not in the context, say so."},
|
|
{"role": "user", "content": f"Context:\n{context}\n\nQuestion: {args.text}"},
|
|
]
|
|
stream = client.chat.completions.create(model=CHAT_MODEL, messages=messages, stream=True)
|
|
answer = []
|
|
for ev in stream:
|
|
delta = ev.choices[0].delta.content
|
|
if delta:
|
|
answer.append(delta)
|
|
sys.stdout.write(delta)
|
|
sys.stdout.flush()
|
|
sys.stdout.write("\n")
|
|
# Legend for only the [n] tags actually cited (stderr keeps stdout pipe-clean).
|
|
cited = sorted({int(n) for n in re.findall(r"\[(\d+)\]", "".join(answer))})
|
|
cited = [i for i in cited if 1 <= i <= len(hits)]
|
|
if cited:
|
|
print("\nQuellen:", file=sys.stderr)
|
|
for i in cited:
|
|
h = hits[i - 1]
|
|
src = h.payload.get("source", "?")
|
|
chunk = h.payload.get("chunk", "?")
|
|
print(f" [{i}] {src}#{chunk} ({h.score:.3f})", file=sys.stderr)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(prog="rag", description="CLI RAG over a Qdrant collection using halo embeddings.")
|
|
sub = parser.add_subparsers(dest="cmd", required=True)
|
|
|
|
pi = sub.add_parser("ingest", help="embed files/dirs into a collection")
|
|
pi.add_argument("paths", nargs="+")
|
|
pi.add_argument("-c", "--collection", default=DEFAULT_COLLECTION)
|
|
pi.set_defaults(func=cmd_ingest)
|
|
|
|
pq = sub.add_parser("query", help="search a collection")
|
|
pq.add_argument("text")
|
|
pq.add_argument("-c", "--collection", default=DEFAULT_COLLECTION)
|
|
pq.add_argument("-k", type=int, default=5)
|
|
pq.add_argument("-s", "--synthesize", action="store_true", help="feed hits to the chat model for a written answer")
|
|
pq.set_defaults(func=cmd_query)
|
|
|
|
args = parser.parse_args()
|
|
args.func(args)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
main()
|
|
except KeyboardInterrupt:
|
|
sys.exit(130)
|
|
except BrokenPipeError:
|
|
sys.exit(0)
|
|
except Exception as e:
|
|
if os.environ.get("RAG_DEBUG"):
|
|
raise
|
|
print(f"rag: {type(e).__name__}: {e}", file=sys.stderr)
|
|
sys.exit(1)
|
|
''
|