import torch
import faiss

# HF ESM-2 classes
from transformers import (AutoTokenizer,EsmModel)

# small CPU-friendly model
MODEL = "facebook/esm2_t6_8M_UR50D"
device = "cpu"
tok = AutoTokenizer.from_pretrained(MODEL)

# Load base model without pooling
mdl = EsmModel.from_pretrained(MODEL, add_pooling_layer=False).to(device)
mdl.eval() # inference mode

# tiny built-in sequence collection
TOY_DB = [
    # Family A: P-loop-like
    ("A001", "p_loop_A",
     "MNNIRRVLIVGPNGAGKSTLLQAIAANAGADVVVVDSQTPAQLEAALERAGVEVVFINDK"),
    ("A002", "p_loop_A",
     "MSNIRRVLIVGPNGAGKSTLLQAVAANAGADIVVVDSQTPAQLEAALERAGVEVIFINDK"),
    ("A003", "p_loop_A",
     "MNNVRRVLIVGPNGAGKSTLLQAIAANSGADVVIVDSQTPAQLEASLERAGVEVVFVNDK"),

    # Family B: helix-turn-helix-like / basic DNA-binding-like
    ("B001", "hth_B",
     "MARRKQLAERLAALEQQNPDVEALAALEQAGYDVKRRRVEQLSRELNEMGVSAAELAQLGVT"),
    ("B002", "hth_B",
     "MARRKQLAERLAALEQQNPDVEALASLEQAGYDVKRRRVEQLSRELNEMGVSAAEIAQLGVT"),
    ("B003", "hth_B",
     "MARRKQLAERLAALEQQNPEVEALAALEQAGYDVKRRRVEQLSRELSEMGVSAAELAQLGVT"),

    # Family C: acidic enzyme-like
    ("C001", "acidic_C",
     "MNKDVAIHFDLSPEDVKRALEAGADVVVVHDELDTPEDLAAIARAGADVVVTLDPEQGK"),
    ("C002", "acidic_C",
     "MNKDVAIHFDLSPEDVKRALEAGADIVVVHDELDTPEDLAAIARAGADVVVTLDPEQGK"),
    ("C003", "acidic_C",
     "MNKDVAIHFDLSPDDVKRALEAGADVVVVHDELDAPEDLAAIARAGADVVVTLDPEQGK"),

    # Family D: ubiquitin-like compact fold
    ("D001", "ubq_like_D",
     "MQIFVKTLTGKTITLEVEPSDTIENVKAKIQDKEGIPPDQQRLIFAGKQLEDGRTLSDYNIQ"),
    ("D002", "ubq_like_D",
     "MQIFVKTLTGKTITLEVEPSDTIENVKAKIQDKEGIPPDQQRLVFAGKQLEDGRTLSDYNIQ"),
    ("D003", "ubq_like_D",
     "MQIFVKTLTGKTITLDVEPSDTIENVKAKIQDKEGIPPDQQRLIFAGKQLEDGRTLSDYNIQ"),
    
    # Hydrophobic / signal-peptide-like / membrane-like
    ("X001", "hydrophobic_like",
     "MKRLLPLAVAVAALLAVSCSAQAAAAPAAEAEAAAG"),
    ("X002", "hydrophobic_like",
     "MALLWLLLAVALVCGAQAAAPAPAPVVAEAAAAGGG"),
    ("X003", "hydrophobic_like",
     "MNRRLVVVVVALLLLTGCGCSAAAAPAAPVVAAAAG"),
    ("X004", "hydrophobic_like",
     "MFFSRRLLLLAAGVALAGCGVQAAPAAPAAAGGGSS"),
    ("X005", "tm_like",
     "MSNNRILAVVVIGTAVVAGLIAGWFFGQKKKDEEAA"),
    ("X006", "tm_like",
     "MKTLLAIVLAFVSVGLVLGAYYFKRKQAEGDDDAAA"),
    
    # Low-complexity / compositionally biased
    ("X007", "low_complexity",
     "MSDSEEKKTKKTKKTKKTKKTKKTKKTKKTKKTKK"),
    ("X008", "low_complexity",
     "MGGGGGGGGSSGGGGGSGSGGGGSSGGGGGSGGGGS"),
    ("X009", "low_complexity",
     "MPEPEPEPEPEPEPEPEPEPEPEPEPEPEPEPEP"),
    ("X010", "low_complexity",
     "MQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQQ"),
    ("X011", "low_complexity",
     "MSSSSSSSSSDDDDDDDEEEEEKKKKKKSSSSSSS"),
    ("X012", "low_complexity",
     "MNRNRNRNRNRNRNGSGSGSGSGSGSNRNRNRNRN"),

]

@torch.no_grad()
def embed_batch(seqs):
    # tokenize batch
    inp = tok(
        seqs,
        return_tensors="pt",
        padding=True,
        truncation=True
        )
    
    #move to CPU
    inp = {k: v.to(device) for k, v in inp.items()}
    out = mdl(**inp)
    h = out.last_hidden_state
    
    # mask padded tokens
    attn = inp["attention_mask"].unsqueeze(-1).float()
    
    # mean pool
    v = (h * attn).sum(dim=1) / attn.sum(dim=1).clamp(min=1.0)
    
    # L2 normalize
    v = torch.nn.functional.normalize(v, dim=1)
    
    return v.cpu().numpy().astype("float32")

# Unpack the database
ids = [sid for (sid, lbl, seq) in TOY_DB]
labels = [lbl for (sid, lbl, seq) in TOY_DB]
seqs = [seq for (sid, lbl, seq) in TOY_DB]

# Compute database embeddings
X = embed_batch(seqs)
d = X.shape[1] # dimension


# exact inner-product index
index = faiss.IndexFlatIP(d)
index.add(X)

def search(query_seq, k=5):
    # embed query
    q = embed_batch([query_seq])
    # search top-k
    scores, nn = index.search(q, k)
    out = []
    # pair ids and scores
    for j, s in zip(nn[0], scores[0]):
        out.append((ids[j],labels[j],float(s),len(seqs[j])))
    return out

query = seqs[10] # example query
hits = search(query, k=5)

print("Query length:", len(query))

# pretty print results
for r, (sid, lbl, score, L) in enumerate(hits,1):
    print(f"{r:>2}. {sid} "
          f"label={lbl} "
          f"cosine={score:.3f} "
          f"len={L}")

