#!/usr/bin/python

# Manage OpenLDAP databases
# Copyright (c) 2013 Guilhem Moulin <guilhem@fripost.org>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import ldap, ldap.sasl
from ldap.filter  import filter_format
from ldap.dn      import dn2str,explode_dn,str2dn
from ldap.modlist import addModlist
from ldif         import LDIFParser
from functools    import partial
import re, pwd
import tempfile, atexit


# Dirty hack to check equality between the targetted LDIF and that
# currently in the directory.  The value of some configuration (olc*)
# attributes is automatically indexed when added; for those we'll add
# explicit indices to what we find in the LDIF.
indexedAttributes = frozenset([
    'olcAttributeTypes',
    'olcObjectClasses',
    'olcAccess',
    'olcSyncrepl',
    'olcOverlay',
    'olcLimits',
    'olcAuthzRegexp',
    'olcDbConfig',
])


# Another hack. Configuration entries sometimes pollutes the DNs with
# indices, thus it's not possible to directly use them as base.
# Instead, we use their parent as a base, and search for the *unique*
# match with the same ObjectClass and the matching extra attributes.
# ('%s' in the attribute value is replaced with the value of the source
# entry.)
indexedDN = {
    'olcSchemaConfig':  [('cn',             '{*}%s')],
    'olcMdbConfig':     [('olcDbDirectory', '%s'   )],
    'olcOverlayConfig': [('olcOverlay',     '%s'   )],
    'olcMonitorConfig': [],
}

# Allow for flexible ACLs for user using SASL's EXTERNAL mechanism.
# "username=postfix,cn=peercred,cn=external,cn=auth" is replaced by
# "gidNumber=106+uidNumber=102,cn=peercred,cn=external,cn=auth" where
# 102 is postfix's UID and 106 its primary GID.
# (Regular expressions are not allowed.)
sasl_ext_re = re.compile( r"""(?P<start>\sby\s+dn(?:\.exact)?)=
                              (?P<quote>['\"]?)username=(?P<user>[a-z][-a-z0-9_]*),
                              (?P<end>cn=peercred,cn=external,cn=auth)
                              (?P=quote)\s"""
                        , re.VERBOSE )
multispaces = re.compile( r"\s+" )
pwd_dict = {}

def acl_sasl_ext(m):
    u = m.group('user')
    if u not in pwd_dict.keys():
        pwd_dict[u] = pwd.getpwnam(u)
    return '%s="gidNumber=%d+uidNumber=%d,%s" ' % ( m.group('start')
                                                  , pwd_dict[u].pw_gid
                                                  , pwd_dict[u].pw_uid
                                                  , m.group('end')
                                                  )


# Run the given callback on each DN seen.  If its return value is not
# None, update the changed variable.
class LDIFCallback(LDIFParser):
    def __init__(self, module, input, callback):
        LDIFParser.__init__(self,input)
        self.callback = callback
        self.changed = False

    def handle(self,dn,entry):
        b = self.callback(dn,entry)
        if b is not None:
            self.changed |= b


# Check if the given dn is already present in the directory.
# Returns None if doesn't exist, and give the dn,entry otherwise
def flexibleSearch(module, l, dn, entry):
    idxClasses = set(entry['objectClass']).intersection(indexedDN.keys())
    if not idxClasses:
        base = dn
        scope = ldap.SCOPE_BASE
        f = 'objectClass=*'
    else:
        # Search on the parent instead, and try to use a precise filter
        dn = str2dn(dn)
        h,t,_ = dn.pop(0)[0]
        base = dn2str(dn)
        scope = ldap.SCOPE_ONELEVEL
        f = []
        for c in idxClasses:
            f.append ( filter_format('objectClass=%s', [c]) )
            for a,v in indexedDN[c]:
                if a == h:
                    v2 = t
                elif a not in entry.keys() or len(entry[a]) > 1:
                    module.fail_json(msg="Multiple values found! This is a bug. Please report.")
                else:
                    v2 = entry[a][0]
                f.append ( filter_format(a+'='+v, [v2]) )
        if len(f) == 1:
            f = f[0]
        else:
            f = '(&(' + ')('.join(f) + '))'

    r = l.search_s( base, scope, filterstr=f )
    if len(r) > 1:
        module.fail_json(msg="Multiple results found! This is a bug. Please report.")
    elif r:
        return r.pop()


