LuluShi Blog
open main menu
Part of series:StatGenetic

SBayesRC

/ 30 min read

(1) SBayesRC的直观解释

可以把它理解成三层结构。

  1. 是把个体层面的多 SNP 线性模型,改写成只依赖 GWAS summary statistics 和 LD 的模型。

  2. 是把这个 summary-data 模型做成低秩形式,让它能在全基因组几百万个 SNP 上跑得动,而且对 GWAS 与 LD 参考面板之间的 LD 不匹配更稳健。

  3. 是在 SNP 效应的先验里引入功能注释,让注释既能影响某个 SNP 是不是非零效应,也能影响它更可能落在哪个效应大小区间里。最后再用 Gibbs 采样,把 SNP 效应、混合成分归属、注释效应、残差方差和遗传方差一起学出来。

SBayesRC整体流程图

SBayesRC 的出发点,仍然是标准的加性线性模型。设有 N 个个体、m 个 SNP。把表型和每个 SNP 基因型都做中心化,且把基因型标准化到均值为 0、方差为 1。则个体层面的模型写成

y=Xβ+e\mathbf y = \mathbf X \boldsymbol{\beta} + \mathbf e eN(0,Iσe2)\mathbf e \sim N\left(\mathbf 0,\mathbf I \sigma_e^2\right)

这里的含义非常直接。矩阵 XX 的每一列是一个 SNP 的标准化基因型,向量 β\beta 是每个 SNP 的联合效应,也就是在把所有 SNP 同时放进模型后对应的真实效应,ee 是未被 SNP 解释的剩余部分。

  1. 从个体层模型推到 summary-data 模型

先定义 GWAS 中每个 SNP 的边际效应估计。对于标准化基因型,边际效应向量可以写成

b=1NXy\mathbf b = \frac{1}{N}\mathbf X^\top \mathbf y

把个体模型代进去,就得到

b=1NXXβ+1NXe\mathbf b = \frac{1}{N}\mathbf X^\top \mathbf X \boldsymbol{\beta} + \frac{1}{N}\mathbf X^\top \mathbf e

再定义 LD 相关矩阵

R=1NXX\mathbf R = \frac{1}{N}\mathbf X^\top \mathbf X

以及噪声项

ε=1NXe\boldsymbol{\varepsilon} = \frac{1}{N}\mathbf X^\top \mathbf e

于是就得到 SBayesRC 的 summary-data 基本方程

b=Rβ+ε\boxed {\mathbf b = \mathbf R \boldsymbol{\beta} + \boldsymbol{\varepsilon}}

并且

Var(ε)=σe2NR\mathrm{Var}\left(\boldsymbol{\varepsilon}\right) = \frac{\sigma_e^2}{N}\mathbf R

这个式子是整个方法的根。它在说,GWAS 的边际效应不是直接等于真实效应,而是等于 LD 矩阵把真实联合效应“卷”了一遍,再加上一个相关噪声。也正因为这里噪声的协方差正比于 R,所以如果直接在这个模型上做全基因组 MCMC,计算会很重,而且对 LD 误差会敏感。

如果 GWAS 提供的是 0/1/2 编码基因型上的边际效应估计,也就是常见的 bb^* 和标准误 sese,那么在 SBayesRC 里会先把它们转换到标准化基因型尺度:

bj=sjbj,sj=σy2Njsej2+(bj)2b_j = s_j b_j^{\ast}, \quad s_j = \sqrt{\frac{\sigma_y^2}{N_j \mathrm{se}_j^2 + \left(b_j^{\ast}\right)^2}}

如果先把表型方差标准化为 1,那么就变成

sj=(Njsej2+(bj)2)1/2s_j = \left(N_j \mathrm{se}_j^2 + \left(b_j^{\ast}\right)^2\right)^{-1/2}

这一步的意义是,把不同 SNP 的 GWAS 边际效应都放到同一个、和理论推导匹配的尺度上。最后在输出联合效应时,再按同样的比例缩放回原始表型尺度。

  1. 为什么要做低秩变换

如果直接用

b=Rβ+ε\boxed {\mathbf b = \mathbf R \boldsymbol{\beta} + \boldsymbol{\varepsilon}}

