#!/usr/bin/python2.3

from optparse import OptionParser
import socket
import re
import struct

"""
1. What happens if a client matches more than one criterion?
   eg jupiter against *() *.umich.edu() *.citi.umich.edu()

   - mountd takes first one that matches, ignoring others
   
2. What happens if same path appears twice

   - options are taken in order they appear
"""

###########################################################

class Entry(object):
    """Holds a client(options) entry from the exports file"""
    def __init__(self, str):
        """str is of form <client>(<options>)"""
        i = str.find('(')
        self.client = str[:i]
        if self.client == "":
            self.client = "*"
        self.options = str[i+1:-1]
        self.type = self._settype()
        self._setflavors()
        
    def _settype(self):
        c = self.client
        if c == "" or c == "*":
            c = "*"
            self.regex = re.compile(r".*")
            return "wild"
        elif c.startswith("gss/"):
            return "gss"
        elif c[0] == "@":
            return "netgroup"
        elif ismask(c):
            fields = self.client.split('/')
            self.addr = fields[0]
            if len(fields) == 2:
                if isquad(fields[1]):
                    self.mask = socket.inet_aton(fields[1])
                else:
                    bits = int(fields[1])
                    mask = 0xffffffffL & ~(2**(32-bits) - 1)
                    self.mask = struct.pack("!Q", mask)[-4:]
            else:
                self.mask = '\xff\xff\xff\xff'
            return "netmask"
        skip = False
        for x in c:
            if skip:
                skip = False
                continue
            if x in "*?[":
                self.regex = self._get_wild_pattern()
                return "wild"
            if x == '\\':
                skip = True
        self.fqdn = socket.getfqdn(self.client)
        return "name"

    def _setflavors(self):
        self.flavors = []
        if self.type == "gss":
            self.flavors.append(self.client[4:])
        opts = self.options.split(',')
        for o in opts:
            if o.startswith("sec="):
                for f in o[4:].split(':'):
                    if f not in self.flavors:
                        self.flavors.extend(f)
        if not self.flavors:
            self.flavors = ["sys"]
        
    def __repr__(self):
        # return "%s(%s)" % (self.client, self.options)
        return "%s(%s)[%s]" % (self.client, self.options, self.type)

    def matches(self, address, sec):
        """Returns True if self.client matches address"""
        # XXX Really have to carefully match mountd code
        if sec not in self.flavors:
            return False
        if self.type == "name":
            return self.fqdn in address.names
        elif self.type == "netmask":
            for q in address.quads:
                if match_mask(self.addr, q, self.mask):
                    return True
            return False
        elif self.type == "wild":
            for name in address.names:
                if self.regex.match(name):
                    return True
            return False
        elif self.type == "gss":
            return self.client[4:] == sec
        else:
            raise TypeError("Unknown type %s" % self.type)

    def wildmatch(self, text, pat):
        def domatch(t, p):
            ti = pi = 0
            while p[pi:]:
                if t[ti:] == '' and p[pi] != '*':
                    return None
                if p[pi] == '*':
                    try:
                        while p[pi] == '*':
                            pi += 1
                    except IndexError:
                        return True
                    while t[ti:]:
                        matched = domatch(t[ti:], p[pi:])
                        ti += 1
                        if matched != False:
                            return matched
                    return None
                while (True): # Allows 'break' to jump to correct point
                    if p[pi] == '?':
                        break;
                    if p[pi] == '\\':
                        pi += 1
                    if p[pi] == '[':
                        reverse = p[pi+1] == '^'
                        if reverse:
                            pi += 1
                        matched = False
                        if p[pi+1] == ']' or p[pi+1] == '-':
                            pi += 1
                            if p[pi].upper() == t[ti].upper():
                                matched = True
                        last = p[pi]
                        pi += 1
                        while (p[pi:] and p[pi] != ']'):
                            if p[pi] == '-' and p[pi] != ']':
                                pi += 1
                                if last <= t[ti] <= p[pi]:
                                    matched = True
                            elif p[pi].upper() == t[ti].upper():
                                matched = True
                            last = p[pi]
                            pi += 1
                        if matched == reverse:
                            return False
                        break
                    if p[pi].upper() != t[ti].upper():
                        return False
                    break
                ti += 1
                pi += 1
            return t[ti:] == ''
        if pat == '*' or pat == '':
            return True
        return domatch(text, pat) == True
    def _get_wild_pattern(self):
        """Returns regex pattern corresponding to self.client

        Does shell-style matching for *, ?, [], and \ characters.
        """
        # XXX BUG? \ handling is confusing
        # XXX should we emulate mountd code, or use python re library?
        pat = "^"
        skip = False
        in_bracket = False
        for c in self.client:
            out = c
            if skip == True:
                skip = False
            elif c == '\\': # Note this implies active w/in bracket
                skip = True
            elif in_bracket:
                if c == ']':
                    in_bracket = False
            elif c == '[':
                in_bracket = True
            elif c == '*':
                out = ".*"
            elif c == '?':
                out = '.'
            elif c == '.':
                out = r'\.'
            pat += out
        pat += '$'
        return re.compile(pat)
        
