This commit is contained in:
deathbybandaid 2022-02-09 15:58:27 -05:00
parent 1bea7a43ff
commit 4bb824c595
4 changed files with 315 additions and 7 deletions

View File

@ -2,6 +2,7 @@
from .config import Config from .config import Config
from .versions import Versions from .versions import Versions
from .logger import Logger from .logger import Logger
from .database import Database
class SpiceBotCore_OBJ(): class SpiceBotCore_OBJ():
@ -19,3 +20,14 @@ class SpiceBotCore_OBJ():
# Parse Version Information for the ENV # Parse Version Information for the ENV
self.versions = Versions(self.config, self.logger) self.versions = Versions(self.config, self.logger)
# Mimic Sopel DB, with enhancements
self.database = Database(self.config)
def setup(self, bot):
"""This runs with the plugin setup routine"""
self.bot = bot
# Re-initialize the bot config properly during plugin setup routine
self.config.config = bot.config

View File

@ -18,11 +18,13 @@ class Config():
# Load config # Load config
self.config = get_configuration(self.opts) self.config = get_configuration(self.opts)
self.setup_config() @property
def basename(self):
return os.path.basename(self.config.filename).rsplit('.', 1)[0]
def setup_config(self): @property
self.config.core.basename = os.path.basename(self.config.filename).rsplit('.', 1)[0] def prefix_list(self):
self.config.core.prefix_list = str(self.config.core.prefix).replace("\\", '').split("|") return str(self.config.core.prefix).replace("\\", '').split("|")
def define_section(self, name, cls_, validate=True): def define_section(self, name, cls_, validate=True):
return self.config.define_section(name, cls_, validate) return self.config.define_section(name, cls_, validate)

View File

