From f0653da318cac9736777495e40de6853227904ec Mon Sep 17 00:00:00 2001
From: Pjotr Prins
Date: Wed, 18 Mar 2015 13:21:12 +0300
Subject: Cleaned up gwas.py to use uses and moved Redis call back into lmm.py

---
 wqflask/wqflask/my_pylmm/pyLMM/gwas.py       | 70 +++++++++++-----------------
 wqflask/wqflask/my_pylmm/pyLMM/lmm.py        | 10 ++--
 wqflask/wqflask/my_pylmm/pyLMM/standalone.py | 31 +++++++-----
 3 files changed, 52 insertions(+), 59 deletions(-)

(limited to 'wqflask')

diff --git a/wqflask/wqflask/my_pylmm/pyLMM/gwas.py b/wqflask/wqflask/my_pylmm/pyLMM/gwas.py
index b901c0e2..8b344a90 100644
--- a/wqflask/wqflask/my_pylmm/pyLMM/gwas.py
+++ b/wqflask/wqflask/my_pylmm/pyLMM/gwas.py
@@ -19,7 +19,6 @@
 
 import pdb
 import time
-import sys
 # from utility import temp_data
 import lmm2
 
@@ -36,12 +35,10 @@ def formatResult(id,beta,betaSD,ts,ps):
    return "\t".join([str(x) for x in [id,beta,betaSD,ts,ps]]) + "\n"
 
 def compute_snp(j,n,snp_ids,lmm2,REML,q = None):
-   # print("COMPUTE SNP",j,snp_ids,"\n")
    result = []
    for snp_id in snp_ids:
       snp,id = snp_id
       x = snp.reshape((n,1))  # all the SNPs
-      # print "X=",x
       # if refit:
       #    L.fit(X=snp,REML=REML)
       ts,ps,beta,betaVar = lmm2.association(x,REML=REML,returnBeta=True)
@@ -51,32 +48,28 @@ def compute_snp(j,n,snp_ids,lmm2,REML,q = None):
       q = compute_snp.q
    q.put([j,result])
    return j
-      # PS.append(ps)
-      # TS.append(ts)
-      # return len(result)
-      # compute.q.put(result)
-      # return None
 
 def f_init(q):
    compute_snp.q = q
 
-def gwas(Y,G,K,restricted_max_likelihood=True,refit=False,verbose=True):
+def gwas(Y,G,K,uses,restricted_max_likelihood=True,refit=False,verbose=True):
    """
-   Execute a GWAS. The G matrix should be n inds (cols) x m snps (rows)
+   GWAS. The G matrix should be n inds (cols) x m snps (rows)
    """
+   progress,debug,info,mprint = uses('progress','debug','info','mprint')
+
    matrix_initialize()
    cpu_num = mp.cpu_count()
    numThreads = None # for now use all available threads
    kfile2 = False
    reml = restricted_max_likelihood
 
-   sys.stderr.write(str(G.shape)+"\n")
+   mprint("G",G)
    n = G.shape[1] # inds
    inds = n
    m = G.shape[0] # snps
    snps = m
-   sys.stderr.write(str(m)+" SNPs\n")
-   # print "***** GWAS: G",G.shape,G
+   info("%s SNPs",snps)
    assert snps>inds, "snps should be larger than inds (snps=%d,inds=%d)" % (snps,inds)
 
    # CREATE LMM object for association
@@ -85,19 +78,10 @@ def gwas(Y,G,K,restricted_max_likelihood=True,refit=False,verbose=True):
 
    lmm2 = LMM2(Y,K) # ,Kva,Kve,X0,verbose=verbose)
    if not refit:
-      if verbose: sys.stderr.write("Computing fit for null model\n")
+      info("Computing fit for null model")
       lmm2.fit()  # follow GN model in run_other
-      if verbose: sys.stderr.write("\t heritability=%0.3f, sigma=%0.3f\n" % (lmm2.optH,lmm2.optSigma))
-      
-   # outFile = "test.out"
-   # out = open(outFile,'w')
-   out = sys.stderr
-
-   def outputResult(id,beta,betaSD,ts,ps):
-      out.write(formatResult(id,beta,betaSD,ts,ps))
-   def printOutHead(): out.write("\t".join(["SNP_ID","BETA","BETA_SD","F_STAT","P_VALUE"]) + "\n")
-
-   # printOutHead()
+      info("heritability=%0.3f, sigma=%0.3f" % (lmm2.optH,lmm2.optSigma))
+            
    res = []
 
    # Set up the pool
@@ -106,26 +90,24 @@ def gwas(Y,G,K,restricted_max_likelihood=True,refit=False,verbose=True):
    p = mp.Pool(numThreads, f_init, [q])
    collect = []
 
-   # Buffers for pvalues and t-stats
-   # PS = []
-   # TS = []
    count = 0
    job = 0
    jobs_running = 0
+   jobs_completed = 0
    for snp in G:
       snp_id = (snp,'SNPID')
       count += 1
       if count % 1000 == 0:
          job += 1
-         if verbose:
-            sys.stderr.write("Job %d At SNP %d\n" % (job,count))
+         debug("Job %d At SNP %d" % (job,count))
          if numThreads == 1:
-            print "Running on 1 THREAD"
+            debug("Running on 1 THREAD")
             compute_snp(job,n,collect,lmm2,reml,q)
             collect = []
             j,lst = q.get()
-            if verbose:
-               sys.stderr.write("Job "+str(j)+" finished\n")
+            debug("Job "+str(j)+" finished")
+            jobs_completed += 1
+            progress("GWAS2",jobs_completed,snps/1000)
             res.append((j,lst))
          else:
             p.apply_async(compute_snp,(job,n,collect,lmm2,reml))
