Contents

function OutputStructure = BayesianNetworkClassifier (InputStructure)
% (c) Karl Kuschner, College of William and Mary, Dept. of Physics, 2009.
%
% DoTheMath takes a data set and performs feature selection
%
% DESCRIPTION
%       DoTheMath takes a data array, class vector, and other information
%       and builds and assesses a Bayesian network after selecting features
%       from within the data array.  It is called from the user interface
%       "orca.m."
%
%       This is the umbrella script that loops a specified number of times
%       (see "repeats" below), each time doing a full n-fold cross
%       validation and recording the results.  All input and output data
%       are stored in a single data structure, described below.
%
% USAGE
%       OutputDataStructure = DoTheMath (InputStructure)
%
% INPUTS
%       InputStructure: Data repository with fields: Intensities: Array of
%       intensity values of size #cases x #variables Class: Vector of
%       length "#cases", with discrete values identifying
%          class of each case (may be integer)
%       ID: Patient ID array of length #cases, with one or more cols MZ:
%       Vector of length "#variables" holding labels for variables Options:
%       Logical 6x1 array. Options are:
%          1. Normalize on population total ion count (sum across rows) 2.
%          Remove negative data values by setting them to zero 3. After
%          normalizing, before binning, average cases with same ID 4. Find
%          the MI threshold by randomization 5. Take log(data) prior to
%          binning.  Negative values set to 1. 6. Remove Low Signal cases
%              NOT DONE: 3 Bin (2 Bin if False)
%       n: the "n" in n-fold cross validation repeats: Times to repeat the
%       whole process (e.g. re-crossvalidate) threshold: Factor by which
%       the maximum "random" MI us multiplied to
%           find the minimum "significant" MI (double, 1.0-5.0).
%
% OUTPUTS
%       OutputDataStructure: all the fields of InputStructure, plus:
%       ErrorRate: Vector containing misclassification rate for each repeat
%       KeyFeatures: Index to vector MZ that identifies features selected
%
% CALLED FUNCTIONS
%
%       InitialProcessing: Applies the options listed above BuildBayesNet:
%       Learns a Bayesian Network from the training data ChooseMetaVars:
%       Combines variables that may not be physically
%           separate molecules.
%       TestCases: Given the BayesNet, tests the "test group" to determine
%           the probability of being in each class.
%       opt3bin: Discretizes continuous data into 3 bins, optimizing MI
%       FindProbTables: Learns the values P(C,V) for each variable
%       cvpartition and training are MATLAB Statistics toolbox functions.


% Initialize Set up (for now) hard coded values:
drop=0.75; % MI loss pecentage threshold for testing independance, see
% clipclassconnections

% Initial Processing According to options, remove negative values,
% normalize and/or take logarithm of data, replicate average. Store in
% output data structure.

%  display('Starting Initial Processing of Data');
OutputStructure = InitialProcessing( InputStructure);
display('Initial processing complete.')
display (' ');


% Get values out of Data structure to be used later
ff=OutputStructure.threshold;
n=double(OutputStructure.n); % for n-fold cross validation; default is 10
repeats=OutputStructure.repeats; % Number of times to repeat CV, default 30
numtrials=repeats*n;
cverrorrate=zeros(numtrials,1);
errorrate=zeros(repeats,1);
data=OutputStructure.Intensities;
class=OutputStructure.Class;

% Find some sizes and initialize variables
[rows cols]=size(data);
% OutputStructure.varlist=zeros(cols,1);
class_predict=zeros(rows,repeats);
class_prob=zeros(rows,repeats);
trial=0; % counter of how many times we perform Bayes Analysis (n*repeats)

% "Repeat Entire Process" Loop

% Repeat all processes the number of times requested
for r=1:repeats
    display (' ');
    display(['Working on repetition number ', num2str(r),' at ',...
num2str(toc/60),' mins']);

% Cross Validation Loop This section selects a training and testing
% group out of the data by dividing it into n groups, and using n-1 of
% those for training and 1 for testing. MATLAB (ver. 2008a or later)
% has a built in class for this. See MATLAB documentation for
% "cvpartition" and "training."
cvgroups = cvpartition ( class, 'kfold', n );

for cv = 1:n % for each of n test groups, together spanning all cases
trial=trial+1; % Keep track of each trial
display(['     Working on cross-validation number ',num2str(cv),...
' of ',num2str(n)])

% The next line uses a function inside "cvpartition" called
% "training" that returns a logical vector identifying which cases
% to use as the training group in cross validation.
traingrpindex=training(cvgroups,cv);

% Use the vec to extract tng data and  class of the tng cases
traingrp=data(traingrpindex,:);
        traingrpclass=class(traingrpindex,:);

% The test cases are cases NOT in the training group
testgrp=data(~traingrpindex,:);
        testgrpclass= class(~traingrpindex,:);

% Discretize the groups into hi-med-low by optimizing MI(V,C) for
% each V (feature) in the training data.

[leftbndry,rightbndry,traingrpbin, maxMI]=opt3bin(traingrp,...
traingrpclass); %#ok<NASGU>

% Build an augmented Naive Bayesian Network with the training data
% The adjacency matrix is a logical with true values meaning "there
% is an arc from row index to column index." The last row
% represents the class variable.

adjmat = BuildBayesNet( traingrpbin, traingrpclass, ff, drop ); %adjacency matrix


% Find MetaVariables, rebuild data Depending on the option set,
% reduce the V->V links by removing them, or combining them into a
% single variable. The result is a naive Bayesian network with only
% connections C->V.

meta_option=1; % Hard coded for now
classrow=cols+1;
        listvec=1:cols; % just a list of numbers
varlist=unique(listvec(adjmat(classrow,:))); % top level vars

if meta_option==1
            [finaldata metas leftbndry rightbndry] = ...
ChooseMetaVars (traingrp, traingrpclass, adjmat);
end

% Bin up the test group using these final results, combining
% variables per the instructions encoded in the "metas" logical
% matrix.

testdata=zeros(size(testgrp));

if isempty(varlist) % in case no links are found
disp ('Not finding any links yet...');
            errorrate(trial) = 1;
else % if we do find links
for var = varlist; % each of the parents of metavariables
metavar=[var listvec(metas(var,:))]; % concatenate children
testdata(:,var)=sum(testgrp(:,metavar),2); % sum parent/child
end


% Now remove empty rows
finaltestdata=testdata(:,varlist);

% And bin the result
testgrpbin=zeros(size(finaltestdata)); %will be stored here
% Build boundary arrays to test against
testcases=size(testgrp,1);
            lb=repmat(leftbndry,testcases,1);
            rb=repmat(rightbndry,testcases,1);
