#!/usr/bin/python

# Manage OpenLDAP databases
# Copyright (c) 2013 Guilhem Moulin <guilhem@fripost.org>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import ldap, ldap.sasl
from ldap.filter  import filter_format
from ldap.dn      import dn2str,explode_dn,str2dn
from ldap.modlist import addModlist
from ldif         import LDIFParser
from functools    import partial
import re, pwd


# Dirty hack to check equality between the targetted LDIF and that
# currently in the directory.  The value of some configuration (olc*)
# attributes is automatically indexed when added; for those we'll add
# explicit indices to what we find in the LDIF.
indexedAttributes = frozenset([
    'olcAttributeTypes',
    'olcObjectClasses',
    'olcAccess',
    'olcSyncrepl',
    'olcOverlay',
])


# Another hack. Configuration entries sometimes pollutes the DNs with
# indices, thus it's not possible to directly use them as base.
# Instead, we use their parent as a pase, and search for the *unique*
# match with the same ObjectClass and the matching extra attributes.
# ('%s' in the attribute value is replaced with the value of the source
# entry.)
indexedDN = {
    'olcSchemaConfig':  [('cn',             '{*}%s')],
    'olcHdbConfig':     [('olcDbDirectory', '%s'   )],
    'olcOverlayConfig': [('olcOverlay',     '%s'   )],
}

# Allow for flexible ACLs for user using SASL's EXTERNAL mechanism.
# "username=postfix,cn=peercred,cn=external,cn=auth" is replaced by
# "gidNumber=106+uidNumber=102,cn=peercred,cn=external,cn=auth" where
# 102 is postfix's UID and 106 its primary GID.
# (Regular expressions are not allowed.)
sasl_ext_re = re.compile( r"""(?P<start>\sby\s+dn(?:\.exact)?)=
                              (?P<quote>['\"]?)username=(?P<user>[a-z][-a-z0-9_]*),
                              (?P<end>cn=peercred,cn=external,cn=auth)
                              (?P=quote)\s"""
                        , re.VERBOSE )
pwd_dict = {}

def acl_sasl_ext(m):
    u = m.group('user')
    if u not in pwd_dict.keys():
        pwd_dict[u] = pwd.getpwnam(u)
    return '%s="gidNumber=%d+uidNumber=%d,%s" ' % ( m.group('start')
                                                  , pwd_dict[u].pw_gid
                                                  , pwd_dict[u].pw_uid
                                                  , m.group('end')
                                                  )


# Run the given callback on each DN seen.  If its return value is not
# None, update the changed variable.
class LDIFCallback(LDIFParser):
    def __init__(self, module, input, callback):
        LDIFParser.__init__(self,input)
        self.callback = callback
        self.changed = False

    def handle(self,dn,entry):
        b = self.callback(dn,entry)
        if b is not None:
            self.changed |= b


# Run slapcat(8) on the given suffix or DB number (suffix takes
# precedence) with an optional filter.  (This is useful for offline
# searches, or one needs to bypass ACLs.) Returns an open pipe to the
# subprocess.
def slapcat(filter=None, suffix=None, idx=0):
    cmd = [ os.path.join(os.sep, 'usr', 'sbin', 'slapcat') ]

    if filter is not None:
        cmd.extend([ '-a', filter ])

    if suffix is not None:
        if type(suffix) is not str:
            suffix = dn2str(suffix)
        cmd.extend([ '-b', suffix ])
    else:
        cmd.append( '-n%d' % idx )

    return subprocess.Popen( cmd, stdout=subprocess.PIPE
                           , stderr=open(os.devnull, 'wb') )


# Start / stop / whatever a service.
def service(name, state):
    cmd = [ os.path.join(os.sep, 'usr', 'sbin', 'service'), name, state ]
    subprocess.check_call( cmd, stdout=open(os.devnull, 'wb')
                         , stderr=subprocess.STDOUT )