最大的麻烦是,ϵ\epsilon 的协方差不是对角阵,而是和 RR 成比例。也就是说,误差项彼此相关。对几百万 SNP 来说,这既不稳,也不快。

SBayesRC 的关键工程突破,就是把每个 LD block 的 RR 做特征分解。对于某个 block,

R=UΛU\mathbf R = \mathbf U \boldsymbol{\Lambda} \mathbf U^\top

这里 UU 是特征向量矩阵,Λ\Lambda 是特征值对角矩阵。然后对 summary-data 模型左乘一个白化变换

Λ1/2U\boldsymbol{\Lambda}^{-1/2}\mathbf U^\top

于是得到新的模型

w=Qβ+ϵ\mathbf w = \mathbf Q \boldsymbol{\beta} + \boldsymbol{\epsilon}

其中

w=Λ1/2Ub,Q=Λ1/2U\mathbf w = \boldsymbol{\Lambda}^{-1/2}\mathbf U^\top \mathbf b, \quad \mathbf Q = \boldsymbol{\Lambda}^{1/2}\mathbf U^\top

从方差传播就能看出,噪声项从之前的 Var(ε)=σe2NR\mathrm{Var}\left(\boldsymbol{\varepsilon}\right) = \frac{\sigma_e^2}{N}\mathbf R 变成了独立形式

Var(ϵ)=σϵ2NI\mathrm{Var}\left(\boldsymbol{\epsilon}\right) = \frac{\sigma_\epsilon^2}{N}\mathbf I

主文有时把上式里的 NN 吸收到 σϵ\sigma_ \epsilon 的定义里,因此也会写成单位阵乘以 σϵ\sigma_ \epsilon 方差。两种写法只是记号归一化不同,算法结构完全一样。

这个变换最漂亮的地方在于,它只压缩了观测空间,没有压缩 SNP 参数空间。原来在一个 block 内要拟合 m 维的 bb,现在只需要拟合 q 维的 ww,其中 q 远小于 m,但要估计的仍然是 m 个 SNP 联合效应 β\beta

低秩变换与参数空间保持不变

GCTB 里真正做的时候,并不会保留全部特征向量,而是只取累计解释至少 ρ\rho 比例 LD 方差的前 q 个主成分。论文默认用了

ρ=99.5%\rho = 99.5\%

这就使得

qmq \ll m

同时又不会丢掉太多 LD 信息。

  1. 低秩为什么不会破坏功能注释学习

这部分很多人第一次看会迷糊,因为会本能地觉得,既然已经把 block 内的数据压成 q 维了,那功能注释是不是也被压坏了。

答案是否定的。原因很简单,注释不是加在 ww 上,而是加在 β\beta 的先验上。低秩变换之后,模型仍然是在估计每个 SNP 的 βj\beta_j。只不过似然项从原来的

bβ\mathbf b \mid \boldsymbol{\beta}

变成了

wβ\mathbf w \mid \boldsymbol{\beta}

也就是说,观测层换了,参数层没换。于是功能注释仍然能一对一对应到每个 SNP 的 βj\beta_j 上。

补充材料专门强调了一个对照。如果进一步把参数空间也压缩成主成分效应,比如定义

β=ΛUβ\boldsymbol{\beta}^{\ast} = \boldsymbol{\Lambda}\mathbf U^\top \boldsymbol{\beta}

然后去拟合

b=Uβ+ε\mathbf b = \mathbf U \boldsymbol{\beta}^{\ast} + \boldsymbol{\varepsilon}

那就不行了。因为这时 β\beta^ \star 的每个元素都不再对应某个具体 SNP,而是许多 SNP 联合效应的线性组合。此时功能注释就失去了明确对象,注释学习也失去了解释性。

所以,SBayesRC 低秩化之所以成立,不只是因为它更快,更因为它保住了 SNP 级别参数这一层。


(2) SBayesRC的混合先验

在 SBayesR 里,每个 SNP 的效应都来自同一个全局混合分布。而在 SBayesRC 里,每个 SNP 的混合权重变成了 SNP 特异的,会随功能注释而变。

核心先验

βjk=15πjkN(0,γkσg2)\beta_j \sim \sum_{k=1}^{5}\pi_{jk} N\left(0,\gamma_k \sigma_g^2\right)

这里有五个混合成分。第一个成分是零效应,后四个成分是不同方差大小的正态分布。论文默认的缩放因子是