%  test each value and record the bin
testgrpbin(finaltestdata<lb)=1;
            testgrpbin(finaltestdata>=lb)=2;
            testgrpbin(finaltestdata>rb)=3;

% Populate Bayesian Network

% With the final set of data and the adjacency matrix, build
% the probability tables and test each of the test group cases,
% to see if we can determine the class.

% Build the probability tables empirically with the training
% group results
ptable=FindProbTables(finaldata, traingrpclass);
            prior=histc(class, unique(traingrpclass))/max(size(traingrpclass));

% find out the probability of each cases bing in class 1,2,etc.
% Cases are in rows, class in columns.
classprobtable = TestCases (ptable, prior, testgrpbin);
            [P_C predclass]=max(classprobtable,[],2);
            class_prob(~traingrpindex,r)=P_C;
            class_predict(~traingrpindex,r)=predclass;

%Get the per trial error rate
cverrorrate(trial)= sum(predclass==testgrpclass)/testcases;

%Store some "per trial" data
OutputStructure.Adjacency(trial,:,:)=adjmat;
            OutputStructure.MetaVariablesFound(trial,1:cols,1:cols)=metas;
            ProbTables(trial).TrialTable=ptable; %#ok<AGROW>

end % of finding metavariables

end % of Cross Validation loop

wrong=sum(~(class==class_predict(:,r)));
    errorrate(r)=wrong/rows;

end % of repeating entire process loop

% Record the results in the output structure
OutputStructure.ErrorRate=errorrate; % one for each repeat
OutputStructure.CvErrorRate=cverrorrate; % one for each of n*repeats trials
OutputStructure.PredictedClass=class_predict;

% Find out the error for each case
classrep=repmat(class,1,r);
WasIright=classrep==OutputStructure.PredictedClass;
OutputStructure.CasePredictionRate=sum(WasIright, 2)/double(r);

OutputStructure.ClassProbability=class_prob;
OutputStructure.ProbTables=ProbTables;
OutputStructure.SumAdj=squeeze(sum(OutputStructure.Adjacency,1));
OutputStructure.SumMV=squeeze(sum(OutputStructure.MetaVariablesFound,1));
% Save the results as a .mat data file and alert the user.
save results -struct  OutputStructure
disp('Results are saved in the current directory as results.mat')

end % of the function

Initial Processing of Input Structure

function StructOut = InitialProcessing( StructIn)
% (c) Karl Kuschner, College of William and Mary, Dept. of Physics, 2009.
%
% INITIALPROCESSING Inital Prep of Data from Signal Processing
%
% DESCRIPTION
%       Takes peaklists that have been imported into MATLAB and prepares
%       them for Bayesian Analysis.
%
% USAGE
%       StructOut = InitialProcessing( StructIn)
%
% INPUTS
%       Structure with the following double-typed arrays Intensities: n x m
%       real-valued array with variables (peaks) in
%           columns, cases (samples) in rows.
%       MZ: List of the labels (m/z value) for each of the variables.
%           Must be the same size as the number of variables in Intensities
%       Class: Classification of each sample (disease state)-- 1 or 2--must
%       be the same size as the number of cases in Intensities ID: Case or
%       patient ID number, same size as class.  May have second
%           column, so each row is [ID1 ID2} where ID2 is replicate number.
%       Options (logical):  Array of processing options with elements:
%           1. Normalize 2. Clip Data (remove negatives) 3. Replicate
%           Average 4. Auto threshold MI 5. Use Log of Data 6. Remove Low
%           Signal cases NOT DONE: 3 Bin (2 Bin if False)
%
% OUTPUTS
%       DataStructure: MATLAB data structure with the following components:
%           RawData: Intensities as input ClipData: RawData where all
%           values less than 1 are set to 1 NormData: ClipData normalized
%           by total ion count, i.e.
%               divided by the sum of all variables for each case
%           LogData: Natural logarithm of NormData Class, MZ: Same as input
%           ID: SIngle column. If replicates are not averaged, the entries
%               are now ID1.ID2. If replicates averaged, then just ID1
%           DeltaMZ: difference in peak m/z values to look for adducts
%           RatioMZ: ratios of m/z values ot look for satellites
%
% CALLED FUNCTIONS
%       None. (cdfplot is MATLAB "stat" toolbox)


% Initialize  Data
%  find the size, create the output structure,and transfer info

[rows cols]=size (StructIn.Intensities);
StructOut = StructIn;
StructOut.RawData = StructIn.Intensities;

% Option 2: Clip Negatives from data
%  set values below 0 to be 1 because negative
%   molecule counts are not physically reasonable
% 1 is chosen rather than 0 in case log(data) is used Note: the decision to
% do this before normalization was based on discussions with Dr. William
% Cooke, who created the data set.

if StructOut.Options(2)
    StructOut.Intensities(find(StructOut.Intensities<1))=1; %#ok<FNDSB>
end

%  Option 6: Removal of Cases with Low Signal
%   find the sum of all values for eah row, then normalize each row to
%   account for the effects of signal strenght over time and other
%   instrumental variations in total strength of the signal

% Find the total ion count for each case, then the global average.
% Determine a correction factor for each case (NormFactor)
if StructOut.Options(1) ||  StructOut.Options(6)
    RowTotalIonCount=sum(StructOut.Intensities, 2);
    AvgTotalIonCount=mean(RowTotalIonCount); %Population average
NormFactor=AvgTotalIonCount./RowTotalIonCount; %Vector of norm factors
StructOut.NormFactor=NormFactor;  %save this in the structure
end
% If Remove Low Signal is desired, interact with user to determine
% threshold, then remove all cases that are below the threshold.

if StructOut.Options(6)
    figure(999);
    cdfplot(NormFactor);
    title('Cumulative Distribution of Normalization Factors');

% Request cutoff

text(1.3,0.5,['Click on the graph where you want';...
'the normalization threshold      ';...
'Cases with high norm factor (or  ';...
'low signal) will be removed.     ']);
    [NormThreshold, Fraction] = ginput(1);
    display([num2str(floor((1-Fraction)*100)),'% of cases removed']);
    close(999);
    TossMe=find (NormFactor>NormThreshold); %Low signal cases

% Now record, then remove, those cases with low signal

StructOut.LowSignalRemovedCases=StructOut.ID(TossMe,:);
    StructOut.LowSignalRemovedCasesNormFactors=NormFactor(TossMe);
    StructOut.Intensities(TossMe,:)=[];
    StructOut.ID(TossMe,:)=[];
    StructOut.Class(TossMe,:)=[];

end


% Option 3: Replicate Average This option causes cases with same ID numbers
% to be averaged, peak by peak.