# Check if the given dn is already present in the directory.
# Returns None if doesn't exist, and give the dn,entry otherwise
def flexibleSearch(module, l, dn, entry):
    idxClasses = set(entry['objectClass']).intersection(indexedDN.keys())
    if not idxClasses:
        base = dn
        scope = ldap.SCOPE_BASE
        f = 'objectClass=*'
    else:
        # Search on the parent instead, and try to use a precise filter
        dn = str2dn(dn)
        h,t,_ = dn.pop(0)[0]
        base = dn2str(dn)
        scope = ldap.SCOPE_ONELEVEL
        f = []
        for c in idxClasses:
            f.append ( filter_format('objectClass=%s', [c]) )
            for a,v in indexedDN[c]:
                if a == h:
                    v2 = t
                elif a not in entry.keys() or len(entry[a]) > 1:
                    module.fail_json(msg="Multiple values found! This is a bug. Please report.")
                else:
                    v2 = entry[a][0]
                f.append ( filter_format(a+'='+v, [v2]) )
        if len(f) == 1:
            f = f[0]
        else:
            f = '(&(' + ')('.join(f) + '))'

    r = l.search_s( base, scope, filterstr=f )
    if len(r) > 1:
        module.fail_json(msg="Multiple results found! This is a bug. Please report.")
    elif r:
        return r.pop()


# Add or modify (only the attributes that differ from those in the
# directory) the entry for that DN.
# l must be an LDAPObject, and should provide an open connection to the
# directory with disclose/search/write access.
def processEntry(module, l, dn, entry):
    changed = False
    r = flexibleSearch( module, l, dn, entry )
    if r is None:
        changed = True
        if module.check_mode:
            module.exit_json(changed=changed, msg="add DN %s" % dn)
        l.add_s( dn, addModlist(entry) )
    else:
        d,e = r
        fst = str2dn(dn).pop(0)[0][0]
        diff = []
        for a,v in e.iteritems():
            if a not in entry.keys():
                if a != fst:
                    # delete all values except for the first attribute,
                    # which is implicit
                    diff.append(( ldap.MOD_DELETE, a, None ))
            elif a in indexedAttributes:
                if a == 'olcAccess':
                    # replace "username=...,cn=peercred,cn=external,cn=auth"
                    # by a DN with proper gidNumber and uidNumber
                    entry[a] = map ( partial(re.sub, sasl_ext_re, acl_sasl_ext)
                                   , entry[a] )
                # add explicit indices in the entry from the LDIF
                entry[a] = map( (lambda x: '{%d}%s' % x)
                              , zip(range(len(entry[a])),entry[a]) )
                if v != entry[a]:
                    diff.append(( ldap.MOD_REPLACE, a, entry[a] ))
            elif v != entry[a]:
                # for non-indexed attribute, we update values in the
                # symmetric difference only
                s1 = set(v)
                s2 = set(entry[a])
                if s1.isdisjoint(s2):
                    # replace the former values with the new ones
                    diff.append(( ldap.MOD_REPLACE, a, entry[a] ))
                else:
                    x = list(s1.difference(s2))
                    if x:
                        diff.append(( ldap.MOD_DELETE, a, x ))
                    y = list(s2.difference(s1))
                    if y:
                        diff.append(( ldap.MOD_ADD,    a, y ))

        # add attributes that weren't in e
        for a in set(entry).difference(e.keys()):
            diff.append(( ldap.MOD_ADD, a, entry[a] ))

        if diff:
            changed = True
            if module.check_mode:
                module.exit_json(changed=changed, msg="mod DN %s" % dn)
            l.modify_s( d, diff )
    return changed


# Load the given module.
def loadModule(module, l, name):
    changed = False

    f = filter_format( '(&(objectClass=olcModuleList)(olcModuleLoad=%s))', [name] )
    r = l.search_s( 'cn=config', ldap.SCOPE_ONELEVEL, filterstr = f, attrlist = [''] )

    if not r:
        changed = True
        if module.check_mode:
            module.exit_json(changed=changed, msg="add module %s" % name)
        l.modify_s( 'cn=module{0},cn=config'
                  , [(ldap.MOD_ADD, 'olcModuleLoad', name)] )

    return changed


# Find the database associated with a given attribute (eg,
# olcDbDirectory or olcSuffix).
def getDN_DB(module, l, a, v):
    f = filter_format( '(&(objectClass=olcDatabaseConfig)('+a+'=%s))', [v] )
    return l.search_s( 'cn=config'
                     , ldap.SCOPE_ONELEVEL
                     , filterstr = f
                     , attrlist = [''] )


# Clear the given DB directory and delete the associated database.  Fail
# if non empty, unless all existing DNS are in skipdns.
def wontRemove(module, skipdns, d, _):
    if d not in skipdns:
        module.fail_json(msg="won't remove '%s'" % d)
