Uptime Kuma already binds 4000, so the gateway never got the port and requests hit the wrong service. Move LiteLLM to 4001 and update the rag CLI default endpoint to match.
159 lines
5.4 KiB
Nix
159 lines
5.4 KiB
Nix
{ writers, python3Packages }:
|
|
# Thin CLI RAG over a shared Qdrant collection. Embeddings + synthesis go
|
|
# through the OpenAI-compatible gateway (LiteLLM on sgx -> halo's llama-server).
|
|
# Configure via env: RAG_API_BASE, RAG_API_KEY, RAG_QDRANT_URL, RAG_EMBED_MODEL,
|
|
# RAG_CHAT_MODEL, RAG_COLLECTION.
|
|
writers.writePython3Bin "rag"
|
|
{
|
|
libraries = [
|
|
python3Packages.openai
|
|
python3Packages.qdrant-client
|
|
];
|
|
flakeIgnore = [ "E501" ];
|
|
}
|
|
''
|
|
import argparse
|
|
import hashlib
|
|
import os
|
|
import pathlib
|
|
import sys
|
|
|
|
from openai import OpenAI
|
|
from qdrant_client import QdrantClient
|
|
from qdrant_client.models import Distance, PointStruct, VectorParams
|
|
|
|
API_BASE = os.environ.get("RAG_API_BASE", "http://localhost:4001/v1")
|
|
API_KEY = os.environ.get("RAG_API_KEY", "none")
|
|
QDRANT_URL = os.environ.get("RAG_QDRANT_URL", "http://localhost:6333")
|
|
EMBED_MODEL = os.environ.get("RAG_EMBED_MODEL", "bge-m3")
|
|
CHAT_MODEL = os.environ.get("RAG_CHAT_MODEL", "coder")
|
|
DEFAULT_COLLECTION = os.environ.get("RAG_COLLECTION", "docs")
|
|
|
|
client = OpenAI(base_url=API_BASE, api_key=API_KEY)
|
|
qdrant = QdrantClient(url=QDRANT_URL)
|
|
|
|
|
|
def embed(texts):
|
|
resp = client.embeddings.create(model=EMBED_MODEL, input=texts)
|
|
return [d.embedding for d in resp.data]
|
|
|
|
|
|
def chunk_text(text, size=1000, overlap=200):
|
|
chunks = []
|
|
start = 0
|
|
n = len(text)
|
|
while start < n:
|
|
end = min(start + size, n)
|
|
chunks.append(text[start:end])
|
|
if end == n:
|
|
break
|
|
start = end - overlap
|
|
return chunks
|
|
|
|
|
|
def iter_files(paths):
|
|
for p in paths:
|
|
path = pathlib.Path(p)
|
|
if path.is_dir():
|
|
for f in sorted(path.rglob("*")):
|
|
if f.is_file():
|
|
yield f
|
|
elif path.is_file():
|
|
yield path
|
|
|
|
|
|
def read_text(f):
|
|
try:
|
|
return f.read_text(encoding="utf-8")
|
|
except (UnicodeDecodeError, OSError):
|
|
return None
|
|
|
|
|
|
def ensure_collection(name, dim):
|
|
if not qdrant.collection_exists(name):
|
|
qdrant.create_collection(
|
|
name,
|
|
vectors_config=VectorParams(size=dim, distance=Distance.COSINE),
|
|
)
|
|
|
|
|
|
def point_id(source, idx):
|
|
h = hashlib.sha1(f"{source}:{idx}".encode()).hexdigest()
|
|
return int(h[:16], 16)
|
|
|
|
|
|
def cmd_ingest(args):
|
|
collection = args.collection
|
|
total = 0
|
|
for f in iter_files(args.paths):
|
|
text = read_text(f)
|
|
if not text or not text.strip():
|
|
continue
|
|
chunks = chunk_text(text)
|
|
vectors = embed(chunks)
|
|
ensure_collection(collection, len(vectors[0]))
|
|
points = [
|
|
PointStruct(
|
|
id=point_id(str(f), i),
|
|
vector=v,
|
|
payload={"text": c, "source": str(f), "chunk": i},
|
|
)
|
|
for i, (c, v) in enumerate(zip(chunks, vectors))
|
|
]
|
|
qdrant.upsert(collection, points)
|
|
total += len(points)
|
|
print(f"ingested {len(points):4d} chunks {f}", file=sys.stderr)
|
|
print(f"done: {total} chunks into {collection!r}", file=sys.stderr)
|
|
|
|
|
|
def cmd_query(args):
|
|
collection = args.collection
|
|
qvec = embed([args.text])[0]
|
|
hits = qdrant.query_points(collection, query=qvec, limit=args.k).points
|
|
if not args.synthesize:
|
|
for h in hits:
|
|
src = h.payload.get("source", "?")
|
|
chunk = h.payload.get("chunk", "?")
|
|
txt = h.payload.get("text", "")
|
|
print(f"\n# {h.score:.3f} {src}#{chunk}\n{txt}")
|
|
return
|
|
context = "\n\n".join(
|
|
f"[{i + 1}] {h.payload.get('source')}\n{h.payload.get('text')}"
|
|
for i, h in enumerate(hits)
|
|
)
|
|
messages = [
|
|
{"role": "system", "content": "Answer the question using only the context below. Cite sources by their [n] tag. If the answer is not in the context, say so."},
|
|
{"role": "user", "content": f"Context:\n{context}\n\nQuestion: {args.text}"},
|
|
]
|
|
stream = client.chat.completions.create(model=CHAT_MODEL, messages=messages, stream=True)
|
|
for ev in stream:
|
|
delta = ev.choices[0].delta.content
|
|
if delta:
|
|
sys.stdout.write(delta)
|
|
sys.stdout.flush()
|
|
sys.stdout.write("\n")
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(prog="rag", description="CLI RAG over a Qdrant collection using halo embeddings.")
|
|
sub = parser.add_subparsers(dest="cmd", required=True)
|
|
|
|
pi = sub.add_parser("ingest", help="embed files/dirs into a collection")
|
|
pi.add_argument("paths", nargs="+")
|
|
pi.add_argument("-c", "--collection", default=DEFAULT_COLLECTION)
|
|
pi.set_defaults(func=cmd_ingest)
|
|
|
|
pq = sub.add_parser("query", help="search a collection")
|
|
pq.add_argument("text")
|
|
pq.add_argument("-c", "--collection", default=DEFAULT_COLLECTION)
|
|
pq.add_argument("-k", type=int, default=5)
|
|
pq.add_argument("-s", "--synthesize", action="store_true", help="feed hits to the chat model for a written answer")
|
|
pq.set_defaults(func=cmd_query)
|
|
|
|
args = parser.parse_args()
|
|
args.func(args)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
''
|