if StructOut.Options(3) %Replicate Average
% Collapse to unique IDs only, throw out replicate ID column
StructOut.Replicate_ID=StructOut.ID; %Save old data
StructOut.Replicate_Class=StructOut.Class;

    newID=unique(StructOut.ID(:,1)); % List of unique IDs
num=size(newID,1); %how many are there?
newClass=zeros(num,1); % Holders for extracted class, data
newData=zeros(num,cols);
for i=1:num % for each unique ID
id=newID(i); % work on this one
cases=find(StructOut.ID(:,1)==id); % Get a list of cases with this ID
newClass(i)=StructOut.Class(cases(1)); % save their class
casedata=StructOut.Intensities(cases, :); % get their data
newData(i,:)=mean(casedata, 1); % and save the average
end
StructOut.Intensities=newData;
    StructOut.Class=newClass;
    StructOut.ID=newID;
    clear newID newClass newData
else % If replicates exist, combine the 2 column ID into a single ID
ID= StructOut.ID;
if min(size(ID))==2
        shortID=ID(:,1)+(ID(:,2)*.001); % Now single entry is ID1.ID2
StructOut.OldID=StructOut.ID;
        StructOut.ID=shortID;
        clear ID shortID
end

end

% Option 1: Normalize total ion count Apply the normalization factor to
% each row to normalize total ion count. We'll recalc norm factors in case
% data was replicate averaged.
if StructOut.Options(1)
    RowTotalIonCount=sum(StructOut.Intensities, 2);
    AvgTotalIonCount=mean(RowTotalIonCount); %Population average
NormFactor=AvgTotalIonCount./RowTotalIonCount; %Vector of norm factors
StructOut.NormFactor=NormFactor;  %save this in the structure
NFmat=repmat(NormFactor, 1, cols); % match size of Intensities
StructOut.Intensities=StructOut.Intensities.*NFmat;
    clear NFmat RowTotalIonCount AvgTotalIonCount NormFactor;
end


%  Option 5: Work with log (data)

if StructOut.Options(5)
    StructOut.Intensities=log(StructOut.Intensities);
end


% end function

end

Build network links

function adjacency = BuildBayesNet( data, class, ffactor, drop )
% (c) Karl Kuschner, College of William and Mary, Dept. of Physics, 2009.
%
% BuildBayesNet selects features and metafeatures based on mutual info.
%
%
% DESCRIPTION
%       This function takes a set of training data and an additional
%       variable called "class" and tries to learn a Bayesian Network
%       Structure by examining Mutual Information.  The class variable C is
%       assumed to be the ancestor of all other variables V.  Arcs from C
%       to V are declared if MI(C;V)>>z, where z is a maximum expected MI
%       of similar, but random data...multiplied by a "fudge factor."  Arcs
%       from Vi to Vj are similarly declared. Then various tests are
%       performed to prune the network structure and combine variables that
%       exhibit high correlations. Finally the network is pruned to be a
%       Naive Bayesian Classifier, with only C->V arcs remaining.
%
% USAGE
%       network_structure = BuildBayesNet( training_data, class )
%
% INPUTS
%       training_data: cases in rows, variables in cols, integer array
%               containing the data used to build the Bayes net
%       class: the known class variable for each case (1:c col vector)
%       ffactor: multiple of auto MI to use to threshold C->V connections
%       drop:
%
% OUTPUTS
%
%       adjmatrix: a matrix of zeros and ones, where one in row i, column j
%               denotes a directed link in a Bayesian network between
%               variable i and variable j. The class variable is the last
%               row/column.
%
% CALLED FUNCTIONS
%
%       automi: finds an MI threshold based on data findmutualinfos: finds
%       all values MI(V;C), MI(V;V) and MI(V;C|V)

% Initialize

% Initialize the network object and some constants
network.data=data;
network.class=class;

automireps=10; %times to repeat the auto MI thresholding to find avg.

% Check the sizes of various things
[rows cols]=size(data); %#ok<NASGU>
cases=max(size(class));
if rows==cases
    clear cases
else
disp('# of rows in the data and class must be equal.')
return
end

% network.adjmat=zeros(cols+1); % all variables plus class as last row/col
dataalphabet=max(size(unique(data))); % number of possible values of data
classalphabet=max(size(unique(class))); % Number of values of class

% Step 0: Find all the necessary mutual information values, thresholds The
% function below finds all values MI(V;C|V) and other combos needed and
% stores them in the network structure.

% *******************************************************************
% REMOVE THE "2" in the next line to use Karl's (slower) subroutine *
% *******************************************************************
[ network.mi_vc, network.mi_vv, network.mi_vc_v ]...
= findmutualinfos( data, class );

% Find a threshold MI by examining MI under randomization
% ****************************** Come back to the next line
% ****************************
network.vcthreshold = automi( data, class, automireps )*ffactor ;
network.vvthreshold = network.vcthreshold *...
log(dataalphabet)/log(classalphabet);


% Step 1: Find all the possible arcs. Find the variables with high MI with
% the class, i.e. MI(V,C)>>0 and connect a link in the adjacency matrix
% C->V.  Also connect variable Vi,Vj if MI(Vi;Vj)>>0

network.adjmat1=getarcs(network.mi_vc,network.vcthreshold,network.mi_vv,...
network.vvthreshold);

% Step 2: Prune the variable set by clearing irrelevant features If there
% is no path from V to the class, clear all entries V<->Vi (all i)
network.adjmat2 = clearirrarcs( network.adjmat1 );

% Step 3: Cut connections to class Where two variables are connected to
% each other and also to the class, attempt to select one as the child of
% the other amd disconnect it from the class. Use MI(Vi;C|Vj)<<MI(Vi;C) as
% a test.

temp = clipclassconnections (network.adjmat2, ...
network.mi_vc,network.mi_vc_v, drop);

% and once again clear features no longer near class and end function
adjacency= clearirrarcs( temp );

end

Attempt to find metavariables

function  [finaldata metamatrix leftbound rightbound] =...
ChooseMetaVars ( data, class, adj)
% (c) Karl Kuschner, College of William and Mary, Dept. of Physics, 2009.
%
% ChooseMetaVars attempts to combine variables into better variables
%
% DESCRIPTION
%       Finds the V-V pairs in the adjacency matrix, and attempts to
%       combine them into a metavariable with a higher mutual information
%       than either variable alone. If it is possible to do this, it
%       returns a new data matrix with the variables combined.
%
% USAGE
%       [finaldata metamatrix leftbound rightbound] =
%                        ChooseMetaVars ( data, class, adj)
%
% INPUTS
%       data: double array of discrete integer (1:n) values, cases in rows
%           and variables in columns.
%       class: double column vector, also 1:n. Classification of each case.
%       adj: Adjacency matrix, #variables+1 by #variables. Last row is
%           class node. Logical meaning "there is an arc from i to j."
%
% OUTPUTS
%       metamatrix: logical whose (i,j) means "variable j was combined into
%           variable i (and erased)"
%       finaldata: The data matrix with the variable combined and rebinned
%       leftbound: The new left boundary (vector) for binning. rightbound:
%       The new right boundary (vector) for binning.
%
% CALLED FUNCTIONS
%       opt3bin: rebins combined variables to determine highest MI.

