90 lines
2.6 KiB
Python
90 lines
2.6 KiB
Python
from typing import Any, Dict, List, Tuple
|
|
|
|
import mysql.connector
|
|
|
|
from .data_source_db import DataSourceDb
|
|
|
|
|
|
class MySQLDataSource(DataSourceDb):
|
|
def __init__(self, mysql_connection_info: Dict[str, str]) -> None:
|
|
self.user = mysql_connection_info["user"]
|
|
self.host = mysql_connection_info["host"]
|
|
self.password = mysql_connection_info["password"]
|
|
|
|
def save_record(
|
|
self,
|
|
database: str,
|
|
tablename: str,
|
|
data: Dict[str, Any],
|
|
use_obj_connection: bool = False,
|
|
close_obj_connection: bool = True,
|
|
):
|
|
query, values = self.create_query(tablename, data)
|
|
|
|
if use_obj_connection:
|
|
try:
|
|
self.cursor.execute(query, values)
|
|
|
|
if close_obj_connection:
|
|
self.db.commit()
|
|
self.db.close()
|
|
|
|
except mysql.connector.Error:
|
|
self.db.commit()
|
|
self.db.close()
|
|
raise
|
|
|
|
else:
|
|
db = mysql.connector.connect(
|
|
user=self.user,
|
|
host=self.host,
|
|
password=self.password,
|
|
database=database,
|
|
)
|
|
try:
|
|
cursor = db.cursor()
|
|
cursor.execute(query, values)
|
|
db.commit()
|
|
db.close()
|
|
except mysql.connector.Error:
|
|
db.close()
|
|
raise
|
|
|
|
def bulk_save(self, export_data: List[Dict[str, Any]]) -> None:
|
|
for index, unit in enumerate(export_data):
|
|
database, table_name, data = unit.values()
|
|
|
|
if (
|
|
not hasattr(self, "db")
|
|
or not hasattr(self, "cursor")
|
|
or not self.db.is_connected()
|
|
):
|
|
self.db = mysql.connector.connect(
|
|
user=self.user,
|
|
host=self.host,
|
|
password=self.password,
|
|
database=database,
|
|
)
|
|
try:
|
|
self.cursor = self.db.cursor()
|
|
except:
|
|
self.db.close()
|
|
raise
|
|
|
|
self.save_record(
|
|
database, table_name, data, True, index + 1 == len(export_data)
|
|
)
|
|
|
|
def create_query(self, table_name: str, data: dict) -> Tuple[str, list]:
|
|
columns = list(data.keys())
|
|
values = list(data.values())
|
|
|
|
query = (
|
|
f"REPLACE INTO {table_name} ("
|
|
+ ", ".join(columns)
|
|
+ ") VALUES ("
|
|
+ ", ".join(["%s"] * len(columns))
|
|
+ ")"
|
|
)
|
|
return (query, values)
|