# Add or modify (only the attributes that differ from those in the
# directory) the entry for that DN.
# l must be an LDAPObject, and should provide an open connection to the
# directory with disclose/search/write access.
def processEntry(module, l, dn, entry):
    changed = False

    for x in indexedAttributes.intersection(entry.keys()):
        # remove useless extra spaces in ACLs etc
        entry[x] = map( partial(multispaces.sub, ' '), entry[x] )

    r = flexibleSearch( module, l, dn, entry )
    if r is None:
        changed = True
        if module.check_mode:
            module.exit_json(changed=changed, msg="add DN %s" % dn)
        if 'olcAccess' in entry.keys():
            # replace "username=...,cn=peercred,cn=external,cn=auth"
            # by a DN with proper gidNumber and uidNumber
            entry['olcAccess'] = map ( partial(sasl_ext_re.sub, acl_sasl_ext)
                                     , entry['olcAccess'] )
        l.add_s( dn, addModlist(entry) )
    else:
        d,e = r
        fst = str2dn(dn).pop(0)[0][0]
        diff = []
        for a,v in e.iteritems():
            if a not in entry.keys():
                if a != fst:
                    # delete all values except for the first attribute,
                    # which is implicit
                    diff.append(( ldap.MOD_DELETE, a, None ))
            elif a in indexedAttributes:
                if a == 'olcAccess':
                    # replace "username=...,cn=peercred,cn=external,cn=auth"
                    # by a DN with proper gidNumber and uidNumber
                    entry[a] = map ( partial(sasl_ext_re.sub, acl_sasl_ext)
                                   , entry[a] )
                # add explicit indices in the entry from the LDIF
                entry[a] = map( (lambda x: '{%d}%s' % x)
                              , zip(range(len(entry[a])),entry[a]) )
                if v != entry[a]:
                    diff.append(( ldap.MOD_REPLACE, a, entry[a] ))
            elif v != entry[a]:
                # for non-indexed attribute, we update values in the
                # symmetric difference only
                s1 = set(v)
                s2 = set(entry[a])
                if s1.isdisjoint(s2):
                    # replace the former values with the new ones
                    diff.append(( ldap.MOD_REPLACE, a, entry[a] ))
                else:
                    x = list(s1.difference(s2))
                    if x:
                        diff.append(( ldap.MOD_DELETE, a, x ))
                    y = list(s2.difference(s1))
                    if y:
                        diff.append(( ldap.MOD_ADD,    a, y ))

        # add attributes that weren't in e
        for a in set(entry).difference(e.keys()):
            diff.append(( ldap.MOD_ADD, a, entry[a] ))

        if diff:
            changed = True
            if module.check_mode:
                module.exit_json(changed=changed, msg="mod DN %s" % dn)
            l.modify_s( d, diff )
    return changed


# Load the given module.
def loadModule(module, l, name):
    changed = False

    f = filter_format( '(&(objectClass=olcModuleList)(olcModuleLoad=%s))', [name] )
    r = l.search_s( 'cn=config', ldap.SCOPE_ONELEVEL, filterstr = f, attrlist = [''] )

    if not r:
        changed = True
        if module.check_mode:
            module.exit_json(changed=changed, msg="add module %s" % name)
        l.modify_s( 'cn=module{0},cn=config'
                  , [(ldap.MOD_ADD, 'olcModuleLoad', name)] )

    return changed


# Find the database associated with a given attribute (eg,
# olcDbDirectory or olcSuffix).
def getDN_DB(module, l, a, v, attrlist=['']):
    f = filter_format( '(&(objectClass=olcDatabaseConfig)('+a+'=%s))', [v] )
    return l.search_s( 'cn=config'
                     , ldap.SCOPE_ONELEVEL
                     , filterstr = f
                     , attrlist = attrlist )