% Intialize
[rows cols]=size(data);
[classrow numvars]=size(adj);
bindata=zeros(rows,cols);
metamatrix=false(cols);

% Create a list of all the variables V to check by examining the adjacency
% matrix's last row, i.e. those with C->V connections
listvec=1:numvars;
varstocheck=unique(listvec(adj(classrow,:)));
l=zeros(1,numvars);
r=zeros(1,numvars);

% Now go through that list, testing each V->W connection to see if adding V
% and W creates a new variable Z that has a higher MI with the class than V
% alone.  V is the list above, W is the list of variables connected to a V.

for v=varstocheck % Pull out the W variables connected to V and test
wlist=unique(listvec(adj(v,:)));
    [l(v), r(v), binned, mitobeat] = opt3bin(data(:,v), class);
    bindata(:,v)=binned;
if ~isempty(wlist)
for w=wlist
            newdata=data(:,v)+data(:,w);
            [left, right, binned, newmi] = opt3bin(newdata, class);
if newmi>mitobeat
                mitobeat=newmi;
                data(:,v)=data(:,v)+data(:,w);
                metamatrix(v,w)=true; % record the combination
bindata(:,v)=binned;
                l(v)=left;
                r(v)=right;
end
end
end
end

%pull out just the V->C columns from the data matrix.
finaldata=bindata(:,adj(classrow,:));
leftbound=l(adj(classrow,:));
rightbound=r(adj(classrow,:));
end

Bayes classification

function classprobs = TestCases( p, prior, data)
% (c) Karl Kuschner, College of William and Mary, Dept. of Physics, 2009.
%
% classprobs uses Bayes rule to classify a case
%
% DESCRIPTION
%       Tests each of a set of data vectors by looking up P(data|class) in
%       a probability table, then finding P(case|class) by multiplying each
%       of those values in a product.  Then uses Bayes' rule to calculate
%       P(class|data) for each possible value of class.  Reports this as an
%       array of class probabilities for each case.
%
% USAGE
%       classprobs = TestCases( p, prior, data)
%
% INPUTS
%       data: double array of discrete integer (1:n) values, cases in rows
%           and variables in columns.
%       p: 3-D double array of probabilities (c,d,v).  The first dimension
%           is the class, the second is the data value, the third is the
%           variable number. The entry is P(var v=value d | class=value c).
%
% OUTPUTS
%
%       classprobs: 2-D double array whose value is P(class=c|data) for
%           each case. Cases are in rows, class in cols.
%
% CALLED FUNCTIONS
%
%       None.

% Intialize

% Find the sizes of the inputs and the number of possible values
[cases numvars]=size(data);
classvals=size(p,1);
pvec=zeros(classvals,numvars);
classprobs = zeros(cases, classvals); % holds the classification results

% Find the probabilities

% Create pvec, an array whose first row is P(V=v|c=1) for each V
for casenum=1:cases
    casedata=data(casenum,:); % The case to be checked
for c=1:classvals
for v=1:numvars
            pvec(c,v)=max(p(c,casedata(v),v),.01); % Don't want any zeros
end
end
% Now find P(case|class) for each class by multiplying each individual
% P(V|C) together, assuming they are independant.

Pdc=prod(pvec,2);

% Use Bayes' Rule

classprobs(casenum,:) =(Pdc.*prior)/sum(Pdc.*prior);

end

end

Optimized discretization

function [l, r, binned, mi] = opt3bin (data, class)
% by Karl Kuschner, College of William and Mary, Dept. of Physics, 2009.
%
% FunctionName short description
%
% DESCRIPTION
%       This function takes an array of continuous sample data of size
%       cases (rows) by variables (columns), along with a class vector of
%       integers 1:c, each integer specifying the class. The class vector
%       has the same number of cases as the data.  The function outputs the
%       position of the 2 bin boundaries (3 bins) that optimize the mutual
%       information of each variable's data vector with the class vector.
%
% USAGE
%       [l,r,binned, mi]=opt3bin(data,class)
%
% INPUTS
%       data: double array of continuous values, cases in rows and
%           variables in columns. Distribution is unknown.
%       class: double column vector, values 1:c representing classification
%           of each case.
%
% OUTPUTS
%
%       l     - row vector of left boundary position for each var. r     -
%       row vector of right boundary position for each var. binned- data
%       array discretized using boundaries in l and r mi    - row vector of
%       mutual info between each discr. variable
%                  and class
%
% CALLED FUNCTIONS
%
%       opt2bin: Similar function that finds a single boundary. This is
%           used as a seed for the 3 bin optimization.
%       looklr: See below.


% Intialize
%
%  Variable Prep : find sizes of arrays and create placeholders for locals

steps=150;
[rows cols]=size(data);
boundary=zeros(2,cols);

% Method Find starting point by finding the maximum value of a 2 bin mi.
% Next, go left and right from that position, finding the position of the
% next boundary that maximizes MI.

[mi boundary(1,:)] = opt2bin (data, class, steps, 2);

% We've located a good starting (center) bin boundary.  Search L/R for a
% second boundary to do a 3 bin discretization.
[mi boundary(2,:)] = looklr (data, class, boundary(1,:), steps);

% We've now found the optimum SECOND boundary position given the best 2 bin
% center boundary.  Now re-search using that SECOND boaundary position,
% dropping the original (2 bin).  The result should be at, or near, the
% optimal 3 bin position.
[mi boundary(1,:) binned] = looklr (data, class, boundary(2,:), steps);

% from the two boundaries found above, sort the left and right
r=max(boundary);
l=min(boundary);

% Now retutn the vector of left and right boundaries, the disc. data, and
% max MI found.
end % of function

Optimized binning search

function [miout nextboundary binned] = looklr (data, class, startbd, steps)
% given a start position, finds another boundary (to create 3 bins) that
% maximizes MI with the class
[rows cols]=size(data);
farleft=min(data,[],1);
farright=max(data,[],1);
miout=zeros(1,cols);
binned=zeros(rows,cols);
nextboundary=zeros(1,cols);