class Address(object):
    """Holds name and ip address info"""
    def __init__(self, client):
        self.given_name = client
        triple = self._gethost(client)
        self.names = [triple[0]] + triple[1]
        self.quads = triple[2]

    def _gethost(self, name):
        """Try to duplicate mountd method of getting host"""
        try:
            addr = socket.gethostbyname_ex(name)
        except:
            # XXX this creates an empty name, as opposed to empty list
            # should we try harder to create empty list?
            return ("", [], [client])
        try:
            host = socket.gethostbyaddr(addr[2][0])
            if len(addr[2]) > 1:
                try:
                    host = socket.gethostbyname_ex(host[0])
                except:
                    pass

        except:
            host = addr
        return host
    
    def show_access(self, exportdata, sec="sys", debug=False):
        """Reports access self would have to various paths in exports file"""
        for path in exportdata:
            access = False
            for entry in exportdata[path]:
                if entry.matches(self, sec):
                    if debug: print "MATCH - %s: %s: %s" % (self, path, entry)
                    access = True
                    break
            if access:
                print "ALLOW: %s %s" % (path, entry)
            else:
                print "DENY : %s" % path

    def __repr__(self):
        return "%s, %s" % (self.names, self.quads)
    
###########################################################

def isquad(str):
    """Determines if str is a dot quad address"""
    a = str.split('.')
    if len(a) != 4:
        return False
    for q in a:
        try:
            i = int(q)
        except ValueError:
            return False
        if not (0 <= i < 256):
            return False
    return True

def ismask(str):
    """Determines if str is a netmask (a simple dot-quad will return True)"""
    # XXXSurely a lot of this mask manipulation must exist in a library already
    fields = str.split('/')
    if len(fields) == 1:
        return isquad(str)
    elif len(fields) > 2:
        return False
    # At this point fields[0] should be address, fields[1] should be mask
    elif not isquad(fields[0]):
        return False
    elif isquad(fields[1]):
        return True
    else:
        try:
            i = int(fields[1])
        except ValueError:
            return False
        return (0 <= i <= 32)

def match_mask(ip1, ip2, mask):
    """Returns True if ip1&mask == ip2&mask"""
    def convert(str):
        """Convert dot-quad string to integer"""
        tmp = socket.inet_aton(str)
        return struct.unpack("!L", tmp)[0]
    # print "match_mask(%s, %s, %s)" % (ip1, ip2, repr(mask))
    ip1 = convert(ip1)
    ip2 = convert(ip2)
    mask = struct.unpack("!L", mask)[0]
    return ip1 & mask == ip2 & mask

###########################################################

def get_options():
    def_file = "/etc/exports"
    def_flavor = "sys"
    p = OptionParser(usage="%prog [options] client ...")
    p.add_option("-f", "--file", default=def_file, metavar="FILE",
                 help="Parse FILE as exports file [%s]" % def_file)
    p.add_option("--sec", default="sys",
                 metavar="FLAVOR",
                 help="Assume client is using security FLAVOR [%s]" % def_flavor)
    opts, args = p.parse_args()
    if args:
        opts.clients = [Address(c) for c in args]
    else:
        p.error("No client given")
    return opts

def readlines(filename):
    """Parse file, concatanating lines and removing whitespace as necessary."""
    def full_lines(fd):
        """Iterate over lines, where we concatanate lines ending with \ """
        line = ""
        for partial_line in fd:
            line += partial_line.rstrip()
            if line.endswith('\\'):
                line = line[:-1]
            else:
                yield line.strip()
                line = ""

    fd = open(filename)
    try:
        return [' '.join(l.split()) for l in full_lines(fd)
                if l != '' and l[0] != '#']
    finally:
        fd.close()

def parse_lines(lines):
    d = {}
    for l in lines:
        fields = l.split()
        path = fields[0]
        exports = [Entry(s) for s in fields[1:]]
        if path in d:
            d[path].extend(exports)
        else:
            d[path] = exports
    return d

###########################################################

def main(opts, debug=True):
    try:
        lines = readlines(opts.file)
    except StandardError, e:
        print "Error trying to read %s:\n%s" % (opts.file, e)
        return
    d = parse_lines(lines)
    if debug:
        print "Reading from file %s:" % opts.file
        for path in d:
            print "%s %s" % (path, ' '.join([str(e) for e in d[path]]))
        print
    for c in opts.clients:
        print c.given_name
        if debug: print c
        c.show_access(d, opts.sec)
        print

if __name__ == "__main__":
    import sys
    try:
        opts = get_options()
    except Exception, e:
        if e.code:
            print e
            print "Failure reading commandline options"
        sys.exit(1)
    main(opts)