# Convert a *.schema file into *.ldif format. The algorithm can be found
# in /etc/ldap/schema/openldap.ldif .
def slapd_to_ldif(src, name):
    s = open( src, 'r' )
    d = tempfile.NamedTemporaryFile(delete=False)
    atexit.register(lambda: os.unlink( d.name ))

    d.write('dn: cn=%s,cn=schema,cn=config\n' % name)
    d.write('objectClass: olcSchemaConfig\n')

    re1 = re.compile( r'^objectIdentifier\s(.*)', re.I )
    re2 = re.compile( r'^objectClass\s(.*)',      re.I )
    re3 = re.compile( r'^attributeType\s(.*)',    re.I )
    reSp = re.compile( r'^\s+' )
    for line in s.readlines():
        if line == '\n':
            line = '#\n'
        m1 = re1.match(line)
        m2 = re2.match(line)
        m3 = re3.match(line)
        if m1 is not None:
            line = 'olcObjectIdentifier: %s' % m1.group(1)
        elif m2 is not None:
            line = 'olcObjectClasses: %s'    % m2.group(1)
        elif m3 is not None:
            line = 'olcAttributeTypes: %s'   % m3.group(1)

        d.write( reSp.sub(line, '  ') )


    s.close()
    d.close()
    return d.name


def main():
    module = AnsibleModule(
        argument_spec   = dict(
            target      = dict( default=None ),
            module      = dict( default=None ),
            suffix      = dict( default=None ),
            format      = dict( default="ldif", choices=["ldif","slapd.conf"] ),
            name        = dict( default=None ),
            local       = dict( default="no", choices=["no","file","template"] ),
            delete      = dict( default=None ),
        ),
        supports_check_mode=True
    )

    params      = module.params
    target      = params['target']
    mod         = params['module']
    suffix      = params['suffix']
    form        = params['format']
    name        = params['name']
    delete      = params['delete']

    changed = False
    try:
        if delete is not None:
            if name is None:
                module.fail_json(msg="missing name")
            l = ldap.initialize( 'ldapi://' )
            l.sasl_interactive_bind_s('', ldap.sasl.external())
            if delete == 'entry':
                filterStr = '(objectClass=*)'
            else:
                filterStr = [ '(%s=*)' % x for x in delete.split(',') ]
                if len(filterStr) > 1:
                    filterStr = '(|' + ''.join(filterStr) + ')'
                else:
                    filterStr = filterStr[0]

            try:
                r = l.search_s( name, ldap.SCOPE_BASE, filterStr, attrsonly=1 )
            except ldap.LDAPError, ldap.NO_SUCH_OBJECT:
                r = None

            if r:
                changed = True
                if module.check_mode:
                    module.exit_json(changed=changed)
                if delete == 'entry':
                    l.delete_s(r[0][0])
                else:
                    attrlist = list(set(r[0][1].keys()) & set(delete.split(',')))
                    l.modify_s(r[0][0], [ (ldap.MOD_DELETE, x, None) for x in attrlist ])
            l.unbind_s()

        else:
            if form == 'slapd.conf':
                if name is None:
                    module.fail_json(msg="missing name")
                target = slapd_to_ldif(target, name)

            if target is None and mod is None:
                module.fail_json(msg="missing target or module")
            # bind only once per LDIF file for performance
            l = ldap.initialize( 'ldapi://' )
            l.sasl_interactive_bind_s('', ldap.sasl.external())

            if mod is None:
                callback = partial(processEntry,module,l)
            else:
                changed |= loadModule (module, l, '%s.la' % mod)
                if target is None and suffix is None:
                    l.unbind_s()
                    module.exit_json(changed=changed)
                if target is None or suffix is None:
                    module.fail_json(msg="missing target or suffix")
                r = getDN_DB(module, l, 'olcSuffix', suffix)
                if not r:
                    module.fail_json(msg="No database found for suffix %s" % suffix)
                elif len(r) > 1:
                    module.fail_json(msg="Multiple results found! This is a bug. Please report.")
                else:
                    d = 'olcOverlay=%s,%s' % (mod, r.pop()[0])
                    callback = lambda _,e: processEntry(module,l,d,e)

            parser = LDIFCallback( module, open(target, 'r'), callback )
            parser.parse()
            changed = parser.changed
            l.unbind_s()

    except subprocess.CalledProcessError, e:
        module.fail_json(rv=e.returncode, msg=e.output.rstrip())
    except ldap.LDAPError, e:
        e = e.args[0]
        if 'info' in e.keys():
            msg = e['info']
        elif 'desc' in e.keys():
            msg = e['desc']
        else:
            msg = str(e)
        module.fail_json(msg=msg)
    except KeyError, e:
        module.fail_json(msg=str(e))

    module.exit_json(changed=changed)

# import module snippets
from ansible.module_utils.basic import *

main()