for peak=1:cols % for each peak/variable separately...

% discretize this variables' values. Sweep through the possible bin
% boundaries from the startbd to the furthest value of the data,
% creating 2 boundaries for 3 bins. Record the binned values in a
% "cases x steps" array, where "steps" is the granularity of the sweep.
% The data vector starts off as a column...

testmat=repmat(data(:,peak),1,steps); % and is replicated to an array.

% Create same size array of bin boundaries. Each row is the same.
checkptsL=repmat(linspace(farleft(peak),startbd(peak),steps),rows,1);
    checkptsR=repmat(linspace(startbd(peak),farright(peak),steps),rows,1);

% Create a place to hold the discrete info, starting with all ones. The
% "left" array will represent data binned holding the center boundary
% fixed and sweeping out a second boundary to the left; similarly the
% right boundary starts at "startbd" and sweeps higher.
binarrayL=ones(rows,steps);
    binarrayR=ones(rows,steps);

% Those in the L test array that are higher than the left boundary -> 2
binarrayL(testmat>checkptsL)=2;
    binarrayL(testmat>startbd(peak))=3; % >center boundary -> 3

% Similarly using center and right boundaries
binarrayR(testmat>startbd(peak))=2;
    binarrayR(testmat>checkptsR)=3;

% Now at each of those step positions, check MI (var;class).
miout(peak) = 0;

% THese vectors hold the MI with each step used to discretize.
miL=MIarray(binarrayL,class);% MI(V;C) using left/center
miR=MIarray(binarrayR,class); % MI(V;C) using center/right

if max(miL)>max(miR)  % See which one is the largest
[miout(peak) index]=max(miL); %record the max mi found
nextboundary(peak)=checkptsL(1,index); % and record the boundary
binned(:,peak)=binarrayL(:,index);% and record the discrete data
else
[miout(peak) index]=max(miR); %record the max mi found
nextboundary(peak)=checkptsR(1,index); % and record the boundary
binned(:,peak)=binarrayR(:,index);% and record the discrete data
end



end % of that variable's search.  Go to next variable.

end % of the search.  Return the best boundary and the associated MI and data

Automated MI thresholding

