#!/usr/bin/python
# Copyright 1999-2017. Parallels IP Holdings GmbH. All Rights Reserved.
import os, sys, re, socket, tempfile
import binascii

db_login='admin'
db_pass=None
db_name='psa'
psa_conf= {}
map_file=None
ifmng_bin = None
mysql_cmd = None

mapfile_header="""# You should edit IP addresses, netmasks and interfaces to reflect your
# future settings. If you don't want the IP to be changed - leave it untouched,
# comment out it's line or remove entire line from the file.

"""

cfg_line_pattern = re.compile(r"(\w+)\s+(\S+)\s*")

def usage():
	print """Plesk reconfigurator - utility to change IP addresses used by Plesk

Usage: %s { <map_file> | --autoconfigure | --remap-ips | --help }

If <map_file> doesn't exists - template will be created, otherwise it will be used to map IP addresses.

--autoconfigure option will attempt to create and process IP mapping automatically. Any new excessive 
or old unmapped IP addresses will retain their status and would need to be handled manually either by 
rereading IP addresses or by passing a correct map file to this utility.

--remap-ips is an alias for --autoconfigure option.

--help option displays this help page.
""" % sys.argv[0]

def err(msg):
	print >>sys.stderr, msg

def report(*args):
	msg = ""
	for i in args:
		msg += str(i)
		msg += " "
	print msg

def init():
	global db_pass, psa_conf, ifmng_bin, mysql_cmd
	fd = open('/etc/psa/.psa.shadow')
	try:
		db_pass=fd.read().strip()
	finally:
		fd.close()

	fd = open('/etc/psa/psa.conf')
	try:
		for ln in fd.readlines():
			pm = cfg_line_pattern.match(ln)
			if pm:
				psa_conf[pm.group(1)] = pm.group(2)
	finally:
		fd.close()

	ifmng_bin = os.path.join(psa_conf["PRODUCT_ROOT_D"], 'admin/bin/ifmng')
	mysql_cmd = """%s '-u%s' '-p%s' '-D%s' """ % (os.path.join(psa_conf["MYSQL_BIN_D"], 'mysql'), db_login, db_pass, db_name)
	

def readCmd(cmd):
	fd = os.popen(cmd)
	try:
		return fd.readlines()
	finally:
		fd.close()

def execCmd(*args):
	cmd = ' '.join(args)
	result = os.system(cmd)
	if result > 0:
		raise Exception("%s: exited with non-zero code %s" % (cmd, result))
	elif result < 0:
		raise Exception("%s: killed by signal %s" % (cmd, result))



def ifmng(*args):
	return execCmd(ifmng_bin, *args)

def execUtil(util, *args):
	return execCmd(os.path.join(psa_conf["PRODUCT_ROOT_D"], 'admin/bin', util), *args)

def readDBAddresses():
	IPs = []
	sqlQuery = """SELECT ip.id, ip.iface, ip.ip_address, ip.mask 
				FROM IP_Addresses ip
				INNER JOIN ServiceNodes sn
					ON ip.serviceNodeId = sn.id
				WHERE sn.ipAddress = 'local'"""
	for ident, iface, ip, mask in db_query(sqlQuery):
		IP = IPAddr(ip, mask, iface)
		IP.ident = ident
		IPs.append(IP)
	return IPs
	
def createMapFile(fn):
	fd = open(fn, "w")
	try:
		fd.write(mapfile_header)
		for addr in readDBAddresses():
			fd.write("%s -> %s\n" % (addr, addr))
	finally:
		fd.close()


