[go: up one dir, main page]

Skip to content

Commit

Permalink
fix some error
Browse files Browse the repository at this point in the history
  • Loading branch information
chenqy4933 committed Jul 18, 2023
1 parent 0f2b055 commit fefe629
Showing 1 changed file with 6 additions and 12 deletions.
18 changes: 6 additions & 12 deletions src/kern/gpu/kernel_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -591,12 +591,9 @@ void llm_matmul_compute_with_head_stride_float(
float alpha = 1.f;
float beta = 0.f;
CUBLAS_CHECK(cublasSetStream(cublas_handle, stream));
for (uint32_t h = 0; h < head; h++) {
CUBLAS_CHECK(cublasSgemm(
cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, N, M, K, &alpha,
srck + h * head_embd, embd, srcq + h * head_embd, embd, &beta,
dst + h * M * N, N));
}
CUBLAS_CHECK(cublasSgemmStridedBatched(
cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, N, M, K, &alpha, srck, embd,
head_embd, srcq, embd, head_embd, &beta, dst, N, M * N, head));
}

void llm_head_batched_matmul_compute_float(
Expand All @@ -612,12 +609,9 @@ void llm_head_batched_matmul_compute_float(
float beta = 0.f;

CUBLAS_CHECK(cublasSetStream(cublas_handle, stream));
for (uint32_t h = 0; h < head; h++) {
CUBLAS_CHECK(cublasSgemm(
cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, M, N, K, &alpha,
v + h * head_embd, embd, qk + h * K * N, K, &beta, dst + h * head_embd,
embd));
}
CUBLAS_CHECK(cublasSgemmStridedBatched(
cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, M, N, K, &alpha, v, embd,
head_embd, qk, K, K * N, &beta, dst, embd, head_embd, head));
}

void llm_glm_gmask_inf_float(
Expand Down

0 comments on commit fefe629

Please sign in to comment.