@@ -134,8 +116,9 @@ def gwas(Y,G,K,restricted_max_likelihood=True,refit=False,verbose=True):
             while jobs_running > cpu_num:
                try:
                   j,lst = q.get_nowait()
-                  if verbose:
-                     sys.stderr.write("Job "+str(j)+" finished\n")
+                  debug("Job "+str(j)+" finished")
+                  jobs_completed += 1
+                  progress("GWAS2",jobs_completed,snps/1000)
                   res.append((j,lst))
                   jobs_running -= 1
                except Queue.Empty:
@@ -150,24 +133,23 @@ def gwas(Y,G,K,restricted_max_likelihood=True,refit=False,verbose=True):
 
    if numThreads==1 or count<1000 or len(collect)>0:
       job += 1
-      print "Collect final batch size %i job %i @%i: " % (len(collect), job, count)
+      debug("Collect final batch size %i job %i @%i: " % (len(collect), job, count))
       compute_snp(job,n,collect,lmm2,reml,q)
       collect = []
       j,lst = q.get()
       res.append((j,lst))
-   print "count=",count," running=",jobs_running," collect=",len(collect)
+   debug("count=%i running=%i collect=%i" % (count,jobs_running,len(collect)))
    for job in range(jobs_running):
       j,lst = q.get(True,15) # time out
-      if verbose:
-         sys.stderr.write("Job "+str(j)+" finished\n")
+      debug("Job "+str(j)+" finished")
+      jobs_completed += 1
+      progress("GWAS2",jobs_completed,snps/1000)
       res.append((j,lst))
 
-   print "Before sort",[res1[0] for res1 in res]
+   mprint("Before sort",[res1[0] for res1 in res])
    res = sorted(res,key=lambda x: x[0])
-   # if verbose:
-   #    print "res=",res[0][0:10]
-   print "After sort",[res1[0] for res1 in res]
-   print [len(res1[1]) for res1 in res]
+   mprint("After sort",[res1[0] for res1 in res])
+   info([len(res1[1]) for res1 in res])
    ts = [item[0] for j,res1 in res for item in res1]
    ps = [item[1] for j,res1 in res for item in res1]
    return ts,ps
diff --git a/wqflask/wqflask/my_pylmm/pyLMM/lmm.py b/wqflask/wqflask/my_pylmm/pyLMM/lmm.py
index eab7d91d..1e00002a 100644
--- a/wqflask/wqflask/my_pylmm/pyLMM/lmm.py
+++ b/wqflask/wqflask/my_pylmm/pyLMM/lmm.py
@@ -57,11 +57,11 @@ import gwas
 # ---- A trick to decide on the environment:
 try:
     from wqflask.my_pylmm.pyLMM import chunks
-    from gn2 import uses, set_progress_storage
+    from gn2 import uses, progress_set_func
 except ImportError:
     has_gn2=False
     import standalone as handlers
-    from standalone import uses, set_progress_storage
+    from standalone import uses, progress_set_func
     sys.stderr.write("WARNING: LMM standalone version missing the Genenetwork2 environment\n")
     pass
 
@@ -348,6 +348,7 @@ def run_other_new(pheno_vector,
         t_stats, p_values = gwas.gwas(Y,
                                       G.T,
                                       K,
+                                      uses,
                                       restricted_max_likelihood=True,
                                       refit=False,verbose=True)
     Bench().report()
@@ -812,7 +813,10 @@ def gn2_redis(key,species,new_code=True):
     params = json.loads(json_params)
     
     tempdata = temp_data.TempData(params['temp_uuid'])
-    set_progress_storage(tempdata)
+    def update_tempdata(loc,i,total):
+        tempdata.store("percent_complete",round(i*100.0/total))
+        debug("Updating REDIS percent_complete=%d" % (round(i*100.0/total)))
+    progress_set_func(update_tempdata)
 
     print('kinship', np.array(params['kinship_matrix']))
     print('pheno', np.array(params['pheno_vector']))
diff --git a/wqflask/wqflask/my_pylmm/pyLMM/standalone.py b/wqflask/wqflask/my_pylmm/pyLMM/standalone.py
index 7cc3e871..36bf8fd5 100644
--- a/wqflask/wqflask/my_pylmm/pyLMM/standalone.py
+++ b/wqflask/wqflask/my_pylmm/pyLMM/standalone.py
@@ -17,24 +17,31 @@ logger = logging.getLogger('lmm2')
 logging.basicConfig(level=logging.DEBUG)
 np.set_printoptions(precision=3,suppress=True)
 
-last_location = None
-last_progress = 0
+progress_location = None 
+progress_current  = None
+progress_prev_perc     = None
 
-def set_progress_storage(location):
-    global storage
-    storage = location
+def progress_default_func(location,count,total):
+    global progress_current
+    value = round(count*100.0/total)
+    progress_current = value
+    
+progress_func = progress_default_func
+
+def progress_set_func(func):
+    global progress_func
+    progress_func = func
     
 def progress(location, count, total):
-    global last_location
-    global last_progress
+    global progress_location
+    global progress_prev_perc
     
     perc = round(count*100.0/total)
-    # print(last_progress,";",perc)
-    if perc != last_progress and (location != last_location or perc > 98 or perc > last_progress + 5):
-        storage.store("percent_complete",perc)
+    if perc != progress_prev_perc and (location != progress_location or perc > 98 or perc > progress_prev_perc + 5):
+        progress_func(location, count, total)
         logger.info("Progress: %s %d%%" % (location,perc))
-        last_location = location
-        last_progress = perc
+        progress_location = location
+        progress_prev_perc = perc
 
 def mprint(msg,data):
     """
-- 
cgit v1.2.3