diff --git a/.env.stripe-sample b/.env.stripe-sample new file mode 100644 index 000000000..c35dff0a5 --- /dev/null +++ b/.env.stripe-sample @@ -0,0 +1,2 @@ +STRIPE_SECRET_KEY=sk_test_changeme +STRIPE_WEBHOOK_SECRET=whsec_changeme \ No newline at end of file diff --git a/.github/workflows/build-latest.yml b/.github/workflows/build-and-deploy-latest.yml similarity index 56% rename from .github/workflows/build-latest.yml rename to .github/workflows/build-and-deploy-latest.yml index 75dca0296..b8f7414c0 100644 --- a/.github/workflows/build-latest.yml +++ b/.github/workflows/build-and-deploy-latest.yml @@ -29,3 +29,21 @@ jobs: platforms: linux/amd64,linux/arm64 cache-from: type=gha cache-to: type=gha,mode=max + + deploy: + runs-on: ubuntu-latest + needs: build-and-push + steps: + - name: Install doctl and authenticate + run: | + sudo snap install doctl jq + doctl auth init -t $DIGITALOCEAN_TOKEN + env: + DIGITALOCEAN_TOKEN: ${{ secrets.DIGITALOCEAN_TOKEN }} + + - name: Trigger deployment on DigitalOcean App Platform + run: | + app_id=$(doctl apps list --output json| jq '.[] | select(.spec.name == "hushline-staging") | .id' -r) + doctl apps create-deployment $app_id --force-rebuild --wait + env: + DIGITALOCEAN_APP_ID: ${{ secrets.DIGITALOCEAN_APP_ID }} diff --git a/.gitignore b/.gitignore index ac297b095..8d8138476 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,4 @@ hushline/static/data/users_directory.json /node_modules .coverage htmlcov +.env.stripe \ No newline at end of file diff --git a/.prettierignore b/.prettierignore index 95a19ccfb..ffc67e7cd 100644 --- a/.prettierignore +++ b/.prettierignore @@ -4,3 +4,4 @@ coverage hushline/static/vendor/* hushline/templates/* .pytest_cache +.venv diff --git a/Dockerfile.dev b/Dockerfile.dev index 488f8055f..8bf4440b8 100644 --- a/Dockerfile.dev +++ b/Dockerfile.dev @@ -34,3 +34,6 @@ ENV PYTHONPATH=/app COPY pyproject.toml poetry.lock . RUN poetry install + +ENV FLASK_APP="hushline" +CMD ["./scripts/dev_start.sh"] diff --git a/Dockerfile.prod b/Dockerfile.prod index d5a0ea5bd..cb74b674c 100644 --- a/Dockerfile.prod +++ b/Dockerfile.prod @@ -29,4 +29,4 @@ EXPOSE 8080 # Run! ENV FLASK_APP="hushline" -CMD ["./start.sh"] +CMD ["./scripts/prod_start.sh"] diff --git a/Makefile b/Makefile index 89ba610ba..52b775b60 100644 --- a/Makefile +++ b/Makefile @@ -23,6 +23,10 @@ migrate-dev: ## Run dev env migrations migrate-prod: ## Run prod env (alembic) migrations poetry run flask db upgrade +.PHONY: dev-data +dev-data: migrate-dev ## Run dev env migrations, and add dev data + poetry run ./scripts/dev_data.py + .PHONY: lint lint: ## Lint the code poetry run ruff format --check && \ diff --git a/dev_env.sh b/dev_env.sh index 9253e815b..f1ac4800e 100644 --- a/dev_env.sh +++ b/dev_env.sh @@ -5,3 +5,9 @@ export SQLALCHEMY_DATABASE_URI=postgresql://hushline:hushline@127.0.0.1:5432/hus export REGISTRATION_CODES_REQUIRED=False export SESSION_COOKIE_NAME=session export NOTIFICATIONS_ADDRESS=notifications@hushline.app + +# Stripe +export STRIPE_PUBLISHABLE_KEY=pk_test_51OhDeALcBPqjxU07I70UA6JYGDPUmkxEwZW0lvGyNXGlJ4QPfWIBFZJau7XOb3QDzDWrVutBVkz9SNrSjq2vRawm00TwfyFuma +# set these manually: +# export STRIPE_SECRET_KEY= +# export STRIPE_WEBHOOK_SECRET= diff --git a/docker-compose.stripe.yaml b/docker-compose.stripe.yaml new file mode 100644 index 000000000..49ae138b7 --- /dev/null +++ b/docker-compose.stripe.yaml @@ -0,0 +1,49 @@ +--- +services: + app: &app_env + build: + context: . + dockerfile: Dockerfile.dev + ports: + - 127.0.0.1:8080:8080 + environment: + FLASK_APP: hushline + FLASK_ENV: development + ENCRYPTION_KEY: bi5FDwhZGKfc4urLJ_ChGtIAaOPgxd3RDOhnvct10mw= + SECRET_KEY: cb3f4afde364bfb3956b97ca22ef4d2b593d9d980a4330686267cabcd2c0befd + SQLALCHEMY_DATABASE_URI: postgresql://hushline:hushline@postgres:5432/hushline + REGISTRATION_CODES_REQUIRED: False + SESSION_COOKIE_NAME: session + NOTIFICATIONS_ADDRESS: notifications@hushline.app + env_file: + - .env.stripe + volumes: + - ./:/app + depends_on: + - dev_data + restart: always + + worker: + <<: *app_env + ports: [] + restart: always + command: poetry run flask stripe start-worker + depends_on: + - app + + dev_data: + <<: *app_env + ports: [] + restart: on-failure + command: make dev-data + depends_on: + - postgres + + postgres: + image: postgres:16.4-alpine3.20 + environment: + POSTGRES_USER: hushline + POSTGRES_PASSWORD: hushline + POSTGRES_DB: hushline + ports: + - 127.0.0.1:5432:5432 diff --git a/docker-compose.yaml b/docker-compose.yaml index b5e9e4c53..efa3ef3c7 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -18,15 +18,16 @@ services: volumes: - ./:/app depends_on: - - postgres + - dev_data restart: always - command: poetry run flask run --debug --host=0.0.0.0 --port=8080 --with-threads dev_data: <<: *app_env ports: [] restart: on-failure - command: make migrate-dev && ./scripts/dev_data.py + command: make dev-data + depends_on: + - postgres postgres: image: postgres:16.4-alpine3.20 @@ -34,3 +35,5 @@ services: POSTGRES_USER: hushline POSTGRES_PASSWORD: hushline POSTGRES_DB: hushline + ports: + - 127.0.0.1:5432:5432 diff --git a/hushline/__init__.py b/hushline/__init__.py index 3bda63293..c7c1fe46c 100644 --- a/hushline/__init__.py +++ b/hushline/__init__.py @@ -1,18 +1,20 @@ +import asyncio import logging import os from datetime import timedelta from typing import Any from flask import Flask, flash, redirect, request, session, url_for +from flask.cli import AppGroup from flask_migrate import Migrate from jinja2 import StrictUndefined from sqlalchemy.exc import ProgrammingError from werkzeug.middleware.proxy_fix import ProxyFix from werkzeug.wrappers.response import Response -from . import admin, routes, settings +from . import admin, premium, routes, settings from .db import db -from .model import HostOrganization, User +from .model import HostOrganization, Tier, User from .version import __version__ @@ -58,6 +60,9 @@ def create_app() -> Flask: app.config["SMTP_PASSWORD"] = os.environ.get("SMTP_PASSWORD", None) app.config["SMTP_ENCRYPTION"] = os.environ.get("SMTP_ENCRYPTION", "StartTLS") app.config["REQUIRE_PGP"] = os.environ.get("REQUIRE_PGP", "False").lower() == "true" + app.config["STRIPE_PUBLISHABLE_KEY"] = os.environ.get("STRIPE_PUBLISHABLE_KEY", None) + app.config["STRIPE_SECRET_KEY"] = os.environ.get("STRIPE_SECRET_KEY", None) + app.config["STRIPE_WEBHOOK_SECRET"] = os.environ.get("STRIPE_WEBHOOK_SECRET", None) # Handle the tips domain for profile verification app.config["SERVER_NAME"] = os.getenv("SERVER_NAME") @@ -78,10 +83,18 @@ def create_app() -> Flask: db.init_app(app) Migrate(app, db) + # Initialize Stripe + if app.config["STRIPE_SECRET_KEY"]: + with app.app_context(): + premium.init_stripe() + routes.init_app(app) for module in [admin, settings]: app.register_blueprint(module.create_blueprint()) + if app.config["STRIPE_SECRET_KEY"]: + app.register_blueprint(premium.create_blueprint(app)) + @app.errorhandler(404) def page_not_found(e: Exception) -> Response: flash("⛓️‍💥 That page doesn't exist.", "warning") @@ -102,6 +115,10 @@ def inject_host() -> dict[str, HostOrganization]: def inject_is_personal_server() -> dict[str, Any]: return {"is_personal_server": app.config["IS_PERSONAL_SERVER"]} + @app.context_processor + def inject_is_premium_enabled() -> dict[str, Any]: + return {"is_premium_enabled": bool(app.config.get("STRIPE_SECRET_KEY", False))} + # Add Onion-Location header to all responses if app.config["ONION_HOSTNAME"]: @@ -112,6 +129,9 @@ def add_onion_location_header(response: Response) -> Response: ) return response + # Register custom CLI commands + register_commands(app) + # we can't if app.config.get("FLASK_ENV", None) != "development": with app.app_context(): @@ -126,3 +146,43 @@ def add_onion_location_header(response: Response) -> Response: app.logger.warning("HostOrganization data not found in database.") return app + + +def register_commands(app: Flask) -> None: + stripe_cli = AppGroup("stripe") + + @stripe_cli.command("configure") + def configure() -> None: + """Configure Stripe and premium tiers""" + # Make sure tiers exist + with app.app_context(): + free_tier = Tier.free_tier() + if not free_tier: + free_tier = Tier(name="Free", monthly_amount=0) + db.session.add(free_tier) + db.session.commit() + business_tier = Tier.business_tier() + if not business_tier: + business_tier = Tier(name="Business", monthly_amount=2000) + db.session.add(business_tier) + db.session.commit() + + # Configure Stripe + if app.config["STRIPE_SECRET_KEY"]: + with app.app_context(): + premium.init_stripe() + premium.create_products_and_prices() + else: + app.logger.info("Skipping Stripe configuration because STRIPE_SECRET_KEY is not set") + + @stripe_cli.command("start-worker") + def start_worker() -> None: + """Start the Stripe worker""" + if not app.config["STRIPE_SECRET_KEY"]: + app.logger.error("Cannot start the Stripe worker without a STRIPE_SECRET_KEY") + return + + with app.app_context(): + asyncio.run(premium.worker(app)) + + app.cli.add_command(stripe_cli) diff --git a/hushline/admin.py b/hushline/admin.py index 6a0451d0a..d07ae6b20 100644 --- a/hushline/admin.py +++ b/hushline/admin.py @@ -1,8 +1,9 @@ -from flask import Blueprint, abort, flash, redirect, url_for +from flask import Blueprint, abort, flash, redirect, request, url_for from werkzeug.wrappers.response import Response from .db import db -from .model import User +from .model import Tier, User +from .premium import update_price from .utils import admin_authentication_required @@ -31,4 +32,37 @@ def toggle_admin(user_id: int) -> Response: flash("✅ User admin status toggled.", "success") return redirect(url_for("settings.index")) + @bp.route("/update_tier/", methods=["POST"]) + @admin_authentication_required + def update_tier(tier_id: int) -> Response: + tier = db.session.get(Tier, tier_id) + if tier is None: + abort(404) + + # Get monthly_price from the request + monthly_price = request.form.get("monthly_price") + if not monthly_price: + flash("❌ Monthly price is required.", "danger") + return redirect(url_for("settings.index")) + + # Convert the monthly_price to a float + try: + monthly_price_number = float(monthly_price) + except ValueError: + flash("❌ Monthly price must be a number.", "danger") + return redirect(url_for("settings.index")) + + # Convert to cents + monthly_amount = int(monthly_price_number * 100) + + # Update in the database + tier.monthly_amount = monthly_amount + db.session.commit() + + # Update in stripe + update_price(tier) + + flash("✅ Price updated.", "success") + return redirect(url_for("settings.index")) + return bp diff --git a/hushline/model.py b/hushline/model.py index ddc264eff..d6692ca15 100644 --- a/hushline/model.py +++ b/hushline/model.py @@ -6,8 +6,10 @@ from flask_sqlalchemy.model import Model from passlib.hash import scrypt +from sqlalchemy import Enum as SQLAlchemyEnum from sqlalchemy import Index from sqlalchemy.orm import Mapped, mapped_column, relationship +from stripe import Event, Invoice from .crypto import decrypt_field, encrypt_field from .db import db @@ -75,6 +77,35 @@ class ExtraField: is_verified: Optional[bool] +@enum.unique +class StripeInvoiceStatusEnum(enum.Enum): + DRAFT = "draft" + OPEN = "open" + PAID = "paid" + UNCOLLECTIBLE = "uncollectible" + VOID = "void" + + +@enum.unique +class StripeSubscriptionStatusEnum(enum.Enum): + INCOMPLETE = "incomplete" + INCOMPLETE_EXPIRED = "incomplete_expired" + TRIALING = "trialing" + ACTIVE = "active" + PAST_DUE = "past_due" + CANCELED = "canceled" + UNPAID = "unpaid" + PAUSED = "paused" + + +@enum.unique +class StripeEventStatusEnum(enum.Enum): + PENDING = "pending" + IN_PROGRESS = "in_progress" + ERROR = "error" + FINISHED = "finished" + + class Username(Model): """ Class representing a username and associated profile. @@ -94,6 +125,7 @@ class Username(Model): show_in_directory: Mapped[bool] = mapped_column(default=False) bio: Mapped[Optional[str]] = mapped_column(db.Text) + # Extra fields extra_field_label1: Mapped[Optional[str]] extra_field_value1: Mapped[Optional[str]] extra_field_label2: Mapped[Optional[str]] @@ -188,6 +220,23 @@ class User(Model): ) smtp_sender: Mapped[Optional[str]] + # Paid tier fields + tier_id: Mapped[int | None] = mapped_column(db.ForeignKey("tiers.id"), nullable=True) + tier: Mapped["Tier"] = relationship(backref=db.backref("tiers", lazy=True)) + + stripe_customer_id = mapped_column(db.String(255)) + stripe_subscription_id = mapped_column(db.String(255), nullable=True) + stripe_subscription_cancel_at_period_end = mapped_column(db.Boolean, default=False) + stripe_subscription_status: Mapped[Optional[StripeSubscriptionStatusEnum]] = mapped_column( + SQLAlchemyEnum(StripeSubscriptionStatusEnum) + ) + stripe_subscription_current_period_end = mapped_column( + db.DateTime(timezone=True), nullable=True + ) + stripe_subscription_current_period_start = mapped_column( + db.DateTime(timezone=True), nullable=True + ) + @property def password_hash(self) -> str: """Return the hashed password.""" @@ -256,6 +305,20 @@ def pgp_key(self, value: str) -> None: else: self._pgp_key = encrypt_field(value) + @property + def is_free_tier(self) -> bool: + return self.tier_id is None or self.tier_id == Tier.free_tier_id() + + @property + def is_business_tier(self) -> bool: + return self.tier_id == Tier.business_tier_id() + + def set_free_tier(self) -> None: + self.tier_id = Tier.free_tier_id() + + def set_business_tier(self) -> None: + self.tier_id = Tier.business_tier_id() + def __init__(self, **kwargs: Any) -> None: for key in ["password_hash", "_password_hash"]: if key in kwargs: @@ -349,3 +412,112 @@ def __init__(self) -> None: def __repr__(self) -> str: return f"" + + +# Paid tiers +class Tier(Model): + __tablename__ = "tiers" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(db.String(255), unique=True) + monthly_amount: Mapped[int] = mapped_column(db.Integer) # in cents USD + stripe_product_id: Mapped[Optional[str]] = mapped_column(db.String(255), unique=True) + stripe_price_id: Mapped[Optional[str]] = mapped_column(db.String(255), unique=True) + + def __init__(self, name: str, monthly_amount: int) -> None: + super().__init__() + self.name = name + self.monthly_amount = monthly_amount + + @staticmethod + def free_tier_id() -> int: + return 1 + + @staticmethod + def business_tier_id() -> int: + return 2 + + @staticmethod + def free_tier() -> Self | None: # type: ignore + return db.session.get(Tier, Tier.free_tier_id()) # type: ignore + + @staticmethod + def business_tier() -> Self | None: # type: ignore + return db.session.get(Tier, Tier.business_tier_id()) # type: ignore + + +class StripeEvent(Model): + __tablename__ = "stripe_events" + + id: Mapped[int] = mapped_column(primary_key=True) + event_id: Mapped[str] = mapped_column(db.String(255), unique=True, index=True) + event_type: Mapped[str] = mapped_column(db.String(255)) + event_created: Mapped[int] = mapped_column(db.Integer) + event_data: Mapped[str] = mapped_column(db.Text) + status: Mapped[Optional[StripeEventStatusEnum]] = mapped_column( + SQLAlchemyEnum(StripeEventStatusEnum), default=StripeEventStatusEnum.PENDING + ) + error_message: Mapped[Optional[str]] = mapped_column(db.Text) + + def __init__(self, event: Event, **kwargs: dict[str, Any]) -> None: + super().__init__(**kwargs) + self.event_id = event.id + self.event_created = event.created + self.event_type = event.type + self.event_data = str(event) + + +class StripeInvoice(Model): + __tablename__ = "stripe_invoices" + + id: Mapped[int] = mapped_column(primary_key=True) + customer_id: Mapped[str] = mapped_column(db.String(255)) + invoice_id: Mapped[str] = mapped_column(db.String(255), unique=True, index=True) + hosted_invoice_url: Mapped[str] = mapped_column(db.String(2048)) + total: Mapped[int] = mapped_column(db.Integer) + status: Mapped[Optional[StripeInvoiceStatusEnum]] = mapped_column( + SQLAlchemyEnum(StripeInvoiceStatusEnum) + ) + created_at: Mapped[datetime] = mapped_column(default=datetime.now) + + user_id: Mapped[int] = mapped_column(db.ForeignKey("users.id")) + tier_id: Mapped[int] = mapped_column(db.ForeignKey("tiers.id")) + + def __init__(self, invoice: Invoice) -> None: + if invoice.id: + self.invoice_id = invoice.id + if invoice.customer and isinstance(invoice.customer, str): + self.customer_id = invoice.customer + if invoice.hosted_invoice_url: + self.hosted_invoice_url = invoice.hosted_invoice_url + if invoice.total: + self.total = invoice.total + else: + self.total = 0 + if invoice.status: + self.status = StripeInvoiceStatusEnum(invoice.status) + if invoice.created: + self.created_at = datetime.fromtimestamp(invoice.created, tz=timezone.utc) + + # Look up the user by their customer ID + user = db.session.scalars( + db.select(User).filter_by(stripe_customer_id=invoice.customer) + ).one_or_none() + if user: + self.user_id = user.id + else: + raise ValueError(f"Could not find user with customer ID {invoice.customer}") + + # Look up the tier by the product_id + if invoice.lines.data[0].plan: + product_id = invoice.lines.data[0].plan.product + + tier = db.session.scalars( + db.select(Tier).filter_by(stripe_product_id=product_id) + ).one_or_none() + if tier: + self.tier_id = tier.id + else: + raise ValueError(f"Could not find tier with product ID {product_id}") + else: + raise ValueError("Invoice does not have a plan") diff --git a/hushline/premium.py b/hushline/premium.py new file mode 100644 index 000000000..78ead1af7 --- /dev/null +++ b/hushline/premium.py @@ -0,0 +1,607 @@ +import asyncio +import json +from datetime import datetime +from typing import Tuple + +import stripe +from flask import ( + Blueprint, + Flask, + abort, + current_app, + flash, + jsonify, + redirect, + render_template, + request, + session, + url_for, +) +from werkzeug.wrappers.response import Response + +from .db import db +from .model import ( + StripeEvent, + StripeEventStatusEnum, + StripeInvoice, + StripeInvoiceStatusEnum, + StripeSubscriptionStatusEnum, + Tier, + User, +) +from .utils import authentication_required + + +def init_stripe() -> None: + stripe.api_key = current_app.config["STRIPE_SECRET_KEY"] + + +def create_products_and_prices() -> None: + current_app.logger.info("Creating products and prices") + + # Make sure the products and prices are created in Stripe + business_tier = Tier.business_tier() + if not business_tier: + current_app.logger.error("Could not find business tier") + return + + # Check if the product exists in the db + create_product = False + if business_tier.stripe_product_id is None: + create_product = True + else: + try: + stripe_product = stripe.Product.retrieve(business_tier.stripe_product_id) + except stripe._error.InvalidRequestError: + create_product = True + + if create_product: + # Do we already have a product in Stripe? + found = False + stripe_products = stripe.Product.list(limit=100) + for stripe_product in stripe_products: + if stripe_product.name == business_tier.name: + current_app.logger.info(f"Found Stripe product for tier: {business_tier.name}") + found = True + business_tier.stripe_product_id = stripe_product.id + db.session.add(business_tier) + db.session.commit() + break + + # Create a product if we didn't find one + if not found: + current_app.logger.info(f"Creating Stripe product for tier: {business_tier.name}") + stripe_product = stripe.Product.create( + name=business_tier.name, + type="service", + tax_code="txcd_10103001", # Software as a service (SaaS) - business use + ) + business_tier.stripe_product_id = stripe_product.id + db.session.add(business_tier) + db.session.commit() + else: + current_app.logger.info(f"Product already exists for tier: {business_tier.name}") + + # Check if the price exists + create_price = False + if business_tier.stripe_price_id is None: + create_price = True + else: + try: + price = stripe.Price.retrieve(business_tier.stripe_price_id) + except stripe._error.InvalidRequestError: + create_price = True + + if create_price: + # Do we already have a price in Stripe? + found = False + if stripe_product.default_price: + try: + stripe_price = stripe.Price.retrieve(str(stripe_product.default_price)) + current_app.logger.info(f"Found Stripe price for tier: {business_tier.name}") + business_tier.stripe_price_id = stripe_price.id + if stripe_price.unit_amount: + business_tier.monthly_amount = stripe_price.unit_amount + db.session.add(business_tier) + db.session.commit() + found = True + except stripe._error.InvalidRequestError: + found = False + + # Create a price if we didn't find one + if not found: + current_app.logger.info(f"Creating price for tier: {business_tier.name}") + price = stripe.Price.create( + product=stripe_product.id, + unit_amount=business_tier.monthly_amount, + currency="usd", + recurring={"interval": "month"}, + ) + business_tier.stripe_price_id = price.id + db.session.add(business_tier) + db.session.commit() + else: + current_app.logger.info(f"Price already exists for tier: {business_tier.name}") + + +def update_price(tier: Tier) -> None: + current_app.logger.info(f"Updating price for tier {tier.name} to {tier.monthly_amount}") + + if not tier.stripe_product_id: + current_app.logger.error(f"Tier {tier.name} does not have a product ID") + return + + # See if we already have an appropriate price for this product + prices = stripe.Price.search(query=f'product:"{tier.stripe_product_id}"') + found_price_id = None + for price in prices: + if price.unit_amount == tier.monthly_amount: + found_price_id = price.id + break + + # If we found it, use it + if found_price_id is not None: + tier.stripe_price_id = found_price_id + db.session.add(tier) + db.session.commit() + + stripe.Product.modify(tier.stripe_product_id, default_price=found_price_id) + return + + # Otherwise, create a new price + price = stripe.Price.create( + product=tier.stripe_product_id, + unit_amount=tier.monthly_amount, + currency="usd", + recurring={"interval": "month"}, + ) + tier.stripe_price_id = price.id + db.session.add(tier) + db.session.commit() + + stripe.Product.modify(tier.stripe_product_id, default_price=price.id) + + +def create_customer(user: User) -> stripe.Customer: + email: str = user.email if user.email is not None else "" + + if user.stripe_customer_id is not None: + try: + return stripe.Customer.modify(user.stripe_customer_id, email=email) + except stripe._error.InvalidRequestError: + user.stripe_customer_id = None + + stripe_customer = stripe.Customer.create(email=email) + user.stripe_customer_id = stripe_customer.id + db.session.add(user) + db.session.commit() + return stripe_customer + + +def get_subscription(user: User) -> stripe.Subscription | None: + if user.stripe_subscription_id is None: + return None + + return stripe.Subscription.retrieve(user.stripe_subscription_id) + + +def get_business_price_string() -> str: + business_tier = Tier.business_tier() + if not business_tier: + current_app.logger.error("Could not find business tier") + return "NA" + + business_price = f"{business_tier.monthly_amount / 100:.2f}" + if business_price.endswith(".00"): + business_price = business_price[:-3] + elif business_price.endswith("0"): + business_price = business_price[:-1] + + return business_price + + +def handle_subscription_created(subscription: stripe.Subscription) -> None: + # customer.subscription.created + + user = db.session.scalars( + db.select(User).filter_by(stripe_customer_id=subscription.customer) + ).one_or_none() + if user: + user.stripe_subscription_id = subscription.id + user.stripe_subscription_status = StripeSubscriptionStatusEnum(subscription.status) + user.stripe_subscription_cancel_at_period_end = subscription.cancel_at_period_end + user.stripe_subscription_current_period_end = datetime.fromtimestamp( + subscription.current_period_end + ) + user.stripe_subscription_current_period_start = datetime.fromtimestamp( + subscription.current_period_start + ) + db.session.commit() + else: + raise ValueError(f"Could not find user with customer ID {subscription.customer}") + + +def handle_subscription_updated(subscription: stripe.Subscription) -> None: + # customer.subscription.updated + + # If subscription changes to cancel or unpaid, downgrade user + user = db.session.scalars( + db.select(User).filter_by(stripe_subscription_id=subscription.id) + ).one_or_none() + if user: + user.stripe_subscription_status = StripeSubscriptionStatusEnum(subscription.status) + user.stripe_subscription_cancel_at_period_end = subscription.cancel_at_period_end + user.stripe_subscription_current_period_end = datetime.fromtimestamp( + subscription.current_period_end + ) + user.stripe_subscription_current_period_start = datetime.fromtimestamp( + subscription.current_period_start + ) + + current_app.logger.info("status is: " + subscription.status) + if subscription.status in ["active", "trialing"]: + user.set_business_tier() + else: + user.set_free_tier() + + db.session.commit() + else: + raise ValueError(f"Could not find user with subscription ID {subscription.id}") + + +def handle_subscription_deleted(subscription: stripe.Subscription) -> None: + # customer.subscription.deleted + + user = db.session.scalars( + db.select(User).filter_by(stripe_subscription_id=subscription.id) + ).one_or_none() + if user: + user.set_free_tier() + user.stripe_subscription_id = None + user.stripe_subscription_status = None + user.stripe_subscription_cancel_at_period_end = None + user.stripe_subscription_current_period_end = None + user.stripe_subscription_current_period_start = None + db.session.commit() + else: + raise ValueError(f"Could not find user with subscription ID {subscription.id}") + + +def handle_invoice_created(invoice: stripe.Invoice) -> None: + # invoice.created + + try: + new_invoice = StripeInvoice(invoice) + except ValueError as e: + current_app.logger.error(f"Error creating invoice: {e}") + + db.session.add(new_invoice) + db.session.commit() + + +def handle_invoice_updated(invoice: stripe.Invoice) -> None: + # invoice.updated + + stripe_invoice = db.session.scalars( + db.select(StripeInvoice).filter_by(invoice_id=invoice.id) + ).one_or_none() + if stripe_invoice: + stripe_invoice.total = invoice.total + stripe_invoice.status = StripeInvoiceStatusEnum(invoice.status) + db.session.commit() + else: + raise ValueError(f"Could not find invoice with ID {invoice.id}") + + +async def worker(app: Flask) -> None: + current_app.logger.error("Starting worker") + with app.app_context(): + while True: + with db.session.begin() as transaction: # Start a transaction block + stripe_event = transaction.session.scalars( + db.select(StripeEvent) + .filter_by(status=StripeEventStatusEnum.PENDING) + .order_by(StripeEvent.event_created.asc()) + .with_for_update() + .limit(1) + ).one_or_none() + + if stripe_event: + stripe_event.status = StripeEventStatusEnum.IN_PROGRESS + transaction.session.add(stripe_event) + transaction.session.commit() + else: + await asyncio.sleep(2) + continue + + event_json = json.loads(stripe_event.event_data) + event = stripe.Event.construct_from(event_json, current_app.config["STRIPE_SECRET_KEY"]) + + current_app.logger.info( + f"Processing event {stripe_event.event_type} ({stripe_event.event_id})" + ) + try: + # subscription events + if event.type.startswith("customer.subscription."): + subscription: stripe.Subscription = stripe.Subscription.construct_from( + event.data.object, current_app.config["STRIPE_SECRET_KEY"] + ) + if event.type == "customer.subscription.created": + handle_subscription_created(subscription) + elif event.type == "customer.subscription.updated": + handle_subscription_updated(subscription) + elif event.type == "customer.subscription.deleted": + handle_subscription_deleted(subscription) + # invoice events + elif event.type.startswith("invoice."): + invoice: stripe.Invoice = stripe.Invoice.construct_from( + event.data.object, current_app.config["STRIPE_SECRET_KEY"] + ) + if event.type == "invoice.created": + handle_invoice_created(invoice) + elif event.type == "invoice.updated": + handle_invoice_updated(invoice) + + except Exception as e: + current_app.logger.error( + f"Error processing event {stripe_event.event_type} ({stripe_event.event_id}): {e}\n{stripe_event.event_data}", # noqa: E501 + exc_info=True, + ) + stripe_event.status = StripeEventStatusEnum.ERROR + stripe_event.error_message = str(e) + db.session.add(stripe_event) + db.session.commit() + continue + + stripe_event.status = StripeEventStatusEnum.FINISHED + db.session.add(stripe_event) + db.session.commit() + + +def create_blueprint(app: Flask) -> Blueprint: + # Now define the blueprint + bp = Blueprint("premium", __file__, url_prefix="/premium") + + @bp.route("/") + @authentication_required + def index() -> Response | str: + user = db.session.get(User, session.get("user_id")) + if not user: + session.clear() + return redirect(url_for("login")) + + # Check if we have an incomplete subscription + stripe_subscription = get_subscription(user) + if stripe_subscription and stripe_subscription["status"] == "incomplete": + flash("⚠️ Your subscription is incomplete. Please try again.", "warning") + + # Load the user's invoices + invoices = db.session.scalars( + db.select(StripeInvoice) + .filter_by(user_id=user.id) + .filter_by(status=StripeInvoiceStatusEnum.PAID) + .order_by(StripeInvoice.created_at.desc()) + ).all() + + return render_template( + "premium.html", user=user, invoices=invoices, business_price=get_business_price_string() + ) + + @bp.route("/select-tier") + @authentication_required + def select_tier() -> Response | str: + user = db.session.get(User, session.get("user_id")) + if not user: + session.clear() + return redirect(url_for("login")) + + return render_template( + "premium-select-tier.html", user=user, business_price=get_business_price_string() + ) + + @bp.route("/select-tier/free", methods=["POST"]) + @authentication_required + def select_free() -> Response | str: + user = db.session.get(User, session.get("user_id")) + if not user: + session.clear() + return redirect(url_for("login")) + + if user.tier_id is None: + user.set_free_tier() + db.session.add(user) + db.session.commit() + + return redirect(url_for("inbox")) + + @bp.route("/waiting") + @authentication_required + def waiting() -> Response | str: + return render_template("premium-waiting.html") + + @bp.route("/upgrade", methods=["POST"]) + @authentication_required + def upgrade() -> Response | str: + user = db.session.get(User, session.get("user_id")) + if not user: + session.clear() + return redirect(url_for("login")) + + # If the user is already on the business tier + if user.is_business_tier: + flash("👍 You're already upgraded.") + return redirect(url_for("premium.index")) + + # Select the business tier + business_tier = Tier.business_tier() + if not business_tier: + current_app.logger.error("Could not find business tier") + flash("⚠️ Something went wrong!") + return redirect(url_for("premium.index")) + if not business_tier.stripe_price_id: + current_app.logger.error("Business tier does not have a price ID") + flash("⚠️ Something went wrong!") + return redirect(url_for("premium.index")) + + # Make sure the user has a Stripe customer + try: + create_customer(user) + except stripe._error.StripeError as e: + current_app.logger.error(f"Failed to create Stripe customer: {e}", exc_info=True) + flash("⚠️ Something went wrong!") + return redirect(url_for("premium.index")) + + # Create a Stripe Checkout session + try: + checkout_session = stripe.checkout.Session.create( + client_reference_id=str(user.id), + customer=user.stripe_customer_id, + line_items=[{"price": business_tier.stripe_price_id, "quantity": 1}], + mode="subscription", + success_url=url_for("premium.waiting", _external=True), + automatic_tax={"enabled": True}, + customer_update={"address": "auto"}, + ) + except stripe._error.StripeError as e: + current_app.logger.error( + f"Failed to create Stripe Checkout session: {e}", exc_info=True + ) + return abort(500) + + if checkout_session.url: + return redirect(checkout_session.url) + + return abort(500) + + @bp.route("/disable-autorenew", methods=["POST"]) + @authentication_required + def disable_autorenew() -> Response | str | Tuple[Response | str, int]: + user = db.session.get(User, session.get("user_id")) + if not user: + session.clear() + return redirect(url_for("login")) + + if user.stripe_subscription_id: + try: + stripe.Subscription.modify(user.stripe_subscription_id, cancel_at_period_end=True) + except stripe._error.StripeError as e: + current_app.logger.error(f"Stripe error: {e}", exc_info=True) + return jsonify(success=False), 400 + else: + return jsonify(success=False), 400 + + user.stripe_subscription_cancel_at_period_end = True + db.session.add(user) + db.session.commit() + + current_app.logger.info( + f"Autorenew disabled for subscription {user.stripe_subscription_id} for user {user.id}" + ) + + flash("Autorenew has been disabled.") + return jsonify(success=True) + + @bp.route("/enable-autorenew", methods=["POST"]) + @authentication_required + def enable_autorenew() -> Response | str | Tuple[Response | str, int]: + user = db.session.get(User, session.get("user_id")) + if not user: + session.clear() + return redirect(url_for("login")) + + if user.stripe_subscription_id: + try: + stripe.Subscription.modify(user.stripe_subscription_id, cancel_at_period_end=False) + except stripe._error.StripeError as e: + current_app.logger.error(f"Stripe error: {e}", exc_info=True) + return jsonify(success=False), 400 + + user.stripe_subscription_cancel_at_period_end = False + db.session.add(user) + db.session.commit() + + current_app.logger.info( + f"Autorenew enabled for subscription {user.stripe_subscription_id} for user {user.id}" # noqa: E501 + ) + + flash("Autorenew has been enabled.") + return jsonify(success=True) + + return jsonify(success=False), 400 + + @bp.route("/cancel", methods=["POST"]) + @authentication_required + def cancel() -> Response | str | Tuple[Response | str, int]: + user = db.session.get(User, session.get("user_id")) + if not user: + session.clear() + return redirect(url_for("login")) + + if user.stripe_subscription_id: + try: + # Cancel the subscription + stripe.Subscription.delete(user.stripe_subscription_id) + + # Downgrade the user (the subscription ID will get removed in the webhook) + user.set_free_tier() + db.session.add(user) + db.session.commit() + + current_app.logger.info( + f"Subscription {user.stripe_subscription_id} canceled for user {user.id}" + ) + + flash("💔 Sorry to see you go!") + return jsonify(success=True) + except stripe._error.StripeError as e: + current_app.logger.error(f"Stripe error: {e}", exc_info=True) + return jsonify(success=False), 400 + + return jsonify(success=False), 400 + + @bp.route("/status.json") + @authentication_required + def status() -> Response | str: + user = db.session.get(User, session.get("user_id")) + if not user: + session.clear() + return redirect(url_for("login")) + + if user.is_business_tier: + flash("🔥 Congratulations, you've upgraded your account!") + + return jsonify({"tier_id": user.tier_id}) + + @bp.route("/webhook", methods=["POST"]) + def webhook() -> Response | str | Tuple[Response | str, int]: + sig_header = request.headers["STRIPE_SIGNATURE"] + + # Parse the event + try: + event = stripe.Webhook.construct_event( + request.data, sig_header, current_app.config.get("STRIPE_WEBHOOK_SECRET") + ) + except ValueError as e: + current_app.logger.error(f"Invalid payload: {e}") + return jsonify(success=False), 400 + except stripe._error.SignatureVerificationError as e: + current_app.logger.error(f"Error verifying webhook signature: {e}") + return jsonify(success=False), 400 + + # Have we seen this one before? + stripe_event = db.session.scalars( + db.select(StripeEvent).filter_by(event_id=event.id) + ).one_or_none() + if stripe_event: + current_app.logger.info(f"Event already seen: {event}") + return jsonify(success=True) + + # Log it + current_app.logger.info(f"Received event: {event.type}") + stripe_event = StripeEvent(event) + db.session.add(stripe_event) + db.session.commit() + + return jsonify(success=True) + + return bp diff --git a/hushline/routes.py b/hushline/routes.py index d9df24c15..d6a18fcd0 100644 --- a/hushline/routes.py +++ b/hushline/routes.py @@ -403,6 +403,12 @@ def login() -> Response | str: db.session.add(auth_log) db.session.commit() + # If premium features are enabled, prompt the user to select a tier if they haven't + if app.config["STRIPE_SECRET_KEY"]: + user = db.session.get(User, username.user_id) + if user and user.tier_id is None: + return redirect(url_for("premium.select_tier")) + return redirect(url_for("inbox")) flash("⛔️ Invalid username or password") @@ -477,6 +483,11 @@ def verify_2fa_login() -> Response | str | tuple[Response | str, int]: db.session.commit() session["is_authenticated"] = True + + # If premium features are enabled, prompt the user to select a tier if they haven't + if app.config["STRIPE_SECRET_KEY"] and user.tier_id is None: + return redirect(url_for("premium.select_tier")) + return redirect(url_for("inbox")) auth_log = AuthenticationLog(user_id=user.id, successful=False) diff --git a/hushline/settings/__init__.py b/hushline/settings/__init__.py index 20d2f8a8c..d89181512 100644 --- a/hushline/settings/__init__.py +++ b/hushline/settings/__init__.py @@ -26,7 +26,7 @@ from ..crypto import is_valid_pgp_key from ..db import db from ..forms import TwoFactorForm -from ..model import HostOrganization, Message, SMTPEncryption, User, Username +from ..model import HostOrganization, Message, SMTPEncryption, Tier, User, Username from ..utils import ( admin_authentication_required, authentication_required, @@ -283,6 +283,16 @@ async def index() -> str | Response: ) user_count = len(all_users) + # Load the business tier price + business_tier = Tier.business_tier() + business_tier_display_price = "" + if business_tier: + price_usd = business_tier.monthly_amount / 100 + if price_usd % 1 == 0: + business_tier_display_price = str(int(price_usd)) + else: + business_tier_display_price = f"{price_usd:.2f}" + # Prepopulate form fields email_forwarding_form.forwarding_enabled.data = user.email is not None if not user.pgp_key: @@ -325,6 +335,8 @@ async def index() -> str | Response: pgp_key_percentage=pgp_key_percentage, directory_visibility_form=directory_visibility_form, default_forwarding_enabled=bool(current_app.config["NOTIFICATIONS_ADDRESS"]), + # Premium-specific data + business_tier_display_price=business_tier_display_price, ) @bp.route("/toggle-2fa", methods=["POST"]) diff --git a/hushline/static/css/style.css b/hushline/static/css/style.css index 27a58d735..cf4b41ebf 100644 --- a/hushline/static/css/style.css +++ b/hushline/static/css/style.css @@ -361,6 +361,11 @@ color: white; } + .upgrade { + background-color: var(--color-brand); + color: white; + } + @media screen and (max-width: 768px) { .dropdown-content a:hover { background-color: white; @@ -386,6 +391,14 @@ .icon.verifiedURL { background-image: url("../img/app/icon-verified-lm.png"); } + + .plan { + border: var(--border); + } + + #invoice-wrapper { + border: var(--border); + } } /* Dark Mode */ @@ -675,6 +688,16 @@ color: #333; } + .upgrade .btn { + border: 1px solid #333; + color: #333; + } + + .upgrade { + background-color: var(--color-brand-dark); + color: #333; + } + @media screen and (max-width: 768px) { .dropdown-content a:hover { background-color: var(--color-dark-bg-alt); @@ -701,9 +724,17 @@ background-image: url("../img/app/icon-verified-dm.png"); } + .plan { + border: var(--border-dark); + } + .icon.chevron { background-image: url("../img/app/icon-chevron-dm.png"); } + + #invoice-wrapper { + border: var(--border-dark); + } } body { @@ -788,6 +819,10 @@ h4 { margin-top: 0; } +.centered-heading { + text-align: center; +} + p { word-break: break-word; margin: 0.5rem 0; @@ -1061,6 +1096,11 @@ header nav li { font-size: 0.875rem; } +.nav-emoji { + position: absolute; + transform: translateY(-0.325rem) translateX(-1.5rem); +} + header nav a { text-decoration: none; } @@ -1908,6 +1948,11 @@ input#captcha_answer { .admin-highlights { gap: 0.5rem; } + + .nav-emoji { + position: absolute; + transform: translateY(-0.325rem) translateX(3.75rem); + } } .toggle-ui { @@ -2139,6 +2184,90 @@ p.bio + .extra-fields { margin-top: 0.5rem; } +#plan-wrapper { + display: flex; + flex-wrap: wrap; + gap: 1rem; + margin-top: 1rem; +} + +.plan { + flex: 1; + box-sizing: border-box; + padding: 20px; + border-radius: 5px; + text-align: center; + align-items: center; + display: flex; + flex-direction: column; +} + +.plan p { + margin: 0.25rem 0; +} + +.plan .badge { + margin-bottom: 1.5rem; + margin-top: 0.625rem; +} + +.plan .plan-status { + text-align: center; +} + +.plan-recommended .plan-status { + font-family: var(--font-sans-bold); +} + +#card-element { + margin-bottom: 1rem; + padding: 1rem; + border: 1px solid #666; +} + +.upgrade { + padding: 1rem; + margin-bottom: 2rem; + border-radius: 0.5rem; + display: flex; + justify-content: space-between; + align-items: center; + gap: 1rem; +} + +.upgrade .btn { + width: fit-content; + box-sizing: border-box; + text-align: center; + margin: 0; +} + +.upgrade h3 { + font-size: var(--font-size-4); + margin-bottom: 0.25rem; +} + +.plan h3 { + margin-bottom: 0.75rem; +} + +.upgrade p { + font-size: var(--font-size-small); + margin: 0; +} + +.feature-meta { + display: flex; + flex-direction: row; + justify-content: space-between; + width: 100%; + font-size: var(--font-size-small); +} + +.feature-meta p:first-of-type { + font-family: var(--font-sans-bold); +} + @media (max-width: 640px) { .extra-fields { flex-direction: column; @@ -2148,6 +2277,18 @@ p.bio + .extra-fields { .extra-field { width: 100%; } + + #plan-wrapper { + flex-direction: column; + } +} + +@media (max-width: 480px) { + .upgrade { + flex-direction: column; + align-items: start; + gap: 0.75rem; + } } .alias-list { @@ -2226,3 +2367,51 @@ p.bio + .extra-fields { .drill-in .checkbox-group { margin-bottom: 0; } + +#invoice-wrapper { + margin-top: 1rem; + padding: 1.5rem; + border-radius: 0.25rem; +} + +#invoice-wrapper table { + text-align: left; + padding: 0; + margin: 0; + width: 100%; +} + +#invoice-wrapper table tr { + width: 100%; +} + +#invoice-wrapper table th, +#invoice-wrapper table td { + padding-top: 0.325rem; + padding-bottom: 0.325rem; +} + +.adv-form { + display: flex; + margin-bottom: 2rem; + flex-direction: column; +} + +#enable-autorenew-form, +#cancel-form { + margin-bottom: 0; + margin-top: 0; +} + +#enable-autorenew-form button, +#cancel-form button { + margin-bottom: 0; +} + +#enable-autorenew-form { + margin-top: 1rem; +} + +.plan .sub-info { + margin-top: 1.25rem; +} diff --git a/hushline/static/js/premium-waiting.js b/hushline/static/js/premium-waiting.js new file mode 100644 index 000000000..605033afd --- /dev/null +++ b/hushline/static/js/premium-waiting.js @@ -0,0 +1,18 @@ +document.addEventListener("DOMContentLoaded", async function () { + const pathPrefix = window.location.pathname.split("/").slice(0, -1).join("/"); + + // Check for payment status every 2 seconds + setInterval(() => { + fetch(`${pathPrefix}/status.json`) + .then((response) => response.json()) + .then((data) => { + if (data.tier_id === 2) { + // Payment successful, redirect to premium home + window.location.href = pathPrefix; + } else { + console.log("Payment status not yet confirmed.", data); + } + }) + .catch((error) => console.error("Failed to load payment status:", error)); + }, 2000); +}); diff --git a/hushline/static/js/premium.js b/hushline/static/js/premium.js new file mode 100644 index 000000000..63662e197 --- /dev/null +++ b/hushline/static/js/premium.js @@ -0,0 +1,86 @@ +document.addEventListener("DOMContentLoaded", async function () { + const pathPrefix = window.location.pathname.split("/").slice(0, -1).join("/"); + const disableAutorenewForm = document.querySelector( + "#disable-autorenew-form", + ); + const enableAutorenewForm = document.querySelector("#enable-autorenew-form"); + const cancelForm = document.querySelector("#cancel-form"); + + if (disableAutorenewForm) { + disableAutorenewForm.addEventListener("submit", async (e) => { + e.preventDefault(); + + // Show confirmation dialog + const confirmed = confirm( + "Are you sure you want to not renew your subscription?", + ); + if (!confirmed) return; + + // Send disable autorenew request + const response = await fetch(`${pathPrefix}/disable-autorenew`, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + }); + + if (response.ok) { + window.location.reload(); // Reload the page to reflect changes + } else { + const errorData = await response.json(); + console.log("Error disabling autorenew subscription:", errorData); + alert("Error disabling autorenew. Please contact Science & Design."); + } + }); + } + + if (enableAutorenewForm) { + enableAutorenewForm.addEventListener("submit", async (e) => { + e.preventDefault(); + + // Send enable autorenew request + const response = await fetch(`${pathPrefix}/enable-autorenew`, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + }); + + if (response.ok) { + window.location.reload(); // Reload the page to reflect changes + } else { + const errorData = await response.json(); + console.log("Error enabling autorenew subscription:", errorData); + alert("Error enabling autorenew. Please contact Science & Design."); + } + }); + } + + if (cancelForm) { + cancelForm.addEventListener("submit", async (e) => { + e.preventDefault(); + + // Show confirmation dialog + const confirmed = confirm( + "Are you sure you want to cancel your subscription?", + ); + if (!confirmed) return; + + // Send downgrade request + const response = await fetch(`${pathPrefix}/cancel`, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + }); + + if (response.ok) { + window.location.reload(); // Reload the page to reflect changes + } else { + const errorData = await response.json(); + console.log("Error canceling subscription:", errorData); + alert("Error canceling subscription. Please contact Science & Design."); + } + }); + } +}); diff --git a/hushline/templates/base.html b/hushline/templates/base.html index 919e6f34d..19544f51f 100644 --- a/hushline/templates/base.html +++ b/hushline/templates/base.html @@ -106,6 +106,14 @@

{{ host_org.brand_app_name }}