@ -0,0 +1,294 @@
# coding=utf8
from __future__ import unicode_literals, absolute_import, division, print_function
import json
from sopel.tools import Identifier
from sopel.db import SopelDB, NickValues, ChannelValues, PluginValues
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.exc import SQLAlchemyError
BASE = declarative_base()
class SpiceDB(object):
# NICK FUNCTIONS
def adjust_nick_value(self, nick, key, value):
"""Sets the value for a given key to be associated with the nick."""
nick = Identifier(nick)
value = json.dumps(value, ensure_ascii=False)
nick_id = self.get_nick_id(nick)
session = self.ssession()
try:
result = session.query(NickValues) \
.filter(NickValues.nick_id == nick_id) \
.filter(NickValues.key == key) \
.one_or_none()
# NickValue exists, update
if result:
result.value = float(result.value) + float(value)
session.commit()
# DNE - Insert
else:
new_nickvalue = NickValues(nick_id=nick_id, key=key, value=float(value))
session.add(new_nickvalue)
session.commit()
except SQLAlchemyError:
session.rollback()
raise
finally:
session.close()
def adjust_nick_list(self, nick, key, entries, adjustmentdirection):
"""Sets the value for a given key to be associated with the nick."""
nick = Identifier(nick)
if not isinstance(entries, list):
entries = [entries]
entries = json.dumps(entries, ensure_ascii=False)
nick_id = self.get_nick_id(nick)
session = self.ssession()
try:
result = session.query(NickValues) \
.filter(NickValues.nick_id == nick_id) \
.filter(NickValues.key == key) \
.one_or_none()
# NickValue exists, update
if result:
if adjustmentdirection == 'add':
for entry in entries:
if entry not in result.value:
result.value.append(entry)
elif adjustmentdirection == 'del':
for entry in entries:
while entry in result.value:
result.value.remove(entry)
session.commit()
# DNE - Insert
else:
values = []
if adjustmentdirection == 'add':
for entry in entries:
if entry not in values:
values.append(entry)
elif adjustmentdirection == 'del':
for entry in entries:
while entry in values:
values.remove(entry)
new_nickvalue = NickValues(nick_id=nick_id, key=key, value=values)
session.add(new_nickvalue)
session.commit()
except SQLAlchemyError:
session.rollback()
raise
finally:
session.close()
# CHANNEL FUNCTIONS
def adjust_channel_value(self, channel, key, value):
"""Sets the value for a given key to be associated with the channel."""
channel = Identifier(channel).lower()
value = json.dumps(value, ensure_ascii=False)
session = self.ssession()
try:
result = session.query(ChannelValues) \
.filter(ChannelValues.channel == channel)\
.filter(ChannelValues.key == key) \
.one_or_none()
# ChannelValue exists, update
if result:
result.value = float(result.value) + float(value)
session.commit()
# DNE - Insert
else:
new_channelvalue = ChannelValues(channel=channel, key=key, value=float(value))
session.add(new_channelvalue)
session.commit()
except SQLAlchemyError:
session.rollback()
raise
finally:
session.close()
def adjust_channel_list(self, channel, key, entries, adjustmentdirection):
"""Sets the value for a given key to be associated with the channel."""
channel = Identifier(channel).lower()
if not isinstance(entries, list):
entries = [entries]
entries = json.dumps(entries, ensure_ascii=False)
session = self.ssession()
try:
result = session.query(ChannelValues) \
.filter(ChannelValues.channel == channel)\
.filter(ChannelValues.key == key) \
.one_or_none()
# ChannelValue exists, update
if result:
if adjustmentdirection == 'add':
for entry in entries:
if entry not in result.value:
result.value.append(entry)
elif adjustmentdirection == 'del':
for entry in entries:
while entry in result.value:
result.value.remove(entry)
session.commit()
# DNE - Insert
else:
values = []
if adjustmentdirection == 'add':
for entry in entries:
if entry not in values:
values.append(entry)
elif adjustmentdirection == 'del':
for entry in entries:
while entry in values:
values.remove(entry)
new_channelvalue = ChannelValues(channel=channel, key=key, value=values)
session.add(new_channelvalue)
session.commit()
except SQLAlchemyError:
session.rollback()
raise
finally:
session.close()
# PLUGIN FUNCTIONS
def adjust_plugin_value(self, plugin, key, value):
"""Sets the value for a given key to be associated with the plugin."""
plugin = plugin.lower()
value = json.dumps(value, ensure_ascii=False)
session = self.ssession()
try:
result = session.query(PluginValues) \
.filter(PluginValues.plugin == plugin)\
.filter(PluginValues.key == key) \
.one_or_none()
# PluginValue exists, update
if result:
result.value = float(result.value) + float(value)
session.commit()
# DNE - Insert
else:
new_pluginvalue = PluginValues(plugin=plugin, key=key, value=float(value))
session.add(new_pluginvalue)
session.commit()
except SQLAlchemyError:
session.rollback()
raise
finally:
session.close()
def adjust_plugin_list(self, plugin, key, entries, adjustmentdirection):
"""Sets the value for a given key to be associated with the plugin."""
plugin = plugin.lower()
if not isinstance(entries, list):
entries = [entries]
entries = json.dumps(entries, ensure_ascii=False)
session = self.ssession()
try:
result = session.query(PluginValues) \
.filter(PluginValues.plugin == plugin)\
.filter(PluginValues.key == key) \
.one_or_none()
# PluginValue exists, update
if result:
if adjustmentdirection == 'add':
for entry in entries:
if entry not in result.value:
result.value.append(entry)
elif adjustmentdirection == 'del':
for entry in entries:
while entry in result.value:
result.value.remove(entry)
session.commit()
# DNE - Insert
else:
values = []
if adjustmentdirection == 'add':
for entry in entries:
if entry not in values:
values.append(entry)
elif adjustmentdirection == 'del':
for entry in entries:
while entry in values:
values.remove(entry)
new_pluginvalue = PluginValues(plugin=plugin, key=key, value=values)
session.add(new_pluginvalue)
session.commit()
except SQLAlchemyError:
session.rollback()
raise
finally:
session.close()
class Database():
def __init__(self, config):
SopelDB.adjust_nick_value = SpiceDB.adjust_nick_value
SopelDB.adjust_nick_list = SpiceDB.adjust_nick_list
SopelDB.adjust_channel_value = SpiceDB.adjust_channel_value
SopelDB.adjust_channel_list = SpiceDB.adjust_channel_list
SopelDB.adjust_plugin_value = SpiceDB.adjust_plugin_value
SopelDB.adjust_plugin_list = SpiceDB.adjust_plugin_list
self.db = SopelDB(config)
BASE.metadata.create_all(self.db.engine)
@property
def botnick(self):
return self.config.core.nick
def __getattr__(self, name):
''' will only get called for undefined attributes '''
if hasattr(self.db, name):
return eval("self.db." + name)
else:
return None
"""Nick"""
def adjust_nick_value(self, nick, key, value):
return self.db.adjust_nick_value(nick, key, value)
def adjust_nick_list(self, nick, key, entries, adjustmentdirection):
return self.db.adjust_nick_list(nick, key, entries, adjustmentdirection)
"""Bot"""
def get_bot_value(self, key):
return self.db.get_nick_value(self.botnick, key)
def set_bot_value(self, key, value):
return self.db.set_nick_value(self.botnick, key, value)
def delete_bot_value(self, key):
return self.db.delete_nick_value(self.botnick, key)
def adjust_bot_value(self, key, value):
return self.db.adjust_nick_value(self.botnick, key, value)
def adjust_bot_list(self, key, entries, adjustmentdirection):
return self.db.adjust_nick_list(self.botnick, key, entries, adjustmentdirection)
"""Channels"""
def adjust_channel_value(self, channel, key, value):
return self.db.adjust_channel_value(channel, key, value)
def adjust_channel_list(self, nick, key, entries, adjustmentdirection):
return self.db.adjust_channel_list(nick, key, entries, adjustmentdirection)
"""Plugins"""
def adjust_plugin_value(self, plugin, key, value):
return self.db.adjust_plugin_value(plugin, key, value)
def adjust_plugin_list(self, plugin, key, entries, adjustmentdirection):
return self.db.adjust_plugin_list(plugin, key, entries, adjustmentdirection)

View File

@ -16,12 +16,12 @@ SCRIPT_DIR = pathlib.Path(os.path.dirname(os.path.abspath(__file__)))
sbcore = SpiceBotCore_OBJ(SCRIPT_DIR) sbcore = SpiceBotCore_OBJ(SCRIPT_DIR)
# def setup(bot): def setup(bot):
# sbcore.setup(bot) sbcore.setup(bot)
@plugin.nickname_command('test') @plugin.nickname_command('test')
def test(bot, trigger): def sb_nickname_command(bot, trigger):
bot.say("Testing the bot") bot.say("Testing the bot")
bot.say("Attributes: %s" % [x for x in dir(sbcore) if not x.startswith("__")]) bot.say("Attributes: %s" % [x for x in dir(sbcore) if not x.startswith("__")])
bot.say("%s" % sbcore.versions.dict) bot.say("%s" % sbcore.versions.dict)