From 6b63fd498a0901329d8ba68ad2f92d38ffb3ca36 Mon Sep 17 00:00:00 2001 From: andres Date: Sat, 20 Jan 2024 13:13:59 +0100 Subject: [PATCH] fix: refactor two commands in saving grade to prevent race conditions --- .../is-order-by-constraint.decorator.ts | 2 +- src/modules/auth/auth.service.ts | 5 +- .../auth/use-cases/refresh-token-use-case.ts | 5 +- .../auth/use-cases/reset-password-use-case.ts | 5 +- .../cards/infrastructure/cards.repository.ts | 72 +++++++++++++++++- src/modules/core/validation/notification.ts | 5 +- .../core/validation/validation.utils.ts | 5 +- src/modules/decks/decks.controller.ts | 8 +- .../decks/use-cases/create-deck-use-case.ts | 5 +- .../use-cases/delete-deck-by-id-use-case.ts | 5 +- .../get-random-card-in-deck-use-case.ts | 76 ++----------------- .../decks/use-cases/save-grade-use-case.ts | 34 ++++++--- src/modules/users/api/users.controller.ts | 5 +- src/settings/app-settings.ts | 5 +- 14 files changed, 140 insertions(+), 97 deletions(-) diff --git a/src/infrastructure/decorators/is-order-by-constraint.decorator.ts b/src/infrastructure/decorators/is-order-by-constraint.decorator.ts index 4f37bfa..8f1b9c6 100644 --- a/src/infrastructure/decorators/is-order-by-constraint.decorator.ts +++ b/src/infrastructure/decorators/is-order-by-constraint.decorator.ts @@ -25,7 +25,7 @@ export class IsOrderByConstraint implements ValidatorConstraintInterface { return true } - defaultMessage(args: ValidationArguments) { + defaultMessage(_args: ValidationArguments) { return 'Invalid format. Expected format is "key-direction". Direction must be "asc" or "desc".' } } diff --git a/src/modules/auth/auth.service.ts b/src/modules/auth/auth.service.ts index a95ef5e..048fa16 100644 --- a/src/modules/auth/auth.service.ts +++ b/src/modules/auth/auth.service.ts @@ -10,7 +10,10 @@ import { UsersRepository } from '../users/infrastructure/users.repository' @Injectable() export class AuthService { - constructor(private usersRepository: UsersRepository, private prisma: PrismaService) {} + constructor( + private usersRepository: UsersRepository, + private prisma: PrismaService + ) {} async createJwtTokensPair(userId: string, rememberMe?: boolean) { const accessSecretKey = process.env.ACCESS_JWT_SECRET_KEY diff --git a/src/modules/auth/use-cases/refresh-token-use-case.ts b/src/modules/auth/use-cases/refresh-token-use-case.ts index 19befac..7230263 100644 --- a/src/modules/auth/use-cases/refresh-token-use-case.ts +++ b/src/modules/auth/use-cases/refresh-token-use-case.ts @@ -5,7 +5,10 @@ import * as jwt from 'jsonwebtoken' import { AuthRepository } from '../infrastructure/auth.repository' export class RefreshTokenCommand { - constructor(public readonly userId: string, public readonly shortAccessToken: boolean) {} + constructor( + public readonly userId: string, + public readonly shortAccessToken: boolean + ) {} } @CommandHandler(RefreshTokenCommand) diff --git a/src/modules/auth/use-cases/reset-password-use-case.ts b/src/modules/auth/use-cases/reset-password-use-case.ts index 7aa7797..640d1de 100644 --- a/src/modules/auth/use-cases/reset-password-use-case.ts +++ b/src/modules/auth/use-cases/reset-password-use-case.ts @@ -5,7 +5,10 @@ import { UsersRepository } from '../../users/infrastructure/users.repository' import { UsersService } from '../../users/services/users.service' export class ResetPasswordCommand { - constructor(public readonly resetPasswordToken: string, public readonly newPassword: string) {} + constructor( + public readonly resetPasswordToken: string, + public readonly newPassword: string + ) {} } @CommandHandler(ResetPasswordCommand) diff --git a/src/modules/cards/infrastructure/cards.repository.ts b/src/modules/cards/infrastructure/cards.repository.ts index 3f262dc..e21c199 100644 --- a/src/modules/cards/infrastructure/cards.repository.ts +++ b/src/modules/cards/infrastructure/cards.repository.ts @@ -1,4 +1,5 @@ -import { Injectable, InternalServerErrorException, Logger } from '@nestjs/common' +import { Injectable, InternalServerErrorException, Logger, NotFoundException } from '@nestjs/common' +import { pick } from 'remeda' import { createPrismaOrderBy, @@ -224,4 +225,73 @@ export class CardsRepository { throw new InternalServerErrorException(e?.message) } } + + private async getSmartRandomCard(cards: Array): Promise { + const selectionPool: Array = [] + + cards.forEach(card => { + // Calculate the average grade for the card + const averageGrade = + card.grades.length === 0 + ? 0 + : card.grades.reduce((acc, grade) => acc + grade.grade, 0) / card.grades.length + // Calculate weight for the card, higher weight for lower grade card + const weight = 6 - averageGrade + + // Add the card to the selection pool `weight` times + for (let i = 0; i < weight; i++) { + selectionPool.push(card) + } + }) + + return selectionPool[Math.floor(Math.random() * selectionPool.length)] + } + + private async getNotDuplicateRandomCard( + cards: Array, + previousCardId: string + ): Promise { + const randomCard = await this.getSmartRandomCard(cards) + + if (!randomCard) { + this.logger.error(`No cards found in deck}`, { + previousCardId, + randomCard, + cards, + }) + throw new NotFoundException(`No cards found in deck`) + } + if (randomCard.id === previousCardId && cards.length !== 1) { + return this.getNotDuplicateRandomCard(cards, previousCardId) + } + + return randomCard + } + + async getRandomCardInDeck(deckId: string, userId: string, previousCardId: string) { + const cards = await this.findCardsByDeckIdWithGrade(userId, deckId) + + if (!cards.length) { + throw new NotFoundException(`No cards found in deck with id ${deckId}`) + } + + const smartRandomCard = await this.getNotDuplicateRandomCard(cards, previousCardId) + + return { + ...pick(smartRandomCard, [ + 'id', + 'question', + 'answer', + 'deckId', + 'questionImg', + 'answerImg', + 'questionVideo', + 'answerVideo', + 'created', + 'updated', + 'shots', + ]), + grade: smartRandomCard.grades[0]?.grade || 0, + } + } } diff --git a/src/modules/core/validation/notification.ts b/src/modules/core/validation/notification.ts index da118a9..f1fe399 100644 --- a/src/modules/core/validation/notification.ts +++ b/src/modules/core/validation/notification.ts @@ -24,7 +24,10 @@ export class ResultNotification { } export class NotificationExtension { - constructor(public message: string, public key: string | null) {} + constructor( + public message: string, + public key: string | null + ) {} } export class DomainResultNotification extends ResultNotification { diff --git a/src/modules/core/validation/validation.utils.ts b/src/modules/core/validation/validation.utils.ts index 75ab4da..4fddcd3 100644 --- a/src/modules/core/validation/validation.utils.ts +++ b/src/modules/core/validation/validation.utils.ts @@ -6,7 +6,10 @@ import { validationErrorsMapper, ValidationPipeErrorType } from '../../../settin import { DomainResultNotification, ResultNotification } from './notification' export class DomainError extends Error { - constructor(message: string, public resultNotification: ResultNotification) { + constructor( + message: string, + public resultNotification: ResultNotification + ) { super(message) } } diff --git a/src/modules/decks/decks.controller.ts b/src/modules/decks/decks.controller.ts index cb786d9..06f59b5 100644 --- a/src/modules/decks/decks.controller.ts +++ b/src/modules/decks/decks.controller.ts @@ -192,13 +192,9 @@ export class DecksController { description: 'Save the grade of a card', summary: 'Save the grade of a card', }) - async saveGrade(@Param('id') deckId: string, @Req() req, @Body() body: SaveGradeDto) { - const saved = await this.commandBus.execute( - new SaveGradeCommand(req.user.id, { cardId: body.cardId, grade: body.grade }) - ) - + async saveGrade(@Req() req, @Body() body: SaveGradeDto) { return await this.commandBus.execute( - new GetRandomCardInDeckCommand(req.user.id, saved.deckId, saved.id) + new SaveGradeCommand(req.user.id, { cardId: body.cardId, grade: body.grade }) ) } } diff --git a/src/modules/decks/use-cases/create-deck-use-case.ts b/src/modules/decks/use-cases/create-deck-use-case.ts index 53823be..34357da 100644 --- a/src/modules/decks/use-cases/create-deck-use-case.ts +++ b/src/modules/decks/use-cases/create-deck-use-case.ts @@ -6,7 +6,10 @@ import { Deck } from '../entities/deck.entity' import { DecksRepository } from '../infrastructure/decks.repository' export class CreateDeckCommand { - constructor(public readonly deck: CreateDeckDto, public readonly cover: Express.Multer.File) {} + constructor( + public readonly deck: CreateDeckDto, + public readonly cover: Express.Multer.File + ) {} } @CommandHandler(CreateDeckCommand) diff --git a/src/modules/decks/use-cases/delete-deck-by-id-use-case.ts b/src/modules/decks/use-cases/delete-deck-by-id-use-case.ts index 57f4714..fff830e 100644 --- a/src/modules/decks/use-cases/delete-deck-by-id-use-case.ts +++ b/src/modules/decks/use-cases/delete-deck-by-id-use-case.ts @@ -4,7 +4,10 @@ import { CommandHandler, ICommandHandler } from '@nestjs/cqrs' import { DecksRepository } from '../infrastructure/decks.repository' export class DeleteDeckByIdCommand { - constructor(public readonly id: string, public readonly userId: string) {} + constructor( + public readonly id: string, + public readonly userId: string + ) {} } @CommandHandler(DeleteDeckByIdCommand) diff --git a/src/modules/decks/use-cases/get-random-card-in-deck-use-case.ts b/src/modules/decks/use-cases/get-random-card-in-deck-use-case.ts index 2501960..e8ad5f9 100644 --- a/src/modules/decks/use-cases/get-random-card-in-deck-use-case.ts +++ b/src/modules/decks/use-cases/get-random-card-in-deck-use-case.ts @@ -1,7 +1,5 @@ import { ForbiddenException, Logger, NotFoundException } from '@nestjs/common' import { CommandHandler, ICommandHandler } from '@nestjs/cqrs' -import { Prisma } from '@prisma/client' -import { pick } from 'remeda' import { CardsRepository } from '../../cards/infrastructure/cards.repository' import { DecksRepository } from '../infrastructure/decks.repository' @@ -14,58 +12,15 @@ export class GetRandomCardInDeckCommand { ) {} } -type CardWithGrade = Prisma.cardGetPayload<{ include: { grades: true } }> - @CommandHandler(GetRandomCardInDeckCommand) export class GetRandomCardInDeckHandler implements ICommandHandler { logger = new Logger(GetRandomCardInDeckHandler.name) + constructor( private readonly cardsRepository: CardsRepository, private readonly decksRepository: DecksRepository ) {} - private async getSmartRandomCard(cards: Array): Promise { - const selectionPool: Array = [] - - cards.forEach(card => { - // Calculate the average grade for the card - const averageGrade = - card.grades.length === 0 - ? 0 - : card.grades.reduce((acc, grade) => acc + grade.grade, 0) / card.grades.length - // Calculate weight for the card, higher weight for lower grade card - const weight = 6 - averageGrade - - // Add the card to the selection pool `weight` times - for (let i = 0; i < weight; i++) { - selectionPool.push(card) - } - }) - - return selectionPool[Math.floor(Math.random() * selectionPool.length)] - } - - private async getNotDuplicateRandomCard( - cards: Array, - previousCardId: string - ): Promise { - const randomCard = await this.getSmartRandomCard(cards) - - if (!randomCard) { - this.logger.error(`No cards found in deck}`, { - previousCardId, - randomCard, - cards, - }) - throw new NotFoundException(`No cards found in deck`) - } - if (randomCard.id === previousCardId && cards.length !== 1) { - return this.getNotDuplicateRandomCard(cards, previousCardId) - } - - return randomCard - } - async execute(command: GetRandomCardInDeckCommand) { const deck = await this.decksRepository.findDeckById(command.deckId) @@ -74,32 +29,11 @@ export class GetRandomCardInDeckHandler implements ICommandHandler { + private readonly logger = new Logger(SaveGradeHandler.name) + constructor( private readonly decksRepository: DecksRepository, - private readonly gradesRepository: GradesRepository + private readonly gradesRepository: GradesRepository, + private readonly cardsRepository: CardsRepository ) {} async execute(command: SaveGradeCommand) { @@ -28,14 +37,21 @@ export class SaveGradeHandler implements ICommandHandler { throw new NotFoundException(`Deck containing card with id ${command.args.cardId} not found`) if (deck.userId !== command.userId && deck.isPrivate) { - throw new ForbiddenException(`You can't save cards to a private deck that you don't own`) + throw new ForbiddenException(`You can't save cards to a private deck that you don't own`) } - return await this.gradesRepository.createGrade({ - userId: command.userId, - grade: command.args.grade, - cardId: command.args.cardId, - deckId: deck.id, - }) + try { + await this.gradesRepository.createGrade({ + userId: command.userId, + grade: command.args.grade, + cardId: command.args.cardId, + deckId: deck.id, + }) + } catch (e) { + this.logger.error(e) + throw new InternalServerErrorException(e?.message) + } + + return this.cardsRepository.getRandomCardInDeck(deck.id, command.userId, command.args.cardId) } } diff --git a/src/modules/users/api/users.controller.ts b/src/modules/users/api/users.controller.ts index e32f0a9..50a52b1 100644 --- a/src/modules/users/api/users.controller.ts +++ b/src/modules/users/api/users.controller.ts @@ -21,7 +21,10 @@ import { UsersService } from '../services/users.service' @ApiTags('Admin') @Controller('users') export class UsersController { - constructor(private usersService: UsersService, private commandBus: CommandBus) {} + constructor( + private usersService: UsersService, + private commandBus: CommandBus + ) {} @Get() async findAll(@Query() query) { diff --git a/src/settings/app-settings.ts b/src/settings/app-settings.ts index 0682333..d2de0a8 100644 --- a/src/settings/app-settings.ts +++ b/src/settings/app-settings.ts @@ -38,7 +38,10 @@ class AuthSettings { } export class AppSettings { - constructor(public env: EnvironmentSettings, public auth: AuthSettings) {} + constructor( + public env: EnvironmentSettings, + public auth: AuthSettings + ) {} } const env = new EnvironmentSettings((process.env.NODE_ENV || 'DEVELOPMENT') as EnvironmentsTypes) const auth = new AuthSettings(process.env)