Dimensionality problem with CustomDist

Hello!! I am relatively new to PyMC, but I have an issue with a model I am trying to estimate. I was hoping someone could help, as I am relatively new to PyMC.

The model is:

def logp(value: TensorVariable,
c: TensorVariable,
beta: TensorVariable,
omega_vec: TensorVariable,
sign_mat: TensorVariable) → TensorVariable:

Omega    = sign_mat * pt.reshape(omega_vec, (N, N))

raw_det = pt.nlinalg.det(Omega)
logabsdet = pt.log(pt.abs(raw_det))


mup   = c + pt.einsum('tnk,k->tn', X_data, beta)
inv_O = pt.nlinalg.matrix_inverse(Omega)
eps   = pt.dot(value - mup, inv_O.T)


student_t_dist = pm.StudentT.dist(nu=5.0, mu=0.0, sigma=1.0)
logp_eps       = pm.logp(student_t_dist, eps)
logp_per_time  = pt.sum(logp_eps, axis=1)

return logp_per_time - logabsdet

with pm.Model() as model:

beta = pm.Normal("beta", mu=0, sigma=10, shape=(K,))
c    = pm.Normal("c",    mu=0, sigma=10, shape=(N,))


omega_vec = pm.HalfCauchy(
    "omega_vec",
    beta=2.0,
    shape=(N * N,)
)


sign_mat = pt.constant(signs)

# Now, sample
pm.CustomDist(
    'custom_dist',
    c,
    beta,
    omega_vec,
    sign_mat,
    logp=logp,
    observed=y_data,
)

trace = pm.sample()

Unfortunately, I get an error:

ValueError: Could not broadcast dimensions. Incompatible shapes were [(ScalarConstant(ScalarType(int64), data=1), ScalarConstant(ScalarType(int64), data=2)), (ScalarConstant(ScalarType(int64), data=1), ScalarConstant(ScalarType(int64), data=4)), (ScalarConstant(ScalarType(int64), data=1), ScalarConstant(ScalarType(int64), data=4)), (ScalarConstant(ScalarType(int64), data=2), ScalarConstant(ScalarType(int64), data=2))]

I think the shapes are ok. I have this other model that works well - but I would like to be able to compute the WAIC, for example, so I hoped to be able to run it using a CustomDist

with pm.Model() as model:

beta = pm.Normal("beta", mu=0, sigma=10, shape=(K,))
c    = pm.Normal("c",    mu=0, sigma=10, shape=(N,))


omega_vec = pm.HalfCauchy(
    "omega_vec",
    beta=2.0,
    shape=(N * N,)
)


sign_mat = pt.constant(signs)
Omega    = sign_mat * pt.reshape(omega_vec, (N, N))


raw_det = pt.nlinalg.det(Omega)
logabsdet = pt.log(pt.abs(raw_det))


pm.Potential("jacobian", -T * logabsdet)


mu    = c + pt.einsum('tnk,k->tn', X_data, beta)
inv_O = pt.nlinalg.matrix_inverse(Omega)
eps   = pt.dot(y_data - mu, inv_O.T)

student_t_dist = pm.StudentT.dist(nu=nup, mu=0.0, sigma=1.0)
logp_eps       = pm.logp(student_t_dist, eps)
pm.Potential("likelihood", pt.sum(logp_eps))

trace = pm.sample()