γ=[0,0.001,0.01,0.1,1]%\gamma = \left[0,0.001,0.01,0.1,1\right]^\top \%

也就是等价于

γ=[0,105,104,103,102]\gamma = \left[0,10^{-5},10^{-4},10^{-3},10^{-2}\right]^\top

这样设计的直觉非常强。它不是说效应大小连续地无边无际,而是先粗分成几个桶。某个 SNP 要么是零效应,要么是很小效应,要么是中等效应,要么是更大效应。每个桶的方差尺度按总遗传方差 σg2\sigma_g^2 的固定比例来定。

这里的重点不在 γ\gamma 本身,而在每个 SNP 的桶概率 πjk\pi_{jk} 不一样。它由功能注释来决定。

(3) SBayesRC的注释

  1. 注释如何进入混合先验

AA 表示注释矩阵,维度是 m 乘以 c,也就是 m 个 SNP、c 个注释。对于 SNP jj 和混合成分 kk,文章给出的模型是

f(πjk)=μk+l=1cAjlαklf\left(\pi_{jk}\right) = \mu_k + \sum_{l=1}^{c} A_{jl}\alpha_{kl}

这里的含义可以逐项理解。

μk\mu_k 是第 kk 个混合成分在全基因组里的基线倾向。

AjlA_{jl} 是 SNP jj 在第 ll 个注释上的取值。二元注释就取 0 或 1,连续注释则先标准化到均值 0、方差 1。

αkl\alpha_{kl} 则是“第 ll 个注释把 SNP 推向第 kk 个效应桶的强度”。

所以一个很重要的认识是,SBayesRC 不是让功能注释直接改写 βj\beta_j 的数值,而是先改写 βj\beta_j 属于哪类分布的概率,再通过这个分布去约束 βj\beta_j

  1. 为什么不能直接对 πjk\pi_{jk} 采样

直接对每个 SNP 的

πj1,πj2,πj3,πj4,πj5\pi_{j1},\pi_{j2},\pi_{j3},\pi_{j4},\pi_{j5}

采样会有一个明显问题,就是它们必须满足

k=15πjk=1\sum_{k=1}^{5}\pi_{jk}=1

也就是说,同一个 SNP 的五个概率不是独立变量。这样一来,MCMC 更新就会很别扭。你每更新一个,另外几个都要跟着联动,Gibbs 采样不方便,Metropolis 调参也会麻烦。

所以 SBayesRC 做了一个很巧妙的重参数化。先定义混合成分指示变量

δj=kwith probability πjk\delta_j = k \quad \text{with probability } \pi_{jk}

然后不直接建模 πjk\pi_{jk},而是建模一个“逐级爬梯子”的条件概率

pjk=Pr(δjkδjk1),k2p_{jk} = \Pr\left(\delta_j \ge k \mid \delta_j \ge k-1\right),\qquad k \ge 2

它的意思是,SNP 已经跨过前一个门槛以后,还有多大概率继续往更大效应的桶里爬。

于是五个混合权重可以写成

πj1=1pj2\pi_{j1} = 1 - p_{j2} πj2=(1pj3)pj2\pi_{j2} = \left(1-p_{j3}\right)p_{j2} πj3=(1pj4)pj3pj2\pi_{j3} = \left(1-p_{j4}\right)p_{j3}p_{j2} πj4=(1pj5)pj4pj3pj2\pi_{j4} = \left(1-p_{j5}\right)p_{j4}p_{j3}p_{j2} πj5=pj5pj4pj3pj2\pi_{j5} = p_{j5}p_{j4}p_{j3}p_{j2}

这样一改以后,真正被注释模型驱动的是彼此独立得多的 pjkp_{jk},而不是受和为 1 约束的 πjk\pi_{jk}

从条件概率 p 到混合权重 pi 的阶梯图

从直觉上看,pj2p_{j2} 决定的是这个 SNP 会不会非零。pj3p_{j3} 决定的是,在已经非零的前提下,它是留在“小效应”还是继续向上。pj4p_{j4}pj5p_{j5} 依次决定它会不会继续进入更大的效应桶。

  1. probit 模型为什么能让注释参数做 Gibbs 采样

为了让 pjkp_{jk} 能用 Gibbs 采样,SBayesRC 选的是 probit 链接。把它写成最标准的形式就是

