#!/usr/local/bin/python3.12 -u
# -*- coding:Utf-8 -*-


"""
Author :      Vincent <vincent.delft@gmail.com>
Version :     0.4
Licence :     BSD
Require :     python >= 3.6
              use sqlite3 DB embedded with python package
Developed on: OpenBSD
Tested on :   OpenBSD 6.4, Windows 10, osx 10.14
Description : This tool allow you to calculate a checksum for each files in the target folder
              Those values are stored in an sqlite DB at the root of your targetted folder
              This program use INODE as key instead of filename, so it can manage hardlinks
              Since that, the scrip does never go outside the targetted filesystem
              it works on openBSD, but should work on any systems (OSX, Windows and Linux)

              Typically, you must perform a first scan of the folder you want:
                    yabitrot -p <folder>
              Then, you can re-scan your folder and yabitrot will compare check sums
              with what we can find in te DB
                    yabitrot -p <folder>


/*
 * Copyright (c) 2018 Vincent Delft <vincent.delft@gmail.com>
 *
 * Permission to use, copy, modify, and distribute this software for any
 * purpose with or without fee is hereby granted, provided that the above
 * copyright notice and this permission notice appear in all copies.
 *
 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 */
"""


import zlib
import time
import os.path
import sqlite3
import sys
import argparse
import fnmatch
import errno
import stat

DEFAULT_CHUNK_SIZE = 16384
COMMIT_LIMIT = 30
BATCHID = time.time()
VERBOSE = 0
LOGFILE = ""
DRY_RUN = False
DB_FILE_NAME = ".cksum.db"
BIGFILE_TTS = 10  # number of seconds after which the chksum informs (if verbose) that he is working on a big file


def log(text):
    """put text in logfile (if provided), or on std output."""
    if LOGFILE:
        tts = time.strftime("%c", time.localtime())
        fid = open(LOGFILE, 'a')
        fid.write("%s: %s\n" % (tts, text))
        fid.close()
    else:
        sys.stdout.write(text + "\n")
        sys.stdout.flush()


def print_err(text):
    """put text in logfile (if provided) or to std error."""
    if LOGFILE:
        tts = time.strftime("%c", time.localtime())
        fid = open(LOGFILE, 'a')
        fid.write("%s: %s\n" % (tts, text))
        fid.close()
    sys.stderr.write(text + "\n")
    sys.stderr.flush()


def get_cksum(path, osstats, status, chunk_size=DEFAULT_CHUNK_SIZE):
    """return the value that the DB is expecting: cksum, mtime."""
    localtts = time.time()
    crc = 0
    stats = {'crc': None, 'mtime': None}
    stats['mtime'] = osstats.st_mtime
    try:
        with open(path, 'rb') as f:
            d = f.read(chunk_size)
            while d:
                crc += zlib.adler32(d)
                d = f.read(chunk_size)
                if VERBOSE > 0 and time.time() - localtts > BIGFILE_TTS:
                    # for pur information, we log the big consumers
                    log("big file: %s, inode: %s, size: %.2f MB, %s" % (status, osstats.st_ino, osstats.st_size / 1024 / 1024, path))
                    localtts = time.time()
    except OSError as ex:
        if ex.errno in [errno.EACCES, errno.EOPNOTSUPP]:
            print_err("Failed to read:%s" % path)
        else:
            raise
    stats['crc'] = "%s" % crc
    return stats


def get_osstats(fpath, excludes):
    """return the os.stat of the selected file."""
    to_skip = False
    for excl_patt in excludes:
        if fnmatch.fnmatch(fpath, excl_patt):
            to_skip = True
    if to_skip:
        if VERBOSE > 0:
            log("Based on exclude rules, we skip: %s" % fpath)
        return None
    try:
        osstats = os.stat(fpath)
    except OSError as ex:
        if ex.errno in [errno.EACCES, errno.EOPNOTSUPP, errno.ENOENT]:
            log("os.stat fails for: %s" % fpath)
            return None
        else:
            raise
    if not stat.S_ISREG(osstats.st_mode):
        if VERBOSE > 0:
            log("Not a regular file: %s" % fpath)
        return None
    return osstats


class CRCDB:
    """The DB class with 2 tables: cksum and params.
    cksum is the main DB where we store indes and associated chckesums
    params is a key-pair values table. Currently we store:
            rootpath: the path from where we perform an anlysis
            filesystem id: the id of the targeted filesystem
    """
    def __init__(self, fpathname, commitlimit=30):
        self.counter = 0
        self.tts = time.time()
        self.commitlimit = commitlimit
        self.conn = None
        self.cur = None
        if os.path.exists(fpathname):
            self.conn = sqlite3.connect(fpathname)
            self.cur = self.conn.cursor()
            tables = set(t for t, in self.cur.execute('SELECT name FROM sqlite_master'))
            if 'cksum' not in tables:
                self._create_db(fpathname)
            if 'params' not in tables:
                self._create_params(fpathname)
        else:
            self._create_db(fpathname)
            self._create_params(fpathname)

    def _create_db(self, fpathname):
        self.conn = sqlite3.connect(fpathname)
        self.cur = self.conn.cursor()
        self.cur.execute("""CREATE TABLE cksum (
                           inode INTEGER PRIMARY KEY,
                           mtime REAL,
                           hash TEXT,
                           timestamp REAL)""")
        self.conn.commit()

    def _create_params(self, fpathname):
        if not self.conn:
            self.conn = sqlite3.connect(fpathname)
        if not self.cur:
            self.cur = self.conn.cursor()
        self.cur.execute("""CREATE TABLE params (
                           param TEXT PRIMARY KEY,
                           value TEXT
                           )""")
        self.conn.commit()

    def get_rec(self, inode):
        self.cur.execute('SELECT mtime, hash, timestamp FROM cksum WHERE '
                         'inode=?', (inode,))
        ret = self.cur.fetchone()
        if ret and not DRY_RUN:
            self.cur.execute('UPDATE cksum SET timestamp=? WHERE inode=?', (BATCHID, inode))
            self.commit()
            return ret
        return None

    def update_rec(self, inode, stats):
        if not DRY_RUN:
            self.cur.execute('UPDATE cksum SET mtime=?, hash=?, timestamp=? '
                         'WHERE inode=?',
                         (stats['mtime'], stats['crc'], BATCHID, inode))
            self.commit()

    def add_rec(self, inode, stats):
        if not DRY_RUN:
            self.cur.execute('INSERT INTO cksum VALUES (?, ?, ?, ?)',
                         (inode, stats['mtime'], stats['crc'], BATCHID))
            self.commit()

    def remove_rec(self, inode):
        if not DRY_RUN:
            self.cur.execute('DELETE FROM cksum WHERE inode=?', (inode,))
            self.commit()

    def commit(self):
        self.counter += 1
        if time.time() - self.tts > self.commitlimit:
            self.conn.commit()
            if VERBOSE > 0:
                log('commit %s files in %.2f sec' % (self.counter, time.time() - self.tts))
            self.tts = time.time()
            self.counter = 0

    def set_param(self, param, value):
        if not DRY_RUN:
            # print("param update",  param,  value)
            self.cur.execute('INSERT OR REPLACE into params VALUES (?, ?)', (param, value))
            self.commit()

    def get_param(self):
        self.cur.execute('SELECT * from params')
        return self.cur.fetchall()

    def close(self):
        self.conn.commit()
        self.conn.close()

    def cleanup(self):
        self.cur.execute('SELECT inode FROM cksum WHERE timestamp != ?', (BATCHID,))
        ret = self.cur.fetchall()
        if ret:
            if DRY_RUN:
                log("%s files could be removed" % (len(ret)))
            else:
                log("%s files removed from DB" % len(ret))
                self.cur.execute('DELETE from cksum WHERE timestamp !=?', (BATCHID,))
        else:
            log("No cleanup required")

    def count(self):
        self.cur.execute("SELECT count(*) from cksum")
        return self.cur.fetchone()


def analyze(rootpath, excludes=[]):
    """ananlyze rootath and all sub-folders. 
    if a DB exists, it compare the checksum of the associated inode with what we have in the DB for this inode
    if a DB does not exists it store the checksum associated to the inode
    """
    dbpath = os.path.join(rootpath, DB_FILE_NAME)
    DB = CRCDB(dbpath, COMMIT_LIMIT)
    log("DB stored on: %s" % (dbpath))
    excludes.append('*/%s' % DB_FILE_NAME)
    excludes.append('*/%s-journal' % DB_FILE_NAME)
    counter = 0
    counter_added = 0
    counter_update = 0
    counter_biterror = 0
    total_size = 0
    parameters = DB.get_param()
    filesystemid = os.stat(rootpath).st_dev
    log("Device ID:%s" % filesystemid)
    if parameters:
        for param_name, value in parameters:
            if param_name == 'rootpath' and value != rootpath:
                print_err("We have detected a DB at %s" % dbpath)
                print_err("This DB has been created with the path: %s" %value)
                print_err("But, you have entered the following path: %s" % rootpath)
                return -1
            if param_name == "filesystemid" and value != str(filesystemid):
                print_err("We have detected a DB at %s" % dbpath)
                print_err("This DB has been created with the filesystemID:%s" % filesystemid)
                print_err("But, currently the filesystem ID is: %s" % filesystemid)
                return -1
    else:
        DB.set_param("rootpath", rootpath)
        DB.set_param("filesystemid", filesystemid)
    analyze_tts = time.time()
    for path, dummy, files in os.walk(rootpath):
        for elem in files:
            fpath = os.path.join(path, elem)
            osstats = get_osstats(fpath, excludes)
            if not osstats or osstats.st_dev != filesystemid:
                # print("skip:%s" % fpath)
                continue
            if VERBOSE > 1 and time.time() - analyze_tts > COMMIT_LIMIT:
                log("working with:", fpath)
                analyze_tts = time.time()
            # print("process:%s" % fpath)
            counter += 1
            db_rec = DB.get_rec(osstats.st_ino)
            cksum = None
            if db_rec is None:
                cksum = get_cksum(fpath, osstats, "new")
                DB.add_rec(osstats.st_ino, cksum)
                counter_added += 1
            else:
                if db_rec[2] != BATCHID:
                    cksum = get_cksum(fpath, osstats, "update")
                    if db_rec[0] != osstats.st_mtime:
                        DB.update_rec(osstats.st_ino, cksum)
                        counter_update += 1
                    elif db_rec[1] != cksum['crc']:
                        log("bit ERROR for file %s" % (fpath))
                        log("   Previous:")
                        log("      scan was on %s" % time.strftime("%c", time.localtime(db_rec[2])))
                        log("      checksum was: %s" % db_rec[1])
                        log("      mtime was: %s" % time.strftime("%c", time.localtime(db_rec[0])))
                        log("   Current:")
                        log("      scan on %s" % time.strftime("%c", time.localtime(BATCHID)))
                        log("      checksum is: %s" % cksum['crc'])
                        log("      mtime is: %s" % time.strftime("%c", time.localtime(osstats.st_mtime)))
                        counter_biterror += 1
            if cksum and cksum['crc']:
                total_size += osstats.st_size
    log("\n")
    DB.cleanup()
    records = DB.count()
    DB.close()
    log("%s files added" % counter_added)
    log("%s files updates" % counter_update)
    log("%s files error" % counter_biterror)
    log("%s files analysed in %.2f sec, %.3f GB" % (counter, time.time() - BATCHID, total_size / 1024 / 1024 / 1024))
    log("%s entries in the DB" % records)
    if os.name == 'posix' and not DRY_RUN:
        os.chmod(dbpath, stat.S_IRUSR | stat.S_IWUSR)
        os.chown(dbpath, os.getuid(), os.getgid())
    if counter_biterror > 0:
        print_err("Several bit error, please check the log file")
    sys.exit(counter_biterror)


def force_db(fpath, rootpath, excludes=[]):
    """This udate the DB record for this inode"""
    dbpath = os.path.join(rootpath, DB_FILE_NAME)
    DB = CRCDB(dbpath, COMMIT_LIMIT)
    log("DB stored on: %s" % (dbpath))
    excludes.append('*/%s' % DB_FILE_NAME)
    excludes.append('*/%s-journal' % DB_FILE_NAME)
    filesystemid = os.stat(rootpath).st_dev
    log("Device ID:%s" % filesystemid)
    to_skip = False
    for excl_patt in excludes:
        if fnmatch.fnmatch(fpath, excl_patt):
            to_skip = True
    if to_skip:
        print_err("The file you want is in the exclude list")
        print_err("File name is: %s" % fpath)
        print_err("Exclude list is: %s" % ",".join(excludes))
        return 1
    try:
        osstats = os.stat(fpath)
    except OSError as ex:
        if ex.errno in [errno.EACCES, errno.EOPNOTSUPP, errno.ENOENT]:
            osstats = None
        else:
            raise
    if osstats is None:
        log("os.stat fails for: %s" % fpath)
        return 2
    stats = get_cksum(fpath, osstats, "update")
    DB.update_rec(osstats.st_ino, stats)
    log("checkcum calculated and stored in the DB")
    DB.close()
    return 0


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '-s', '--chunk-size', type=int, default=DEFAULT_CHUNK_SIZE,
        help='read files this many bytes at a time. Default is %s' % DEFAULT_CHUNK_SIZE)
    parser.add_argument(
        '-c', '--commit-limit', type=int, default=COMMIT_LIMIT,
        help='number of DB actions before committing them. Default is %s' % COMMIT_LIMIT)
    parser.add_argument(
        '-p', '--path', type=str, default='.',
        help='Path to analyse. Default is "."')
    parser.add_argument(
        '-e', '--exclude', type=str, default='',
        help='file types to exclude with the fnmath format. For example *.core,*.tmp. Default is ""')
    parser.add_argument(
        '-v', '--verbose', type=int, default=0,
        help='verbosity level, currently from 0 to 2. Default is 0')
    parser.add_argument(
        '-n', '--dry-run', action='store_true',
        help='perform the task, but do not update the DB')
    parser.add_argument(
        '-L', '--log', type=str, default='',
        help='put mesage in the log instead to stdout')
    parser.add_argument(
        '-f', '--force', type=str, default='',
        help='Force checksum for a specific file')
    args = parser.parse_args()
    path = args.path
    if args.log:
        LOGFILE = args.log
    if args.verbose:
        VERBOSE = args.verbose
    if args.chunk_size:
        DEFAULT_CHUNK_SIZE = args.chunk_size
    if args.commit_limit:
        COMMIT_LIMIT = args.commit_limit
    if args.dry_run:
        DRY_RUN = True
    to_exclude = []
    if args.exclude:
        to_exclude = args.exclude.split(",")
    if args.force:
        ret = force_db(args.force, path, to_exclude)
        sys.exit(ret)
    analyze(path, to_exclude)