function threshold = automi( data, class, repeats )
% (c) Karl Kuschner, College of William and Mary, Dept. of Physics, 2009.
%
% automi finds a threshold for randomized MI(V; C)
%
% DESCRIPTION
%       Finds the threshold of a data set's mutual information with a class
%       vector, above which a variable's MI(class, variable) can be
%       expected  to be significant. The threshold for mi (significance
%       level) is found by taking the data set and randoomizing the class
%       vector, then calculating MI(C;V) for all the variables. This is
%       repeated a number of times. The resulting list of length (#repeats
%       * #variables) is sorted,  and the 99th percentile max MI is taken
%       as the threshold.

% USAGE
%       threshold = automi( data, class )
%
% INPUTS
%       data: double array of discrete integer (1:n) values, cases in rows
%           and variables in columns.
%       class: double column vector, also 1:n. Classification of each case.
%       repeats: the number of times to repeat the randomization
%
% OUTPUTS
%
%       threshold: the significance level for MI(C;V)
%
% CALLED FUNCTIONS
%
%       MIarray(data,class): returns a vector with MI(Vi;Class) for each V
%           in the data set

% Intialize

% Find the size of the data (cases x variables) and check against class
[rows cols]=size(data);
cases=max(size(class));
if rows==cases
    clear cases
else
disp('# of rows in the data and class must be equal.')
return
end


% Repeat a number of times

mifound=zeros(cols,repeats); % stores the results of the randomized MI
for i=1:repeats
    c=class(randperm(rows)); % creates a randomized class vector
mifound(:,i)=MIarray(data,c); % record MI(Ci;V) in an array
end

% pull off the 99th percentile highest MI
mi_in_a_vector=reshape(mifound,repeats*cols,1); % prctile needs vector
threshold=prctile(mi_in_a_vector,99);

end

Find arcs, build adjacency matrix

function adjacency = getarcs( mvc, vcthreshold, mvv, vvthreshold )
% (c) Karl Kuschner, College of William and Mary, Dept. of Physics, 2009.
%
% GETARCS builds the adjacency matrix for a set of variables
%
% DESCRIPTION
%       By comparing mutual information between two variables to thresholds
%       determined seperately, this function declares there to be an arc in
%       a Bayesian network. Arcs are stored in an adjacency matrix,
%       described below.
%
%       The primary tests are: MI(Vi;Cj)>>vcthreshold : tests for links
%       between Vi and the class MI(Vi;Vj)>>vvthreshold : tests the links
%       between variables
%
% USAGE
%       adjacency = getarcs( mvc, vcthreshold, mvv, vvthreshold )
%
% INPUTS
%       mvc [mvv]: double vector [array] with mutual information between
%           variables and the class [variables and other variables]. The
%           (i,j) entries of mvv are MI(Vi,Vj).
%       vc/vvthreshold: scalar threshold used to test for existence linkz
%
% OUTPUTS
%
%       adjacency: logical matrix whose entries "1" at (i,j) mean "an arc
%            exists from the Bayesian network node Vi to Vj." The class
%            variable C is added at row (number of V's + 1). "0" values
%            mean no arc.
%
% CALLED FUNCTIONS
%
%       None.
%
% For more information on the tests and the links, see my dissertation.


% Initialize
numvars=max(size(mvc)); %the number of variables
classrow=numvars+1; %row to store links C->V
adjacency= false(classrow,numvars); %the blank adjacency matrix

% Test for adjacency to class
adjacency ( classrow , : )= mvc > vcthreshold;

% Test for links between variables This test results in a symmetric logical
% matrix since MI (X;Y) is symetric. To create a directed graph, these arcs
% will need to be pruned.
adjacency ( 1:numvars, 1:numvars ) = mvv > vvthreshold;

end

Clear irrelevant arcs

function adjout = clearirrarcs( adjin )
% (c) Karl Kuschner, College of William and Mary, Dept. of Physics, 2009.
%
% CLEARIRRARCS clears arcs that are not C->V or C->V<->V
%
% DESCRIPTION
%       Given an adjacency matrix with V<->V arcs in a square matrix and an
%       additional row representing C->V (class to variable), this function
%       clears out all V1->V2 arcs where V1 is not a member of the set of
%       V's that are class-connected, i.e. have arcs in the final row.
%
% USAGE
%       adjout = clearirrarcs( adjin )
%
% INPUTS
%       adjin: a logical array where a true value at position (i,j) means
%           that there is an arc in a directed acyclic graph between
%           (variable) i and variable j.
%
% OUTPUTS
%       adjout: copy of adjin with unneeded arcs cleared
%
% CALLED FUNCTIONS
%       None.

% Intialize Find the sizes of the input
[classrow, numvars]=size(adjin);

% Main processing Find out which variables are connected to class
conntocls=(adjin(classrow,:));

% Remove all arcs that don't have at least one variable in this list, e.g.
% all Vi<->Vj such that ~(Vi->C or Vj->C). These are all the entries in the
% adjacency matrix whose i and j are NOT in the list above.

% Make a matrix with ones where neither variable is in the list above
noconnmat=repmat(~conntocls,numvars,1) & repmat(~conntocls',1,numvars);

% Use that to erase all the irrelevant entries in the square adj matrix, at
% the same time remove the diagonal (arcs Vi<->Vi)
adjout=adjin (1:numvars, 1:numvars)& ~noconnmat & ~eye(numvars);

% Bidirectional arcs are temporarily permitted between nodes connected
% directly to the class, but not between nodes where only one is connected
% to the class- those are assumed to flow C->V1->V2 only.  Remove V2->V1.

% Get a matrix of ones in rows that are class connected. V->V arcs are only
% allowed to be in these rows:
parents=repmat(conntocls',1,numvars);
% Remove anything else
adjout=adjout & parents;

% Now add back in the class row at the bottom of the square matrix
adjout(classrow,:)=adjin(classrow,:);

end

Attempt to prune relevant arcs

function adjout = clipclassconnections( adj, mivc_vec,mivcv,dropthreshold )
% (c) Karl Kuschner, College of William and Mary, Dept. of Physics, 2009.
%
% clipclassconnections delinks variables from class
%
% DESCRIPTION
%        Where two variables are connected to each other and also to the
%        class, attempt to select one as the child of the other and
%        disconnect it from the class. Use MI(Vi;C|Vj)<<MI(Vi;C) as a test.
%
% USAGE
%       probtable = FindProbTables(data, class)
%
% INPUTS
%       adj: (logical) matrix where "true" entries at (i,j) mean "an arc
%            exists from the Bayesian network node Vi to Vj." The class
%            variable C is added at row (number of V's + 1). "0" values
%            mean no arc.
%       mivc_vec: (double) row vector containing MI(C;Vi) for each variable
%       mivcv: (double) array whose (i,j) entry is MI(Vi,C|Vj).
%       dropthreshold: percentage drop from MI(Vj;C) to MI(Vj;C|Vi) before
%           declaring that Vi is between C and Vj.
%
% OUTPUTS
%       adjout: copy of adj with the appropriate arcs removed.
%
% CALLED FUNCTIONS
%       None.

% Intialize
[classrow, numvars]=size(adj);
classconnect=adj(classrow, :); % the last row of adj stores arcs C->V
adjout=false(classrow, numvars); % placeholder for output array

% Identify triply connected arcs

% First look for pairs that are connected to each other and connected to
% the class.

% Connected to each other: build logical array with (i,j) true if Vi<->Vj
vv_conn=adj(1:numvars, 1:numvars);

% Connected to the class: logical array with (i,j) true if C->Vi and C->Vj
vcv_conn=repmat(classconnect, numvars,1) & repmat(classconnect',1,numvars);

% Find all (i,j) with both true
triple_conn = vv_conn & vcv_conn;

% Determine preferred direction on V<->V arcs

% Determine the Vi<->Vj direction by finding the greater of MI(C;i|j) or
% (C;j|i).  Greater MI means less effect of the instantiation of i or j.
arcdirection=mivcv > mivcv'; %Only the larger survive
dag_triple_conn=arcdirection & triple_conn; % Wipes out the smaller ->

% find links should NOT be kept under the test above,
linkstoremove=(~arcdirection) & triple_conn;
% and if they are in the connection list, remove them
adjout(1:numvars, 1:numvars)=xor(vv_conn,linkstoremove);

% Now we need to test whether we can remove the link between C and which
% ever V (i or j) is the child of the other. We look for a "significant"
% drop in MI(Vj;C) when instantiating Vi, e.g. MI(Vj;C|Vi)<<MI(Vj;C).
%
% dropthreshold of .7, for example, means link breaks if 1st term is less
% than 30% of the second term.
%
% If there is a big drop in MI(C;Vj) when Vi is given, and Vi->Vj exists in
% the DAG, then we can remove the link C->Vj and leave C->Vi->Vj.

% Build an array out of the mivc_vec vector
mivc=repmat(mivc_vec',1,numvars);
% Test for the large drop described above
bigdrop=((mivc-mivcv)./mivc) > dropthreshold;
% Test for the big drop and the V-V connection
breakconn = bigdrop' & dag_triple_conn;
% If any of the elements in a column of the result are true, remove that
% variable's C->V link, since it is a child.
linkstokeep=~any(breakconn);
adjout(classrow,:)= adj(classrow,:) & linkstokeep;

% With V->V links now only one way, and C->V removed where needed, we can
end

Find all necessary mutual information values

function [ mi_vc, mi_vv, mi_vc_v ] = findmutualinfos( data, class )
% by Karl Kuschner, College of William and Mary, Dept. of Physics, 2009.
%
% FINDMUTUALINFOS finds the various mutual info combos among variables.
%
% DESCRIPTION
%       Given a set of data (many cases, each with values for many
%       variables) and an additional value stored in the vector class, it
%       finds various combinations of MI described below in "OUTPUTS."
%
% USAGE
%       [ mi_vc, mi_vv, mi_vc_v ] = findmutualinfos( data, class )
%
% INPUTS
%       data: A number of cases (in rows), each with a measurement for a
%           group of variables (in columns). The data should be discretized
%           into integers 1 through k. The columns are considered variables
%           V1, V2, ...
%       class: an additional measurement of class C. A column vector of
%           length  "cases" with integer values 1,2...
%
% OUTPUTS
%
%       mi_vc: a row vector whose ith value is MI(Vi,C). mi_vv: Symmetric
%       matrix with values MI(Vi,Vj). mi_vc_v: Non-sym matrix with values
%       MI(Vi;C|Vj).
%
% CALLED FUNCTIONS
%
%       findentropies: returns entropy values [e.g. H(Vi,Vj)]

% Intialize

%Find the data size and declare some blank arrays
[rows cols]=size(data);
mi_vv=zeros(cols);
mi_vc_v=zeros(cols);

% Find Entropies and Calculate Mutual Informations

% Find the various entropies needed to calculate the MI's
[ h_c, h_v, h_vc, h_vv, h_vcv ] = findentropies( data, class );

% Calculate the value MI(Vi,C)
mi_vc = h_v + h_c - h_vc;

% For each variable Vj, calculate MI(Vi,Vj) and MI(Vi;C|Vj)

for i=1:cols
for j=1:cols
        mi_vv (i,j) = h_v(i) + h_v(j) - h_vv(i,j);
        mi_vc_v(i,j) = h_vv(i,j) -  h_v(j)+ h_vc(j) - h_vcv(i,j);
end
end

end

Coarse optimized binning,

function [mi boundary binneddata] = opt2bin (rawdata, class, steps,...
typesearch, minint, maxint)
% (c) Karl Kuschner, College of William and Mary, Dept. of Physics, 2009.
%
% opt2bin finds the best single boundary for each variable to maximize MI
%
% DESCRIPTION
%       This function takes an array of continuous data, with cases in rows
%       and variables in columns, along with a vector "class" which holds
%       the known class of each of the cases, and returns an array
%       "binneddata" that holds the 2 bin discretized data.  The
%       discretization bin boundary is found by maximizing the mutual
%       information with the class; the resulting MI and boundary are also
%       returned. The starting boundaries for the search can be given in
%       the vectors min and max, or either one, or neither, in which case
%       the data values determine the search boundaries.%
%
% USAGE
%       [mi boundary binneddata] = maxMIbin(rawdata, class, typesearch [,
%           min, max])
%
% INPUTS
%       rawdata: double array of continuous values, cases in rows and
%           variables in columns. Distribution is unknown.
%       class: double column vector, values 1:c representing classification
%           of each case.
%       steps: Number of steps to test at while finding maximum MI
%       typesearch =0: starting bndry based on data's actual max/min values
%                  =1: use the value passed in max as maximum (right) value
%                  =-1: use the value passed in min as minimum (left) value
%                  =2: used values passed via max, min
%       the two optional arguments are vectors whose values limit the range
%       of search for each variables boundaries.
%
% OUTPUTS
%
%       mi: row vector holding the maximum values of MI(C;Vi) found
%       boundary: The location used to bin the data to get max MI
%       binneddata: The resulting data binned into "1" (low) or "2" (hi)
%
% CALLED FUNCTIONS
%
%       MIarray: Finds the MI of each col in an array with a separate
%           vector (the class in this case)

% Intialize
[rows cols]=size(rawdata);
mi=zeros(1,cols);
boundary=zeros(1,cols);
binneddata=zeros(rows,cols);
currentmi=zeros(steps,cols);

% if not passed, find the left and rightmost possible bin boundaries from
% data

if nargin~=6
    minint=min(rawdata,[],1);
    maxint=max(rawdata,[],1);
elseif typesearch==1
    minint=min(rawdata,[],1);
elseif typesearch==-1
    maxint=max(rawdata,[],1);
elseif typesearch==2
    disp('using passed values')
else
disp('typesearch must = 0,1,-1,2')
return
end

% Find best boundary

for peak=1:cols %look at each variable separately

% Create an array of bin boundary's possible locations min->max
checkpoints=repmat(linspace(minint(peak),maxint(peak),steps),rows,1);

% discretize the variable's values at each of these possible
% boundaries, putting 2's everywhere (value > boundary), 1 elsewhere
binarray=(repmat(rawdata(:,peak), 1, steps)>checkpoints)+1;

% Send this array off to find the MI(C,V) for each possible binning
currentmi(1:steps,peak)=MIarray(binarray,class);

% Now pick out the highest MI, i.e. best bin boundary
[mi(peak) atstep]=max(currentmi(:,peak));
    boundary(peak)=checkpoints(1,atstep);

% and record the binned data using that boundary.
binneddata(:,peak)=binarray(:,atstep);
end

end

Find Bayes network parameters

function p=FindProbTables(data, class)
% (c) Karl Kuschner, College of William and Mary, Dept. of Physics, 2009.
%
% FindProbTables estimates the probabilities P(class=c|data=D)
%
% DESCRIPTION
%       Input a training group of data arranged with cases in rows and
%       variables in columns, as well as the class value c for that vector.
%       Each case represents a data vector V.  For each possible data value
%       vi, and each variable Vi, it calculates P(C=c|Vi=vi) and stores
%       that result in a 3-D table.  The table is arranged with the
%       dimensions (class value, data value, variable number).
%
% USAGE
%       probtable = FindProbTables(data, class)
%
% INPUTS
%       data: double array of discrete integer (1:n) values, cases in rows
%           and variables in columns.
%       class: double column vector, also 1:n. Classification of each case.
%
% OUTPUTS
%
%       probtable: 3-D array whose (c,d,v) value is P(class=c|data=p) for
%           variable v.
%
% CALLED FUNCTIONS
%
%       None.

% Intialize Find the sizes of the inputs and the number of possible values
[cases numvars]=size(data);
datavals=max(size(unique(data)));
classvals=max(size(unique(class)));
% Build some placeholders and loop indices
p=zeros(classvals, datavals, numvars ); % triplet: (class, value, variable#)
databins=1:datavals;
classbins=1:classvals;

% Find Probabilities For each classification value, extract the data with
% that class
for c=classbins
    datainthatclass=data(class==c,:); % array of just cases with class=c
% find the percentage of data with each possible data value
p(c,:,:)=histc(datainthatclass,databins)/cases;
end

end

Find all variable entropies

function [ h_c, h_v, h_vc, h_vv, h_vvc ] = findentropies( data, class )

% (c) Karl Kuschner, College of William and Mary, Dept. of Physics, 2009.
%
% findentropies finds all the entropies H(V), H(V,V), etc.
%
% DESCRIPTION
%       Give a set of data arranged with cases in rows and variables in
%       columns, and an additional variable labeled the "class", this
%       function returns the entropy of each variable's data, the class'
%       data, and the joint entropy of all pairs of two variables, all
%       variables with the class, and all pairs of variable and the class.
%
% USAGE
%       [ h_c, h_v, h_vc, h_vv, h_vvc ] = findentropies( data, class )
%
% INPUTS
%       data: double array of discrete integer (1:n) values, cases in rows
%           and variables in columns.
%       class: double column vector, also 1:n. Classification of each case.
%
% OUTPUTS
%
%       h_v: entropies of the variables, H(Vi), stored in a row vector.
%       h_c: scalar entropy of the class vector, H(C) h_vc: vector whose
%       ith entry is the joint entropy H(Vi,C) h_vv: matrix whose (i,j)
%       entry is the joint entropy H(Vi,Vj) h_vvc: matrix whose (i,j) entry
%       is the joint entropy H(Vi,Vj,C)
%
% CALLED FUNCTIONS
%
%       entropy (vector, num_poss_vals [vector, numvals,...]) see below

% Initialize
%  Find the number of variable (cols) and number of cases, as well as the
%  number of possible values (k) and class values (l)
[rows cols]=size(data);
k=max(size(unique(data))); % # of possible values of data
l=max(size(unique(class))); % # of possible values of class

% Intialize the output matrices
h_v=zeros(1,cols);
h_vc=zeros(1,cols);
h_vv=zeros(cols,cols);
h_vvc=zeros(cols,cols);

% Main processing Calculate all the various entropy combinations
h_c = entropy (class, l); % see function below

for i=1:cols

    h_v(i) = entropy (data(:,i), k);
    h_vc(i) = entropy (data(:,i), k, class, l);

for j=1:cols
        h_vv(i,j)= entropy (data(:,i), k, data(:,j), k);
        h_vvc(i,j) = entropy (data(:,i), k, data(:,j), k, class, l);
end

end

end

Entropy equation implementation

function ent=entropy(vector1, k, vector2, l, vector3, m)
% (c) Karl Kuschner, College of William and Mary, Dept. of Physics, 2009.
%
% entropy finds all the entropies H(V), H(V,V), etc.
%
% DESCRIPTION
%       Calculates the entropy (or joint entropy if more that one argument
%       pair) of a vector (or vectors) whose values are {1,2,...k}. Must
%        send in one or more pairs of arguments ("vector", "num poss vals")
%
% USAGE
%       ent=entropy (vector, num_poss_vals [,vec,numvals [,vec,numvals] ])
%
% INPUTS One to three pairs of
%       vector: vector of integers 1,2,..k representing values of randm var
%       k: number of possible values in vector
%
% OUTPUTS
%       ent: information entropy H(V1)[or H(V1,V2) or H(V1,V2,V3)]
%
% CALLED FUNCTIONS
%
%       None.

% Initialize
n=max(size(vector1)); % Number of possible cases (not error checked)

% Calculate the Entropy

% single variable entropy formula
if nargin==2
    P_k=hist(vector1,1:k)/n;
    NonZero=find(P_k~=0); % See Note 1
ent=-sum(P_k(NonZero).*log2(P_k(NonZero)));
end

% two variable joint entropy H(V1,V2)
if nargin==4
    ent=0;
for i=1:l
        P_lk=hist(vector1(vector2==i),1:k)/n;
        NonZero=find(P_lk~=0);
        ent=ent-sum(P_lk(NonZero).*log2(P_lk(NonZero)));
end
end

% three variable joint entropy H(V1,V2,V3)
if nargin==6
    ent=0;
for i=1:l % for all possible values in V2
for j=1:m % for all possible values in V3
%           empirically find probability and sum entropy each
%           step
P_lkm=hist(vector1(vector2==i & vector3==j),1:k)/n;
            NonZero=find(P_lkm~=0); % See Note 1
ent=ent-sum(P_lkm(NonZero).*log2(P_lkm(NonZero)));
end
end
end

%  Note 1: we can skip terms with p(a,b,c)=0 since
%               p log (p) = 0 log 0 = 0
%  in that case and it does not contribute to the sum.

end

Find mutual information of a vector with all columns of an array

function MIOut = MIarray(MatrixIn, class)
% by Karl Kuschner, College of William and Mary, Dept. of Physics, 2009.
%
% MIarray finds MI of each column of a data set with a separate vector
%
% DESCRIPTION
%       This function finds the mutual information between a single
%       discrete variable (class) and a matrix of discrete variables
%       (MatrixIn) which have the same number of cases (variables in
%       columns, cases in rows). A row vector containing the values
%       MI(Vi,C) for each variable Vi in the matrix is returned.
%
% USAGE
%       MIOut = MI(MatrixIn, class)
%
% INPUTS
%       data: double array of discrete integer (1:n) values, cases in rows
%           and variables in columns.
%       class: double (col) vector, values 1:c representing class of each
%           case. Number of values c can be different than n in the data.
%
% OUTPUTS
%       MIOut: double (row) vector whose entries are the Mutual information
%           between each corresponding column of MatrixIn and the class.
%
% CALLED FUNCTIONS
%       None.
%


% Intialize and Data Check check arguments
if nargin~=2
    disp('wrong number of input arguments')
    disp('need (data_array,class)')
    disp(' ')
    disp('Type "doc MI" for more info')
end
%class and MatrixIn must have the same number of rows
[rows cols]=size(MatrixIn);
if size(class,2)==rows
    class=transpose(class);
elseif size(class,1)~=rows
    disp('Dimension mismatch in rows of MI arrays')
    disp('Input arrays must have the same number of rows')
return
end %row dimension check


% States must be integer values, typically 1 to n. If so, record n.
% Similarly, find out the number of states of the class variable.
if sum(any(MatrixIn-round(MatrixIn)))
    disp('Matrix in should be integers 1 to n')
return
else
n=max(size(unique(MatrixIn))); % Number of data states
c=max(size(unique(class)));% Number of class states
end % check if integer

% Variable Prep

MatrixIn=int8(MatrixIn); %optional
class=int8(class); %optional
Pcv = zeros(c,n,cols);

% Main function

% Compute probability tables. P_ij is a matrix whose entries are
% Prob(Variable 1=state i and Variable 2= state j).  Others are similar.

if c==1 %trap for errors in the case where all classes are the same
Pc = 1;
else
% Create a 3-D array with c rows, each row filled with P(C=ci)
Pc = repmat((hist(class,1:c)/rows)', [1,n,cols]);
end

% Create a 2-D array where (j,k) is P(Vk=vj).  Replicate it to a third
% dimension to prepare for multiplication with the above.
Pv =repmat( reshape ( hist(MatrixIn,1:n)/rows , [1,n,cols] ), [c,1,1] );

% Now multiply these together,  The result is a c by n by cols matrix whose
% (i,j,k) entry is P(C=ci)*P(Vk=vj) for each value of class ci and data vj.
PcPv= Pc.*Pv;

% Now we need a similar sized array with the  (i,j,k) entry equal to P(C=ci
% and Vk=vj) -- the joint probability.
for classstate=1:c
    Pcv(classstate,:,:) = hist(MatrixIn(class==classstate,:),1:n)/rows;
end

% Now we can compute the mutual info using
%
% MI(C=i;Vk=j) = sum i (sum j (Pcv(i,j,k) log [Pcv(i,j,k)/PcPv(i,j,k)] ) )
%
miterms=Pcv.*(log2(Pcv)-log2(PcPv)); % The term inside the log above...
miterms(isnan(miterms))=0; % with all the 0 log 0 entries removed

% Do the double summation and squeeze the unused dimensions
MIOut = squeeze(sum(sum(miterms,1),2))';

end