Φ1(pjk)=μk+c=1CAjcαkc\Phi^{-1}\left(p_{jk}\right)=\mu_k+\sum_{c=1}^{C}A_{jc}\alpha_{kc}

等价地,也可以写成

pjk=Φ(μk+c=1CAjcαkc)p_{jk}=\Phi\left(\mu_k+\sum_{c=1}^{C}A_{jc}\alpha_{kc}\right)

这里 Φ\Phi 是标准正态分布函数。

之所以这一步特别重要,是因为 probit 模型可以用 Albert-Chib 的潜变量技巧,立刻把二项变量更新问题改写成正态线性模型更新问题。

先引入一个伯努利指示变量

zjkBernoulli(pjk)z_{jk} \sim \mathrm{Bernoulli}\left(p_{jk}\right)

再引入潜变量

ljk=μk+c=1CAjcαkc+ηjkl_{jk} = \mu_k + \sum_{c=1}^{C}A_{jc}\alpha_{kc} + \eta_{jk} ηjkN(0,1)\eta_{jk}\sim N\left(0,1\right)

并定义

zjk=1    ljk>0z_{jk}=1 \iff l_{jk}>0

这样一来,原来不好直接采样的离散概率模型,就转成了一个带截断正态潜变量的高斯模型。于是

ljkzjk=1,μk,αkTN(μk+Ajαk,1;0,)l_{jk}\mid z_{jk}=1,\mu_k,\boldsymbol{\alpha}_k \sim TN\left(\mu_k+\mathbf A_j^\top \boldsymbol{\alpha}_k,1;0,\infty\right) ljkzjk=0,μk,αkTN(μk+Ajαk,1;,0)l_{jk}\mid z_{jk}=0,\mu_k,\boldsymbol{\alpha}_k \sim TN\left(\mu_k+\mathbf A_j^\top \boldsymbol{\alpha}_k,1;-\infty,0\right)

这个时候,对每个混合层级 kk,所有注释效应 αkc\alpha_{kc} 的更新就都退化成标准的单变量高斯后验。

  1. 注释效应 αkc\alpha_{kc} 的全条件分布

在引入潜变量 ljkl_{jk} 以后,αkc\alpha_{kc} 的更新就和普通 Bayesian 线性回归几乎一样。文中的全条件分布是

αkclk,αk,c,σαk2N(rkcCkc,1Ckc)\alpha_{kc}\mid \mathbf l_k,\alpha_{k,-c},\sigma_{\alpha_k}^2 \sim N\left(\frac{r_{kc}}{C_{kc}},\frac{1}{C_{kc}}\right)

其中

rkc=Ac(lkccAcαkc)r_{kc}=\mathbf A_c^\top\left(\mathbf l_k-\sum_{c'\ne c}\mathbf A_{c'}\alpha_{kc'}\right) Ckc=AcAc+1σαk2C_{kc}=\mathbf A_c^\top\mathbf A_c+\frac{1}{\sigma_{\alpha_k}^2}

它的含义很直观。rkcr_{kc} 是“当前潜变量残差”和注释列 AcA_c 的相关程度,谁更能解释当前这层的潜变量,谁的 αkc\alpha_{kc} 就会更偏离 0。CkcC_{kc} 则相当于后验精度,等于数据精度加先验精度。

SBayesRC 对注释效应使用的是正态先验

αklN(0,σαk2)\alpha_{kl}\sim N\left(0,\sigma_{\alpha_k}^2\right)

再对这个方差放一个逆卡方先验

σαk2χ2(να,τα2)\sigma_{\alpha_k}^2 \sim \chi^{-2}\left(\nu_\alpha,\tau_\alpha^2\right)

论文设置的是

να=4,τα2=1\nu_\alpha=4,\qquad \tau_\alpha^2=1

于是其全条件分布仍然是逆卡方

σαk2αkχ2(ν~α,τ~α2)\sigma_{\alpha_k}^2 \mid \boldsymbol{\alpha}_k \sim \chi^{-2}\left(\tilde{\nu}_\alpha,\tilde{\tau}_\alpha^2\right)

其中

ν~α=C+να\tilde{\nu}_\alpha = C+\nu_\alpha τ~α2=αkαk+νατα2ν~α\tilde{\tau}_\alpha^2 = \frac{\boldsymbol{\alpha}_k^\top \boldsymbol{\alpha}_k + \nu_\alpha \tau_\alpha^2}{\tilde{\nu}_\alpha}

所以在这一层,Gibbs 采样几乎是纯闭式的。


(4) MCMC过程参数的更新

  1. SNP 效应 βj\beta_j 的更新

回到低秩观测模型

w=Qβ+ϵ\mathbf w=\mathbf Q\boldsymbol{\beta}+\boldsymbol{\epsilon}

假设当前 SNP jj 落在第 kk 个混合成分,也就是

δj=k\delta_j = k

那么 βj\beta_j 的先验就是

βjN(0,γkσg2)\beta_j \sim N\left(0,\gamma_k \sigma_g^2\right)

把这个先验和低秩似然合在一起,就能得到 βj\beta_j 的全条件分布是单变量正态。论文把它写成“均值是一个局部 BLUP,方差由当前噪声方差和当前混合成分共同决定”的形式。常见写法是

βjw,βj,δj=k,σg2,σϵ2N(rjCj,σϵ2Cj)\beta_j\mid \mathbf w,\boldsymbol{\beta}_{-j},\delta_j=k,\sigma_g^2,\sigma_\epsilon^2 \sim N\left(\frac{r_j}{C_j},\frac{\sigma_\epsilon^2}{C_j}\right)

其中

rj=Qj(wjjQjβj)r_j=\mathbf Q_j^\top\left(\mathbf w-\sum_{j'\ne j}\mathbf Q_{j'}\beta_{j'}\right) Cj=1+σϵ2γkσg2C_j=1+\frac{\sigma_\epsilon^2}{\gamma_k \sigma_g^2}

这一步的直觉也很清楚。rjr_j 是在把其他 SNP 当前效应扣掉以后,剩下还能被 SNP jj 解释的那部分信号。 CjC_j 则是在平衡数据证据和先验收缩。如果当前混合成分的方差 γkσg2\gamma_k \sigma_g^2 很小,那么 βj\beta_j 就会被更强地往 0 收缩。

  1. δj\delta_j 这个混合成员指标是怎样更新的

有了当前的 βj\beta_{j} 以及每个成分的先验权重 πjk\pi_{jk} 以后,δj\delta_j 的后验概率就是五个候选分量竞争的结果

Pr(δj=kw,β,σg2,σϵ2)πjkf(wδj=k,β,σg2,σϵ2)\Pr\left(\delta_j=k\mid \mathbf w,\boldsymbol{\beta},\sigma_g^2,\sigma_\epsilon^2\right)\propto \pi_{jk}\,f\left(\mathbf w\mid \delta_j=k,\boldsymbol{\beta},\sigma_g^2,\sigma_\epsilon^2\right)

再对

k=1,2,3,4,5k=1,2,3,4,5

归一化即可。


(5) 遗传方差、遗传力和功能富集的估计

在每次 MCMC 迭代里,只要当前有一组 SNP 效应样本 β\beta,就可以直接算总遗传方差。文章给出的推导是

σg2=βRβ\sigma_g^2 = \boldsymbol{\beta}^\top \mathbf R \boldsymbol{\beta}

又因为

R=QQ\mathbf R=\mathbf Q^\top \mathbf Q

所以也可写成

σg2=βQQβ\sigma_g^2 = \boldsymbol{\beta}^\top \mathbf Q^\top \mathbf Q \boldsymbol{\beta}

如果定义

w^=Qβ\hat{\mathbf w}=\mathbf Q\boldsymbol{\beta}

那么

σg2=w^w^\sigma_g^2=\hat{\mathbf w}^\top \hat{\mathbf w}

在假定表型方差为 1 的标准化情形下,就有

hSNP2=σg2h_{\mathrm{SNP}}^2=\sigma_g^2

也就是说,SBayesRC 的 SNP 遗传力不是额外另开一套模型估出来的,而是直接从当前 β\beta 样本诱导出来的。

对于二元注释 cc,文章把该注释内 SNP 解释的总方差写成

σc2=jcβj2\sigma_c^2=\sum_{j\in c}\beta_j^2

然后定义 per-SNP heritability enrichment 为

θc=σc2/mcσg2/m\theta_c=\frac{\sigma_c^2/m_c}{\sigma_g^2/m}

其中 m_c 是该注释里 SNP 的数量,m 是全基因组总 SNP 数。

对于连续注释,文中采用的是回归斜率定义

E[βj2]=μc+AjcωcE\left[\beta_j^2\right]=\mu_c+A_{jc}\omega_c

于是

θc=1+ωc\theta_c=1+\omega_c

这些量都是在每次 MCMC 迭代里计算一次,最后用后验均值做估计。


(6) MCMC 流程

到这里,其实整套算法已经拼完整了。把它压成一句一句就是下面这个顺序。

SBayesRC的MCMC循环图

(7) SBayesRC 和 SBayesR 的本质差别

SBayesR 也有混合正态先验,但它默认所有 SNP 共用一套全局混合权重,也就是所有 SNP 的“进桶概率”都一样。

SBayesRC 把这个地方改成了 SNP 特异的。不同 SNP 因为注释不同,会拥有不同的 πjk\pi_{jk}。于是它不仅能学习“总体上有多少 SNP 是大效应”,还能学习“哪些注释上的 SNP 更容易是大效应”。

再往工程上说,SBayesRC 还把 SBayesR 的 summary-data 模型进一步升级成了低秩形式,这也是它能稳定分析全基因组高密度 SNP 的关键。

(8) 小结

SBayesRC 的数学本质可以概括成下面这句话。

它把

b=Rβ+ε\boxed {\mathbf b = \mathbf R\boldsymbol{\beta}+\boldsymbol{\varepsilon}}

这个 summary-data 观测模型,先通过特征分解变成更稳健的低秩模型

w=Qβ+ϵ\boxed {\mathbf w=\mathbf Q\boldsymbol{\beta}+\boldsymbol{\epsilon}}

再给每个 SNP 的 βj\beta_j 放一个由功能注释驱动的混合正态先验

βjk=15πjkN(0,γkσg2)\boxed {\beta_j\sim \sum_{k=1}^{5}\pi_{jk}N\left(0,\gamma_k \sigma_g^2\right)}

最终通过 Gibbs 采样,把 β\betaδ\deltaα\alphaσg2\sigma_g^2σϵ2\sigma_{\epsilon}^2 一起估计出来。

(9) SBayesRC 的 python 实现

import numpy as np


def sigmoid(x):
    return 1.0 / (1.0 + np.exp(-x))


def build_pi_from_annotations(A, alpha_seq):
    """
    根据 annotation 构造每个 SNP 的 5 组 mixture 概率 pi_jk

    A: (m, c)
       m 个 SNP, c 个 annotations

    alpha_seq: (c+1, 4)
       对应论文里的顺序条件概率 p_{j2}, p_{j3}, p_{j4}, p_{j5}
       第一行是截距,后面每行对应一个 annotation 的系数

    返回:
    pi: (m, 5)
        每个 SNP 属于 5 个 mixture component 的概率
    """
    A1 = np.column_stack([np.ones(A.shape[0]), A])  # 加截距
    eta = A1 @ alpha_seq                            # (m, 4)
    p = np.clip(sigmoid(eta), 1e-6, 1 - 1e-6)      # 教学简化: 用 sigmoid

    # 顺序条件概率 -> 最终 5 成分概率
    # component 0: 零效应
    # component 1..4: 非零不同方差层
    pi = np.zeros((A.shape[0], 5))
    pi[:, 0] = 1 - p[:, 0]
    pi[:, 1] = p[:, 0] * (1 - p[:, 1])
    pi[:, 2] = p[:, 0] * p[:, 1] * (1 - p[:, 2])
    pi[:, 3] = p[:, 0] * p[:, 1] * p[:, 2] * (1 - p[:, 3])
    pi[:, 4] = p[:, 0] * p[:, 1] * p[:, 2] * p[:, 3]

    # 数值稳定一下
    pi /= pi.sum(axis=1, keepdims=True)
    return pi


def low_rank_transform(R, b, rho=0.995):
    """
      low-rank 变换:
      R = U diag(lam) U'
      w = lam^{-1/2} U' b
      Q = lam^{1/2} U'

    只保留累计解释 rho 比例方差的主成分
    """
    evals, evecs = np.linalg.eigh(R)
    idx = np.argsort(evals)[::-1]
    evals = np.clip(evals[idx], 1e-12, None)
    evecs = evecs[:, idx]

    cum = np.cumsum(evals) / np.sum(evals)
    q = np.searchsorted(cum, rho) + 1

    U = evecs[:, :q]
    lam = evals[:q]

    Q = (np.sqrt(lam)[:, None] * U.T)   # q x m
    w = (U.T @ b) / np.sqrt(lam)        # q

    return w, Q, q


def invgamma_sample(shape, scale, rng):
    """
    采样 Inv-Gamma(shape, scale)
    """
    return 1.0 / rng.gamma(shape, 1.0 / scale)


def run_toy_sbayesrc(
    b,
    R,
    A,
    n,
    alpha_seq,
    n_iter=800,
    burn_in=300,
    rho=0.995,
    seed=1,
):
    """
    SBayesRC-like Gibbs sampler

    参数
    ----
    b : (m,)
        GWAS marginal effects
    R : (m, m)
        LD correlation matrix
    A : (m, c)
        annotation matrix
    n : int
        GWAS sample size
    alpha_seq : (c+1, 4)
        annotation -> mixture probability 的固定系数
    """
    rng = np.random.default_rng(seed)

    m = len(b)

    # 论文里的 5 个 mixture 方差比例:
    # [0, 0.001, 0.01, 0.1, 1]% = [0, 1e-5, 1e-4, 1e-3, 1e-2]
    gamma = np.array([0.0, 1e-5, 1e-4, 1e-3, 1e-2])

    # 低秩变换
    w, Q, q = low_rank_transform(R, b, rho=rho)

    # 由 annotation 生成每个 SNP 的先验 mixture 概率
    pi = build_pi_from_annotations(A, alpha_seq)

    # 初始化
    beta = np.zeros(m)
    z = np.zeros(m, dtype=int)   # 每个 SNP 属于哪一个 mixture component
    sigma_g2 = 0.05
    sigma_e2 = 1.0

    resid = w.copy()             # resid = w - Q @ beta
    qq = np.sum(Q * Q, axis=0)   # 每个 SNP 列向量 q_j 的平方范数

    beta_sum = np.zeros(m)
    pip_sum = np.zeros(m)

    # 超参数(教学版随便给得温和一些)
    a_e, b_e = 2.0, 1.0
    a_g, b_g = 2.0, 1e-3

    for it in range(n_iter):
        noise_var = sigma_e2 / n

        # 逐个 SNP 做 Gibbs 更新
        for j in range(m):
            qj = Q[:, j]

            # 先把旧 beta_j 加回残差
            if beta[j] != 0:
                resid += qj * beta[j]

            s = qj @ resid

            logw = np.empty(5)
            post_mean = np.zeros(5)
            post_var = np.zeros(5)

            # k = 0: 零效应分量
            logw[0] = np.log(pi[j, 0] + 1e-300)

            # k = 1..4: 非零分量
            for k in range(1, 5):
                tau2 = gamma[k] * sigma_g2

                # 单坐标正态-正态共轭更新
                v = 1.0 / (qq[j] / noise_var + 1.0 / tau2)
                m_k = v * (s / noise_var)

                # 积分掉 beta_j 后的 marginal weight
                logw[k] = (
                    np.log(pi[j, k] + 1e-300)
                    + 0.5 * (np.log(v) - np.log(tau2) + (m_k * m_k) / v)
                )

                post_mean[k] = m_k
                post_var[k] = v

            # 采样 mixture membership z_j
            logw -= np.max(logw)
            ww = np.exp(logw)
            ww /= ww.sum()
            z[j] = rng.choice(5, p=ww)

            # 再采样 beta_j
            if z[j] == 0:
                beta[j] = 0.0
            else:
                beta[j] = rng.normal(post_mean[z[j]], np.sqrt(post_var[z[j]]))

            # 把新 beta_j 从残差中减掉
            if beta[j] != 0:
                resid -= qj * beta[j]

        # 更新 sigma_e^2
        # whitened residual: resid ~ N(0, sigma_e2 / n * I)
        rss = resid @ resid
        sigma_e2 = invgamma_sample(
            a_e + q / 2.0,
            b_e + 0.5 * n * rss,
            rng,
        )

        # 更新 sigma_g^2
        active = z > 0
        if np.any(active):
            scaled_ss = np.sum(beta[active] ** 2 / gamma[z[active]])
            sigma_g2 = invgamma_sample(
                a_g + active.sum() / 2.0,
                b_g + 0.5 * scaled_ss,
                rng,
            )
        else:
            sigma_g2 = invgamma_sample(a_g, b_g, rng)

        # 存后验均值
        if it >= burn_in:
            beta_sum += beta
            pip_sum += active.astype(float)

    n_keep = n_iter - burn_in
    return {
        "beta_mean": beta_sum / n_keep,
        "pip": pip_sum / n_keep,
        "sigma_e2": sigma_e2,
        "sigma_g2": sigma_g2,
        "pi_prior": pi,
    }


# =========================================================
# 下面构造一个“可跑通、可教学”的小例子
# =========================================================
if __name__ == "__main__":
    rng = np.random.default_rng(123)

    m = 40          # 40 个 SNP,适合教学
    n = 80000       # GWAS sample size
    c = 2           # 两个 annotations

    # 1) 先造一个小 LD 块,这里用 AR(1) 相关结构
    rho_ld = 0.6
    idx = np.arange(m)
    R = rho_ld ** np.abs(idx[:, None] - idx[None, :])

    # 2) 构造 annotation
    # A[:,0] 是二值注释,例如“是否在功能区”
    # A[:,1] 是连续注释,例如某个 conservation score
    A = np.column_stack([
        rng.binomial(1, 0.25, size=m),
        rng.normal(size=m),
    ])
    A[:, 1] = (A[:, 1] - A[:, 1].mean()) / A[:, 1].std()

    # 3) 为了模拟数据,先设一个“真实”的 annotation -> mixture 关系
    #    行: [截距, annotation1, annotation2]
    #    列: p2, p3, p4, p5 的线性预测值
    alpha_true = np.array([
        [-2.8, -1.4, -1.6, -1.8],  # intercept
        [ 2.0,  1.5,  1.1,  0.8],  # binary annotation 让 SNP 更容易进入较大效应成分
        [ 0.8,  0.5,  0.3,  0.2],  # quantitative annotation 也有一点加成
    ])

    pi_true = build_pi_from_annotations(A, alpha_true)

    # 5 个方差层
    gamma = np.array([0.0, 1e-5, 1e-4, 1e-3, 1e-2])

    sigma_g2_true = 0.15
    sigma_e2_true = 1.0

    # 4) 按真实 mixture 先验抽样真实 z 和 beta
    z_true = np.array([rng.choice(5, p=pi_true[j]) for j in range(m)])

    beta_true = np.zeros(m)
    for j in range(m):
        if z_true[j] > 0:
            beta_true[j] = rng.normal(
                0.0,
                np.sqrt(gamma[z_true[j]] * sigma_g2_true),
            )

    # 5) 根据 summary model 生成 GWAS marginal effect b
    #    b = R beta + eps, eps ~ N(0, R * sigma_e2 / n)
    L = np.linalg.cholesky(R + 1e-10 * np.eye(m))
    eps = L @ rng.normal(size=m) * np.sqrt(sigma_e2_true / n)
    b = R @ beta_true + eps

    # 6) SBayesRC 拟合
    fit = run_toy_sbayesrc(
        b=b,
        R=R,
        A=A,
        n=n,
        alpha_seq=alpha_true,   # 教学简化:这里先假设 annotation 参数已知
        n_iter=800,
        burn_in=300,
        rho=0.995,
        seed=42,
    )

    beta_hat = fit["beta_mean"]
    pip = fit["pip"]

    # 7) 看看前几个 SNP
    top = np.argsort(-pip)[:10]

    print("Top SNPs by posterior inclusion probability")
    print("idx\tPIP\tbeta_true\tbeta_hat\tannot1")
    for j in top:
        print(
            f"{j}\t{pip[j]:.3f}\t{beta_true[j]:.4f}\t{beta_hat[j]:.4f}\t{int(A[j,0])}"
        )

    print("\nposterior sigma_g2 =", fit["sigma_g2"])
    print("posterior sigma_e2 =", fit["sigma_e2"])
    print("corr(beta_true, beta_hat) =", np.corrcoef(beta_true, beta_hat)[0, 1])