def removeDB(module, dbdir, skipdn=None):
    changed = False
    if not os.path.exists(dbdir):
        return False

    l = ldap.initialize( 'ldapi://' )
    l.sasl_interactive_bind_s('', ldap.sasl.external())
    r = getDN_DB( module, l, 'olcDbDirectory', dbdir )
    l.unbind_s()

    if len(r) > 1:
        module.fail_json(msg="Multiple results found! This is a bug. Please report.")
    elif r:
        dn,entry = r.pop()
        suffix = entry['olcSuffix'][0]

        skipdns = [suffix]
        if skipdn is not None:
            skipdns.extend([ "%s,%s" % (s,suffix) for s in skipdn ])
        # here we need to use slapcat not search_s, because we may
        # not have read access on the database (even though we're
        # root!).
        p = slapcat( suffix=suffix )
        parser = LDIFCallback( module, p.stdout
                             , partial(wontRemove,module,skipdns) )
        parser.parse()

        changed = True
        if module.check_mode:
            module.exit_json(changed=changed, msg="remove dir %s" % dbdir)

        # slapd doesn't support database deletion, so we need to turn it
        # off and remove it from slapd.d manually.
        service( 'slapd', 'stop' )
        path = [ os.sep, 'etc', 'ldap', 'slapd.d' ]
        ldif = explode_dn(dn)[::-1]
        ldif[-1] += ".ldif"
        path.extend( ldif )
        os.unlink( os.path.join(*path) )

        # delete all children in path, but not the path directory itself.
        for file in os.listdir(dbdir):
            os.unlink( os.path.join(dbdir, file) )
        service( 'slapd', 'start' )
    return changed


def main():
    module = AnsibleModule(
        argument_spec   = dict(
            dbdirectory = dict( default=None ),
            ignoredn    = dict( default=None ),
            state       = dict(default="present", choices=["absent", "present"]),
            target      = dict( default=None ),
            module      = dict( default=None ),
            suffix      = dict( default=None ),
        ),
        supports_check_mode=True
    )

    params      = module.params
    state       = params['state']
    dbdirectory = params['dbdirectory']
    ignoredn    = params['ignoredn']
    target      = params['target']
    mod         = params['module']
    suffix      = params['suffix']

    if ignoredn is not None:
        ignoredn = ignoredn.split(':')

    changed = False
    try:
        if state == "absent":
            if dbdirectory is not None:
                changed = removeDB(module,dbdirectory,skipdn=ignoredn)
            # TODO: might be useful to be able remove DNs
            else:
                module.fail_json(msg="missing dbdirectory")

        elif state == "present":
            if target is None and mod is None:
                module.fail_json(msg="missing target or module")
            # bind only once per LDIF file for performance
            l = ldap.initialize( 'ldapi://' )
            l.sasl_interactive_bind_s('', ldap.sasl.external())

            if mod is None:
                callback = partial(processEntry,module,l)
            else:
                changed |= loadModule (module, l, '%s.la' % mod)
                if target is None and suffix is None:
                    l.unbind_s()
                    module.exit_json(changed=changed)
                if target is None or suffix is None:
                    module.fail_json(msg="missing target or suffix")
                r = getDN_DB(module, l, 'olcSuffix', suffix)
                if not r:
                    module.fail_json(msg="No database found for suffix %s" % suffix)
                elif len(r) > 1:
                    module.fail_json(msg="Multiple results found! This is a bug. Please report.")
                else:
                    d = 'olcOverlay=%s,%s' % (mod, r.pop()[0])
                    callback = lambda _,e: processEntry(module,l,d,e)

            parser = LDIFCallback( module, open(target, 'r'), callback )
            parser.parse()
            changed = parser.changed
            l.unbind_s()

    except subprocess.CalledProcessError, e:
        module.fail_json(rv=e.returncode, msg=e.output.rstrip())
    except ldap.LDAPError, e:
        e = e.args[0]
        if 'info' in e.keys():
            msg = e['info']
        elif 'desc' in e.keys():
            msg = e['desc']
        else:
            msg = str(e)
        module.fail_json(msg=msg)
    except KeyError, e:
        module.fail_json(msg=str(e))

    module.exit_json(changed=changed)


# this is magic, see lib/ansible/module_common.py
#<<INCLUDE_ANSIBLE_MODULE_COMMON>>
main()