SNAP Library 2.0, Developer Reference  2013-05-13 16:33:57
SNAP, a general purpose, high performance system for analysis and manipulation of large networks
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
agmfast.h
Go to the documentation of this file.
00001 #ifndef snap_agmfast_h
00002 #define snap_agmfast_h
00003 #include "Snap.h"
00004 
00007 class TAGMFast { 
00008 private:
00009   PUNGraph G; //graph to fit
00010   TVec<TIntFltH> F; // membership for each user (Size: Nodes * Coms)
00011   TRnd Rnd; // random number generator
00012   TIntV NIDV; // original node ID vector
00013   TFlt RegCoef; //Regularization coefficient when we fit for P_c +: L1, -: L2
00014   TFltV SumFV; // sum_u F_uc for each community c. Needed for efficient calculation
00015   TBool NodesOk; // Node ID is from 0 ~ N-1
00016   TInt NumComs; // number of communities
00017 public:
00018   TVec<TIntSet> HOVIDSV; //NID pairs to hold out for cross validation
00019   TFlt MinVal; // minimum value of F (0)
00020   TFlt MaxVal; // maximum value of F (for numerical reason)
00021   TFlt NegWgt; // weight of negative example (a pair of nodes without an edge)
00022   TFlt PNoCom; // base probability \varepsilon (edge probability between a pair of nodes sharing no community
00023   TBool DoParallel; // whether to use parallelism for computation
00024 
00025   TAGMFast(const PUNGraph& GraphPt, const int& InitComs, const int RndSeed = 0): Rnd(RndSeed), RegCoef(0), 
00026     NodesOk(true), MinVal(0.0), MaxVal(1000.0), NegWgt(1.0) { SetGraph(GraphPt); RandomInit(InitComs); }
00027   void SetGraph(const PUNGraph& GraphPt);
00028   void SetRegCoef(const double _RegCoef) { RegCoef = _RegCoef; }
00029   double GetRegCoef() { return RegCoef; }
00030   void RandomInit(const int InitComs);
00031   void NeighborComInit(const int InitComs);
00032   void SetCmtyVV(const TVec<TIntV>& CmtyVV);
00033   double Likelihood(const bool DoParallel = false);
00034   double LikelihoodForRow(const int UID);
00035   double LikelihoodForRow(const int UID, const TIntFltH& FU);
00036   int MLENewton(const double& Thres, const int& MaxIter, const TStr PlotNm = TStr());
00037   void GradientForRow(const int UID, TIntFltH& GradU, const TIntSet& CIDSet);
00038   double GradientForOneVar(const TFltV& AlphaKV, const int UID, const int CID, const double& Val);
00039   double HessianForOneVar(const TFltV& AlphaKV, const int UID, const int CID, const double& Val);
00040   double LikelihoodForOneVar(const TFltV& AlphaKV, const int UID, const int CID, const double& Val);
00041   void GetCmtyVV(TVec<TIntV>& CmtyVV);
00042   void GetCmtyVV(TVec<TIntV>& CmtyVV, const double Thres, const int MinSz = 3);
00043   int FindComsByCV(TIntV& ComsV, const double HOFrac = 0.2, const int NumThreads = 20, const TStr PlotLFNm = TStr(), const double StepAlpha = 0.3, const double StepBeta = 0.1);
00044   int FindComsByCV(const int NumThreads, const int MaxComs, const int MinComs, const int DivComs, const TStr OutFNm, const double StepAlpha = 0.3, const double StepBeta = 0.3);
00045   double LikelihoodHoldOut(const bool DoParallel = false);
00046   double GetStepSizeByLineSearch(const int UID, const TIntFltH& DeltaV, const TIntFltH& GradV, const double& Alpha, const double& Beta, const int MaxIter = 10);
00047   int MLEGradAscent(const double& Thres, const int& MaxIter, const TStr PlotNm, const double StepAlpha = 0.3, const double StepBeta = 0.1);
00048   int MLEGradAscentParallel(const double& Thres, const int& MaxIter, const int ChunkNum, const int ChunkSize, const TStr PlotNm, const double StepAlpha = 0.3, const double StepBeta = 0.1);
00049   int MLEGradAscentParallel(const double& Thres, const int& MaxIter, const int ChunkNum, const TStr PlotNm = TStr(), const double StepAlpha = 0.3, const double StepBeta = 0.1) {
00050     int ChunkSize = G->GetNodes() / 10 / ChunkNum;
00051     if (ChunkSize == 0) { ChunkSize = 1; }
00052     return MLEGradAscentParallel(Thres, MaxIter, ChunkNum, ChunkSize, PlotNm, StepAlpha, StepBeta);
00053   }
00054   //double FindOptimalThres(const TVec<TIntV>& TrueCmtyVV, TVec<TIntV>& CmtyVV);
00055   void Save(TSOut& SOut);
00056   void Load(TSIn& SIn, const int& RndSeed = 0);
00057   double inline GetCom(const int& NID, const int& CID) {
00058     if (F[NID].IsKey(CID)) {
00059       return F[NID].GetDat(CID);
00060     } else {
00061       return 0.0;
00062     }
00063   }
00064   void inline AddCom(const int& NID, const int& CID, const double& Val) {
00065     if (F[NID].IsKey(CID)) {
00066       SumFV[CID] -= F[NID].GetDat(CID);
00067     }
00068     F[NID].AddDat(CID) = Val;
00069     SumFV[CID] += Val;
00070   }
00071 
00072   void inline DelCom(const int& NID, const int& CID) {
00073     if (F[NID].IsKey(CID)) {
00074       SumFV[CID] -= F[NID].GetDat(CID);
00075       F[NID].DelKey(CID);
00076     }
00077   }
00078   double inline DotProduct(const TIntFltH& UV, const TIntFltH& VV) {
00079     double DP = 0;
00080     if (UV.Len() > VV.Len()) {
00081       for (TIntFltH::TIter HI = UV.BegI(); HI < UV.EndI(); HI++) {
00082         if (VV.IsKey(HI.GetKey())) { 
00083           DP += VV.GetDat(HI.GetKey()) * HI.GetDat(); 
00084         }
00085       }
00086     } else {
00087       for (TIntFltH::TIter HI = VV.BegI(); HI < VV.EndI(); HI++) {
00088         if (UV.IsKey(HI.GetKey())) { 
00089           DP += UV.GetDat(HI.GetKey()) * HI.GetDat(); 
00090         }
00091       }
00092     }
00093     return DP;
00094   }
00095   double inline DotProduct(const int& UID, const int& VID) {
00096     return DotProduct(F[UID], F[VID]);
00097   }
00098   double inline Prediction(const TIntFltH& FU, const TIntFltH& FV) {
00099     double DP = log (1.0 / (1.0 - PNoCom)) + DotProduct(FU, FV);
00100     IAssertR(DP > 0.0, TStr::Fmt("DP: %f", DP));
00101     return exp(- DP);
00102   }
00103   double inline Prediction(const int& UID, const int& VID) {
00104     return Prediction(F[UID], F[VID]);
00105   }
00106   double inline Sum(const TIntFltH& UV) {
00107     double N = 0.0;
00108     for (TIntFltH::TIter HI = UV.BegI(); HI < UV.EndI(); HI++) {
00109       N += HI.GetDat();
00110     }
00111     return N;
00112   }
00113   double inline Norm2(const TIntFltH& UV) {
00114     double N = 0.0;
00115     for (TIntFltH::TIter HI = UV.BegI(); HI < UV.EndI(); HI++) {
00116       N += HI.GetDat() * HI.GetDat();
00117     }
00118     return N;
00119   }
00120 };
00121 
00122 #endif