class IPAddr:
	def __init__(self, addr, netmask = None, iface = None):
		self.netmask = netmask
		self.iface = iface
		try:
			self.binaddr = socket.inet_pton(socket.AF_INET, addr)
			self.af = socket.AF_INET
		except:
			try:
				self.binaddr = socket.inet_pton(socket.AF_INET6, addr)
				self.af = socket.AF_INET6
			except:
				raise Exception("%s: is neither ipv6 nor ipv4 address" % self.addr)
		self.addr = socket.inet_ntop(self.af, self.binaddr)

	def __eq__(self, rhs):
		return self.binaddr.__eq__(rhs.binaddr)

	def __hash__(self):
		return self.binaddr.__hash__()

	def __str__(self):
		return "%s %s %s" % (self.iface, self.addr, self.netmask)

	def __repr__(self):
		return "<%s>" % self.addr

	def fullAddr(self):
		if (self.af == socket.AF_INET6):
			hexstr = binascii.hexlify(self.binaddr)
			return ':'.join(hexstr[i:i+4] for i in range(0, len(hexstr), 4))
		else:
			return self.addr

	def isIPv4(self):
		return self.af == socket.AF_INET

	def isIPv6(self):
		return self.af == socket.AF_INET6

def readSystemAddresses():
	rv = []
	for ln in readCmd("%s -l" % ifmng_bin):
		addr, mask, iface, main = ln.split()
		rv.append(IPAddr(addr, mask, iface))

	return rv

empty_pattern = re.compile(r"\s*(#.*)?")
map_pattern = re.compile(r"([^#]+)->([^#]+)(#.*)?")

addr_pattern=re.compile(r"\s*(\w+):?\s*(\S+)\s+([0-9.]+)")
def parseAddr(addr):
	am = addr_pattern.match(addr)
	if am is None:
		raise Exception("%s: cannot parse", addr)
	return am.group(1), am.group(2), am.group(3)

def readMapping(fn):
		
	fd = open(fn)
	errors = 0
	rv = {}
	try:
		for ln in fd.readlines():
			pmatch = map_pattern.match(ln)
			if not pmatch:
				if not empty_pattern.match(ln):
					err("%s: cannot parse" % ln)
					errors += 1
			else:
				try:
					iface_f, addr_f, mask_f = parseAddr(pmatch.group(1))
					iface_t, addr_t, mask_t = parseAddr(pmatch.group(2))
					ipf = IPAddr(addr_f, mask_f, iface_f)
					ipt = IPAddr(addr_t, mask_t, iface_t)
					if ipf in rv:
						err("%s: is already mapped to %s" % (ipf.addr, rv[ipf].addr))
						errors += 1
					if ipt in rv.values():
						err("%s: is already used as target address" % (ipt.addr))
						errors += 1

					rv[ipf] = IPAddr(addr_t, mask_t, iface_t)
				except:
					err("%s: cannot parse" % ln.strip())
					errors += 1

	finally:
		fd.close()
	for addr in rv:
		if addr.af != rv[addr].af:
			err("%s=>%s: it is not allowed to change IPv6 address to IPv4 and vice versa" % (addr.addr, rv[addr].addr))
			errors += 1
	if errors:
		raise Exception("%s: cannot parse, %s errors found." % (fn, errors))
	return rv

def db_query(stmt):
	cmd = """%s -s -N -e "%s" """ % (mysql_cmd, stmt)
	return [s.split() for s in readCmd(cmd)]


def generateUpdateSQL(mapping):
	fd, fn = tempfile.mkstemp('.sql')
	fd = os.fdopen(fd, "w")
	fd.write("BEGIN;\n")

	for addr in mapping:
		naddr = mapping[addr]
		fd.write("UPDATE IP_Addresses SET ip_address='%s', iface='%s', mask='%s' WHERE id='%s';\n" % (naddr.addr, naddr.iface, naddr.netmask, addr.ident))

	domains = {}

	for dom_id, dom_name, rec_id, rec_type, rec_val in db_query(r"SELECT dom.id, dom.name, rec.id, rec.type, rec.val FROM dns_recs rec JOIN dns_zone z ON (z.id = rec.dns_zone_id) JOIN domains dom ON (dom.dns_zone_id = z.id) WHERE rec.type IN ('A', 'AAAA')"):
		rec_ip = IPAddr(rec_val)
		if rec_ip in mapping:
			new_addr = mapping[rec_ip]
			if rec_type == 'A' and new_addr.af == socket.AF_INET6:
				rec_type = 'AAAA'
			elif rec_type == 'AAAA' and new_addr.af == socket.AF_INET:
				rec_type = 'A'

			domains[dom_id] = dom_name
			fd.write("UPDATE dns_recs SET type='%s', val='%s', displayVal='%s' WHERE id = '%s';\n" % (rec_type, new_addr.addr, new_addr.addr, rec_id))
		
	for dom_id, dom_name, rec_id, rec_type, rec_host in db_query(r"SELECT dom.id, dom.name, rec.id, rec.type, rec.host FROM dns_recs rec JOIN dns_zone z ON (z.id = rec.dns_zone_id) JOIN domains dom ON (dom.dns_zone_id = z.id) WHERE rec.type IN ('PTR')"):
		rec_ip = IPAddr(rec_host)
		if rec_ip in mapping:
			new_addr = mapping[rec_ip]
			fd.write("UPDATE dns_recs SET host='%s', displayHost='%s' WHERE id = '%s';\n" % (new_addr.fullAddr(), new_addr.addr, rec_id))

			domains[dom_id] = dom_name

	for rec_id, rec_type, rec_host in db_query(r"SELECT rec.id, rec.type, rec.host FROM dns_recs rec WHERE dns_zone_id IS NULL"):
		rec_ip = IPAddr(rec_host)
		if rec_ip in mapping:
			new_addr = mapping[rec_ip]
			if rec_type == 'PTR':
				fd.write("UPDATE dns_recs SET host='%s', displayHost='%s' WHERE id = '%s';\n" % (new_addr.fullAddr(), new_addr.addr, rec_id))
			elif rec_type != 'none':
				fd.write("UPDATE dns_recs SET host='%s', displayHost='%s' WHERE id = '%s';\n" % (new_addr.addr, new_addr.addr, rec_id))

	fd.write("COMMIT;\n")
	fd.close()
	return fn, domains


def reconfigure(mapping):
	report("Generating DB update script... ")
	sqlfile, affected_domains = generateUpdateSQL(mapping)
	report("ok")

	report("Updating database... ")
	execCmd(mysql_cmd, " < ", sqlfile)
	os.unlink(sqlfile)
	report("ok")
	if affected_domains:
		report("Reconfiguring DNS:")
		for domain in affected_domains.values():
			report("domain %s..." % domain)
			execUtil('dnsmng', '--update', domain)
			report("ok")
		report("Restarting DNS service...")
		execUtil('dnsmng', '--restart')
		report("ok")

# As we got affected domains from DNS, we cannot rely on it in case DNS not installed
	report("Reconfiguring Apache...")
	execUtil('httpdmng', '--reconfigure-all')
	report("ok")

	report("Reconfiguring Proftpd...")
	execUtil('ftpmng', '--reconfigure-all')
	report("ok")

	wd = os.path.join(psa_conf["PRODUCT_ROOT_D"], 'admin/bin/modules/watchdog/wd')
	if os.path.exists(wd):
		report("Reconfiguring Watchdog module...")
		execCmd(wd, "--adapt")
		report("ok")

	transport_restore = '/usr/lib/plesk-9.0/mail_postfix_transport_restore'
	if os.path.exists(transport_restore):
		report("Rebuilding Postfix transport map...")
		execCmd(transport_restore)
		report('ok')

	report("Refresh trusted IPs for site preview")
	refresh_trusted_ips = os.path.join(
		psa_conf["PRODUCT_ROOT_D"],
		'admin/plib/scripts/refresh-trusted-ips.php')
	execUtil('php', refresh_trusted_ips)
	report('ok')

def fileMapping(mapfile):
	system_addresses = readSystemAddresses()
	mapping = readMapping(mapfile)
	db_addresses = readDBAddresses()

	# finding out DB ids
	for ip in mapping:
		if ip in db_addresses:
			dbIP = db_addresses[db_addresses.index(ip)]
			ip.ident = dbIP.ident

	errors = 0
	for addr in mapping:
		if addr not in db_addresses:
			err("%s: address is not used by Plesk" % addr.addr)
			errors += 1
	for addr in db_addresses:
		if addr not in mapping and addr in mapping.values():
			err("%s: address is already used by Plesk" % addr.addr)
			errors +=1
	if errors:
		raise Exception("%s: %d conflicts found" % (mapfile, errors))
	
	clean_mapping = {}
	for addr in mapping:
		if mapping[addr] != addr:
			clean_mapping[addr] = mapping[addr]
	mapping = clean_mapping

	for addr in mapping.values():
		exists = False
		for eaddr in system_addresses:
			if eaddr == addr:
				if eaddr.iface != addr.iface or eaddr.netmask != addr.netmask:
					ifmng("--del", eaddr.addr)
				else:
					exists = True
		if not exists:
			report("Adding %s..." % addr)
			ifmng("--add", addr.iface, addr.addr, addr.netmask)
			report("ok")
	
	return mapping		


class IpListHelper:
	def __init__(self, addresses):
		self._addresses = addresses

	def getIp(self, iface=None):
		if iface is None:
			return self._addresses.pop(0)
		for i in range(len(self._addresses)):
			if self._addresses[i].iface == iface:
				addr = self._addresses[i]
				del self._addresses[i]
				return addr
		print >>sys.stderr, "Address on interface '%s' not found." % iface
		return self._addresses.pop(0)

	def addresses(self):
		return self._addresses


def autoMapping():
	dbData = readDBAddresses()
	sysData = readSystemAddresses()

	report("Database:", dbData)
	report("Actual:", sysData)
	
	removedIPv4s = []
	removedIPv6s = []
	addedIPv4s = []
	addedIPv6s = []
	mapping = {}

	# searching for removed IPs
	for dbIP in dbData:
		if dbIP not in sysData:
			if dbIP.isIPv4():
				removedIPv4s.append(dbIP)
			elif dbIP.isIPv6():
				removedIPv6s.append(dbIP)
			else:
				assert(False)

	# searching for added IPs
	for sysIP in sysData:
		if sysIP not in dbData:
			if sysIP.isIPv4():
				addedIPv4s.append(sysIP)
			elif sysIP.isIPv6():
				addedIPv6s.append(sysIP)
			else:
				assert(False)
	
	report("Removed IPs:", removedIPv4s, removedIPv6s)
	report("Added IPs:", addedIPv4s, addedIPv6s)
	
	ip4ListHelper = IpListHelper(addedIPv4s)
	for ipFrom in removedIPv4s[:]:
		try:
			ipTo = ip4ListHelper.getIp(iface=ipFrom.iface)
			removedIPv4s.remove(ipFrom)
			mapping[ipFrom] = ipTo
		except:
			break
	ip6ListHelper = IpListHelper(addedIPv6s)
	for ipFrom in removedIPv6s[:]:
		try:
			ipTo = ip6ListHelper.getIp(iface=ipFrom.iface)
			removedIPv6s.remove(ipFrom)
			mapping[ipFrom] = ipTo
		except:
			break

	report("Mapping:", mapping)
	report("Old not remapped:", removedIPv4s + removedIPv6s)
	report("New not used:", ip4ListHelper.addresses() + ip6ListHelper.addresses())
	return mapping
			
def main():
	if len(sys.argv) != 2 or sys.argv[1] in ("--help", "-h"):
		usage()
		exit(1)
	init()

	mapping = None

	if sys.argv[1] in ("--autoconfigure", "--remap-ips"):
		mapping = autoMapping()
	else:
		map_file = sys.argv[1]
		if not os.path.exists(map_file):
			createMapFile(map_file)
			report("""IP map file template '%s' is successfully created. 
Edit it to declare desired configuration, and start reconfigurator again with --file '%s'.
""" % (map_file, map_file))
		else:
			mapping = fileMapping(map_file)		

	if mapping:
		reconfigure(mapping)
		report("IP addresses are successfully changed.")
	else:
		report("Nothing to do.")
	
if __name__ == "__main__":
	try:
		main()
	except Exception, e:
		for a in e.args:
			print >> sys.stderr, a

