Skip to content

joey00072/Attention-as-graph

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 

Repository files navigation

Attention-as-graph

alternative way to calculating self attention

Warning

I may or may not work on it further, PR are welcome though

look main.py this is preview

@dataclass
class Node:
    idx: int
    value: Tensor
    adjacency_list: list[Edge] = field(default_factory=list)

@dataclass
class Edge:
    node: Node
    weight: Tensor



def build_graph(nodes:list[Node],keys:Tensor,queries:Tensor):    
    batch,seq_len,d_model = queries.shape
    for idx,curr_node in enumerate(nodes):
        # picking 1 to n keys
        keys_history = keys[:,:idx+1,:] 
        
        # picking nth query
        curr_query = queries[:,idx,:]
        
        # here we take dot product (concise similarity) between current query
        # and all keys that contains in histoy of current node (token)
        similarity_values = curr_query@keys_history.transpose(-1,-2)  
        
        # if DEBUGE: print(f"{keys_history.shape=} {curr_query.shape=} {similarity_values.shape=} ")
        similarity_values = similarity_values/math.sqrt(d_model)
        
        # after softmax you will get weights with indicates 
        # how much current node want pay attention to past node
        attn = F.softmax(similarity_values.float(),dim=-1).type_as(keys)
        
        attn = attn.reshape(-1) # reshaping to make it simple
        # if DEBUGE: print(attn)
        
        # adding back edges in adjacency list of each node
        for nidx,node in enumerate(nodes[:idx+1]):
            edge_weight = attn[nidx]
            
            # if DEBUGE: print(f"{idx} attend to {nidx} node with {edge_weight:.2f}")
            edge = Edge(node=node,weight=edge_weight)
            
            # curent node is getting weighted edge with all past nodes
            curr_node.adjacency_list.append(
                edge
            )
    return nodes

TODO

  • so inferace with tiny lm as poc
  • Add visuization
    • top nodes infuanceing current node

its for education purpose, has no pratical use (unless added visualiztion)

About

alternative way to calculating self attention

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages