#!/usr/bin/perl -w

# scan_psmm.pl: scan a position-specific Markov model across sequences

# Written by Martin C Frith 2006
# Genome Exploration Research Group, RIKEN GSC and
# Institute for Molecular Bioscience, University of Queensland

# This program takes two parameters: a PSMM in the format produced by
# make_psmm.pl, and a list of sequences in multi-fasta format. The
# PSMM is scanned across each sequence: at each position, a score is
# calculated, indicating how strongly that chunk of the sequence
# matches the PSMM. The score is the base 2 logarithm of: Prob(chunk |
# PSMM) / Prob(chunk | null model). The null model is a
# non-position-specific Markov model: it is constructed by summing the
# position-specific counts for each k-mer in the PSMM, to get
# non-position-specific k-mer counts. Alternatively, if option -b is
# selected, a null model of uniform k-mer frequencies is used. In
# either case, chunks containing letters other than a, c, g, t get a
# score of -1e100. Chunks containing k-mers with zero counts in the
# PSMM get a score of -1e10.

use strict;
use Getopt::Std;
use List::Util qw(sum);
use File::Basename;

my $alphabet = "acgt";  # can change this to e.g. protein alphabet

my %opts;
getopts('b', \%opts);
die "Usage: ", basename($0), " [-b] PSMM_file sequences.fa\n" unless @ARGV;
my $psmm_file = shift;

my %freq;  # store the k-mer counts here
my $word_size;
my $windows;

open FILE, $psmm_file or die $!;
warn "reading PSMM...\n";
while (<FILE>) {
    s/#.*//;  # allow for comment lines beginning with #
    next unless /\S/;  # ignore blank lines
    my @fields = split;
    my $kmer = shift @fields;
    my $ws = length $kmer;
    my $win = @fields;
    die "unequal word sizes" if defined $word_size and $ws != $word_size;
    die "unequal fields per line" if defined $windows and $win != $windows;
    die "repeated kmer" if exists $freq{$kmer};
    $word_size = $ws;
    $windows = $win;
    $freq{$kmer} = \@fields;
}
close FILE or die $!;

warn "alphabet = $alphabet, word size = $word_size, windows = $windows\n";
warn "preprocessing...\n";

my %freq2;  # store (k-1)-mer counts here
for my $kmer (keys %freq) {
    my $jmer = substr $kmer, 0, -1;  # leave off the last letter
    for (my $i = 0; $i < $windows; ++$i) {
	$freq2{$jmer}[$i] += $freq{$kmer}[$i];
    }
    my $jmer2 = substr $kmer, 1;  # leave off the first letter
    $freq2{$jmer2}[$windows] += $freq{$kmer}[$windows-1];
}

my %tot_freq;  # store k-mer totals here
my %tot_freq2;  # store (k-1)-mer totals here

$tot_freq{$_} = sum @{$freq{$_}} for keys %freq;
$tot_freq2{$_} = sum @{$freq2{$_}} for keys %freq2;

my $tot_tot = sum values %tot_freq;
my $tot_tot2 = sum values %tot_freq2;

my %scores;  # store k-mer scores here
my %scores2;  # store (k-1)-mer scores here
my $knorm = $word_size * log(length $alphabet) / log(2);
my $jnorm = ($word_size - 1) * log(length $alphabet) / log(2);

for (my $i = 0; $i < $windows; ++$i) {
    my $tot = 0;
    for my $kmer (keys %freq) {
	$tot += $freq{$kmer}[$i];
    }
    for my $kmer (keys %freq) {
	$scores{$kmer}[$i] = safe_log($freq{$kmer}[$i] / $tot) +
	    (exists $opts{'b'} ? $knorm :
	     -safe_log($tot_freq{$kmer} / $tot_tot));
    }
    for my $jmer (keys %freq2) {
	$scores2{$jmer}[$i] = safe_log($freq2{$jmer}[$i] / $tot) +
	    (exists $opts{'b'} ? $jnorm :
	     -safe_log($tot_freq2{$jmer} / $tot_tot2));
    }
}

warn "scanning sequences...\n";
my $seq = '';
while (<>) {
    if (/^>/) {
	do_it();
	$seq = '';
	print $_;
    } else {
	$seq .= $_;
    }
}
do_it();

sub do_it {
    $seq =~ tr/a-zA-Z//cd;  # remove non-alphabetic characters (e.g. newlines)
    $seq =~ tr/A-Z/a-z/;  # lc doesn't seem to work for big sequences

    my $big_win = length($seq) - $windows - $word_size + 2;

    for (my $start = 0; $start < $big_win; ++$start) {
	my $s = 0;  # total score
	for (my $i = 0; $i < $windows; ++$i) {
	    my $kmer = substr $seq, $start + $i, $word_size;
	    if ($kmer =~ /[^$alphabet]/o) {
		$s = -1e100;
		last;
	    }
	    if (!exists $freq{$kmer} or $freq{$kmer}[$i] == 0) {
		$s = -1e10;
		last;
	    }
	    $s += $scores{$kmer}[$i];
	    next if $i == 0;
	    my $jmer = substr $seq, $start + $i, $word_size-1;
	    $s -= $scores2{$jmer}[$i];
	}
	printf "%.3g\n", $s;
    }
}

sub safe_log {
    my $x = shift;
    return $x > 0 ? log($x)/log(2) : -1e10;
    # -1e10 is probably much lower than log(DBL_MIN)/log(2)
}
