-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
157 lines (111 loc) · 4.22 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
from typing import List
from fastapi import FastAPI, HTTPException, Depends
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, relationship, selectinload
from sqlalchemy import Column, Integer, String, Float, ForeignKey, DateTime
from sqlalchemy.future import select
import os
from dotenv import load_dotenv
load_dotenv()
db_config = {
"user": os.getenv('DB_USER'),
"password": os.getenv('DB_PASSWORD'),
"host": os.getenv('DB_HOST'),
"database": os.getenv('DB_DATABASE'),
}
DATABASE_URL = (f"postgresql+asyncpg://"
f"{db_config['user']}:{db_config['password']}@{db_config['host']}/{db_config['database']}")
Base = declarative_base()
engine = create_async_engine(DATABASE_URL, echo=True)
async_session = sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
class Customer(Base):
__tablename__ = 'myapp_customer'
id = Column(Integer, primary_key=True, index=True)
name = Column(String, index=True)
email = Column(String, index=True)
orders = relationship('Order', back_populates='customer')
class Product(Base):
__tablename__ = 'myapp_product'
id = Column(Integer, primary_key=True, index=True)
name = Column(String, index=True)
price = Column(Float)
class Order(Base):
__tablename__ = 'myapp_order'
id = Column(Integer, primary_key=True, index=True)
customer_id = Column(Integer, ForeignKey('myapp_customer.id'))
created_at = Column(DateTime)
customer = relationship('Customer', back_populates='orders')
products = relationship('OrderProduct', back_populates='order')
class OrderProduct(Base):
__tablename__ = 'myapp_order_products'
order_id = Column(Integer, ForeignKey('myapp_order.id'), primary_key=True)
product_id = Column(Integer, ForeignKey('myapp_product.id'), primary_key=True)
order = relationship('Order', back_populates='products')
product = relationship('Product')
class CustomerSchema(BaseModel):
id: int
name: str
email: str
class Config:
orm_mode = True
from_attributes = True
class ProductSchema(BaseModel):
id: int
name: str
price: float
class Config:
orm_mode = True
from_attributes = True
class OrderSchema(BaseModel):
id: int
customer_id: int
created_at: str # Change this to str to handle datetime
customer: CustomerSchema
class Config:
orm_mode = True
from_attributes = True
class CombinedData(BaseModel):
order: OrderSchema
products: List[ProductSchema]
app = FastAPI()
@app.on_event("startup")
async def startup():
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
@app.on_event("shutdown")
async def shutdown():
await engine.dispose()
async def get_db():
async with async_session() as session:
yield session
@app.get("/api/orders/{order_id}/", response_model=CombinedData)
async def get_order(order_id: int, db: AsyncSession = Depends(get_db)):
async with db as session:
result = await session.execute(
select(Order)
.options(selectinload(Order.customer), selectinload(Order.products))
.where(Order.id == order_id)
)
order = result.scalars().first()
if not order:
raise HTTPException(status_code=404, detail="Order not found")
product_result = await session.execute(
select(Product)
.join(OrderProduct, OrderProduct.product_id == Product.id)
.where(OrderProduct.order_id == order_id)
)
products = product_result.scalars().all()
order_data = OrderSchema(
id=order.id,
customer_id=order.customer_id,
created_at=order.created_at.isoformat() if order.created_at else None,
customer=CustomerSchema.from_orm(order.customer)
)
products_data = [ProductSchema.from_orm(product) for product in products]
combined_data = CombinedData(order=order_data, products=products_data)
return combined_data
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)