#!/usr/bin/env python3

import callgraph

from collections import deque
import sys

## graph things

def dict_set_add(d, k, v):
    """Adds v to a set mapped to by k in d, creating the set if needed."""
    s = d.get(k, set())
    s.add(v)
    d[k] = s


def setup(graph):
    """Construct a list of initially marked nodes and a reverse graph."""
    marked = set()
    rgraph = {}
    for (k,(edges,color)) in graph.items():
        if color:
            marked.add(k)
        for e in edges:
            if e in graph:
                dict_set_add(rgraph, e, k)
    return marked, rgraph

def def_pred(x): return False
def color(marked, graph, exclude=def_pred):
    """Color all nodes with a path to a marked node red.

    This is done with a DFS. Excludes nodes for which the exclude
    predicate returns True. Returns a dictionary mapping colored nodes
    to some colored node that they have an edge to.

    """
    red = { v: None for v in marked if not exclude(v) }
    q = deque(red)
    
    while q:
        v = q.popleft()
        for e in graph.get(v, set()):
            if e not in red and not exclude(e):
                red[e] = v
                q.append(e)

    return red

def trail(red, start):
    """Find a path from a red node to an initially marked node."""
    if start not in red: return []
    k = start
    l = []
    while k:
        l.append(k)
        k = red[k]
    return l
 
## stuff specific to this problem

def strip_args(s):
    """Return the function name with argument signature removed"""
    return s.split("(")[0]

def get_ops(funcs):
    """Extract JSOP names from a set of functions.

    Given an iterable containing names of functions, returns a set
    containing the name of each JSOP represented by some function in
    that iterable.

    """
    funcs = (strip_args(s) for s in funcs)
    ops = {s.replace("_func","") for s in funcs 
           if s.startswith("JSOP_") and s.endswith("_func")}
    return ops

def fixup_file(file, ops):
    """Corrects JOF_NOPC info in a jsopcode.tbl style file."""
    for line in open(file):
        if line.startswith("OPDEF"):
            line = line.replace("|JOF_NOPC", "")
            op = line.split("(")[1].split(",")[0]
            if op not in ops:
                line = line.replace(")", "|JOF_NOPC)")
        sys.stdout.write(line)

def get_file_ops(file):
    """Returns the set of opcodes declared in a file."""
    ops = set()
    for line in open(file):
        if line.startswith("OPDEF"):
            op = line.split("(")[1].split(",")[0]
            ops.add(op)
    return ops

def print_ok(file, pc_ops):
    """Prints the opcodes declared in file that do not touch the pc."""
    all_ops = get_file_ops(file)
    for function in sorted(all_ops-pc_ops):
        print(function)

def get_path(op, red):
    """Print the path from the function for opcode op to a marked function."""
    f = "JSOP_"+op+"_func(js_state_t*, uintN)"
    l = (strip_args(s) for s in trail(red, f))
    print(" -> ".join(l))

exclude = {
#        "js_FramePCToLineNumber",
#        "js_LeaveTrace",
#        "js_watch_set",
#        "js_InferFlags",
}

def main(args):
    def p(x):
        return strip_args(x) in exclude

    if len(args) != 3:
        sys.stderr.write("usage: %s (-f|-o|-p) (file|file|op)\n" % args[0])
        return 1

    # Do the initial analysis
    marked, rgraph = setup(callgraph.callgraph)
    red = color(marked, rgraph, p)
    pc_ops = get_ops(red)

    # Get the arguments
    flag = args[1]
    arg = args[2]
    
    # Do something
    if flag == '-f':
        fixup_file(arg, pc_ops)
    elif flag == '-o':
        print_ok(arg, pc_ops)
    elif flag == '-p':
        get_path(arg, red)
    else:
        sys.stderr.write("invalid flag %s\n" % flag)
        return 1

    return 0

if __name__ == '__main__':
    sys.exit(main(sys.argv))
